diff --git a/tests/python/test_embedding.py b/tests/python/test_embedding.py index 0e2670f2cbb..0655cdafc50 100644 --- a/tests/python/test_embedding.py +++ b/tests/python/test_embedding.py @@ -73,5 +73,5 @@ def fusion_func( norm_type = 2.0 if norm_type is None else norm_type scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq sparse = False if sparse is None else sparse - ref_out = F.embedding_fwd(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) + ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) torch.testing.assert_close(nvf_out[0], ref_out) \ No newline at end of file