-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EmbeddingFwdOp
node with same functionality as F.embedding
#3649
base: main
Are you sure you want to change the base?
Conversation
e0e76e8
to
a0aa5dc
Compare
EmbeddingOp
node with same functionality as F.embedding
EmbeddingFwdOp
node with same functionality as F.embedding
a4a8e33
to
9ea8f85
Compare
PR Reviewer Guide 🔍(Review updated until commit af25ce0)Here are some key observations to aid the review process:
|
!test |
|
||
constexpr int64_t n = 5, s = 2; | ||
|
||
TEST_F(EmbeddingTest, EmbeddingFwdNode) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it's possible to add a check to verify the output of the toString
method as well.
} | ||
std::optional<double> max_norm = std::nullopt; | ||
if (has_max_norm()) { | ||
auto idx = 5 + has_padding_idx(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: having this free 5 bothers me a little bit, but not sure what would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may not be ideal, however, we are fetching the previous variables based on fixed indices as well. The position of the variables is constant so it should be safe.
if (norm_type == nullptr) { | ||
norm_type = IrBuilder::create<Val>(2.0, DataType::Double); | ||
} | ||
if (scale_grad_by_freq == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we use IrContainer::falseVal()
here?
input->fusion()->falseVall()
or get the current fusion and call the function on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar pattern below.
This PR adds an
EmbeddingFwdOp
with same functionality asF.embedding
.take_along_axis
.F.embedding
allows optional parameters likemax_norm, padding_idx
which would require further processing if implemented usingtake_along_axis
. So I defaulted to creating a new node to guarantee performance parity.prims.EMBEDDING
if the optional parameterspadding_idx/max_norm
are specified, else it usesprims.TAKE
. This prevents nvfuser from consuming embedding operator in the other cases. Hence, in Thunder, nvfuser will also directly executeltorch.embedding
. This will require a separate backward API to consumeltorch.embedding_backward
and cannot reuse grad rules forprims.EMBEDDING
. Hence, theEmbeddingFwdOp
naming instead ofEmbeddingOp
.