-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EmbeddingFwdOp
node with same functionality as F.embedding
#3649
base: main
Are you sure you want to change the base?
Changes from all commits
77b09cf
18816d4
7a1e646
db19eff
891d50e
6818d6d
44780c0
9ea8f85
08c2f1f
6698535
af25ce0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -662,4 +662,71 @@ SdpfaBwdResult sdpfa_bwd( | |
return {grad_query, grad_key, grad_value}; | ||
} | ||
|
||
TensorView* embedding_fwd( | ||
TensorView* input, | ||
TensorView* weight, | ||
Val* padding_idx, | ||
Val* max_norm, | ||
Val* norm_type, | ||
Val* scale_grad_by_freq, | ||
Val* sparse) { | ||
auto input_domain = TensorDomain::noReductions(input->getLogicalDomain()); | ||
auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain()); | ||
NVF_CHECK( | ||
!input_domain.empty(), | ||
"Expected input to be atleast 1D, got: ", | ||
input_domain.size()); | ||
NVF_CHECK( | ||
weight_domain.size() == 2, | ||
"Expected weight to be 2D, got: ", | ||
weight_domain.size()); | ||
|
||
NVF_CHECK( | ||
!padding_idx || padding_idx->isScalar(), | ||
"Expected padding_idx to be a scalar int."); | ||
NVF_CHECK( | ||
!max_norm || max_norm->isScalar(), | ||
"Expected max_norm to be a scalar double."); | ||
NVF_CHECK( | ||
!norm_type || norm_type->isScalar(), | ||
"Expected scale to be a scalar double."); | ||
NVF_CHECK( | ||
!scale_grad_by_freq || scale_grad_by_freq->isScalar(), | ||
"Expected scale to be a scalar bool."); | ||
NVF_CHECK( | ||
!sparse || sparse->isScalar(), "Expected scale to be a scalar bool."); | ||
|
||
auto ndims_out = input_domain.size() + 1; | ||
std::vector<IterDomain*> out_domain(ndims_out, nullptr); | ||
|
||
for (auto idx : c10::irange(ndims_out - 1)) { | ||
out_domain[idx] = ops::newOutputIterDomain({input_domain[idx]}); | ||
} | ||
out_domain[ndims_out - 1] = ops::newOutputIterDomain({weight_domain.back()}); | ||
TensorDomain* out_td = IrBuilder::create<TensorDomain>( | ||
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); | ||
TensorView* output = IrBuilder::create<TensorView>(out_td, weight->dtype()); | ||
|
||
if (norm_type == nullptr) { | ||
norm_type = IrBuilder::create<Val>(2.0, DataType::Double); | ||
} | ||
if (scale_grad_by_freq == nullptr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar pattern below. |
||
scale_grad_by_freq = IrBuilder::create<Val>(false, DataType::Bool); | ||
} | ||
if (sparse == nullptr) { | ||
sparse = IrBuilder::create<Val>(false, DataType::Bool); | ||
} | ||
IrBuilder::create<EmbeddingFwdOp>( | ||
output, | ||
input, | ||
weight, | ||
padding_idx, | ||
max_norm, | ||
norm_type, | ||
scale_grad_by_freq, | ||
sparse); | ||
|
||
return output; | ||
} | ||
|
||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: having this free 5 bothers me a little bit, but not sure what would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may not be ideal, however, we are fetching the previous variables based on fixed indices as well. The position of the variables is constant so it should be safe.