Skip to content

Commit

Permalink
fix torch fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Jan 16, 2025
1 parent e5b0594 commit a4a8e33
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/python/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def fusion_func(
norm_type = 2.0 if norm_type is None else norm_type
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
sparse = False if sparse is None else sparse
ref_out = F.embedding_fwd(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
torch.testing.assert_close(nvf_out[0], ref_out)

0 comments on commit a4a8e33

Please sign in to comment.