diff --git a/tests/cpp/test_embedding_node.cpp b/tests/cpp/test_embedding_node.cpp index 249e50a7238..fd77c606644 100644 --- a/tests/cpp/test_embedding_node.cpp +++ b/tests/cpp/test_embedding_node.cpp @@ -20,7 +20,7 @@ using EmbeddingTest = NVFuserTest; constexpr int64_t n = 5, s = 2; -TEST_F(EmbeddingTest, Basic) { +TEST_F(EmbeddingTest, EmbeddingNode) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); std::vector inp_shape({s}); @@ -45,19 +45,5 @@ TEST_F(EmbeddingTest, Basic) { FusionExecutorCache executor_cache(std::move(fusion)); auto nvf_out = executor_cache.runFusionWithInputs({input, weight}); EXPECT_TRUE(at::allclose(nvf_out[0], aten_out)); -} - -// INSTANTIATE_TEST_SUITE_P( -// LinearWithoutBias, -// LinearNodeParametrizedTest, -// testing::Combine( -// testing::Values( -// Sizes({k}), -// Sizes({m, k}), -// Sizes({b, m, k}), -// Sizes({1, k}), -// Sizes({b, 1, k})), -// testing::Values(Sizes({n, k}), Sizes({1, k})), -// testing::Values(std::nullopt))); - +} } // namespace nvfuser \ No newline at end of file diff --git a/tests/python/test_embedding.py b/tests/python/test_embedding.py index 8a459a03959..6fa264be983 100644 --- a/tests/python/test_embedding.py +++ b/tests/python/test_embedding.py @@ -11,69 +11,67 @@ from functools import partial import torch.nn.functional as F -class TestEmbedding(NVFuserTest): - def test_embedding(self): - def fusion_func( - fd: FusionDefinition, - has_optional_inputs: list[bool], - optional_inputs_dtypes: list[DataType] - ): - input = fd.define_tensor( - shape=[-1], - contiguity=[True], - dtype=DataType.Int, - is_cpu=False, - ) - weight = fd.define_tensor( - shape=[-1, -1], - contiguity=[True, True], - dtype=DataType.BFloat16, - is_cpu=False, - ) - # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse - optional_inputs = [None] * 5 - for idx in range(len(optional_inputs)): - if has_optional_inputs[idx]: - optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx]) - out = fd.ops.embedding(input, weight, *optional_inputs) - fd.add_output(out) - N, S = 10, 3 - input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False) - weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True) - - padding_idx_vals = [None, -1, -2] - max_norm_vals = [None, 1e-5] - norm_type_vals = [None, 2.0, 1.0] - scale_grad_by_freq = [None, True] - sparse = [None, False, True] - optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool] +@pytest.mark.parametrize("padding_idx", [None, -2]) +@pytest.mark.parametrize("max_norm", [None, 1e-5]) +@pytest.mark.parametrize("norm_type", [None, 1.0]) +@pytest.mark.parametrize("scale_grad_by_freq", [None, True]) +@pytest.mark.parametrize("sparse", [None, True]) +def test_embedding( + padding_idx: None | int, + max_norm: None | float, + norm_type: None | float, + scale_grad_by_freq: None | bool, + sparse: None | bool + ): + def fusion_func( + fd: FusionDefinition, + has_optional_inputs: list[bool], + optional_inputs_dtypes: list[DataType] + ): + input = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Int, + is_cpu=False, + ) + weight = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + ) + # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + optional_inputs = [None] * 5 + for idx in range(len(optional_inputs)): + if has_optional_inputs[idx]: + optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx]) + out = fd.ops.embedding(input, weight, *optional_inputs) + fd.add_output(out) - - # TODO: Try to move this to pytest_ops.py. Currently, it does not work since the API between nvFuser and torch differs. - for padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse in itertools.product( - padding_idx_vals, max_norm_vals, norm_type_vals, scale_grad_by_freq, sparse - ): - with self.subTest(padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse): - # Reset the FusionCache or the fusion would not recompile for all subtests, failing checks in exec_nvfuser. - FusionCache.reset() - optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse] - has_optional_inputs = [None] * 5 - inputs = [input, weight] - for idx, param in enumerate(optional_inputs): - if param is not None: - has_optional_inputs[idx] = True - inputs.append(param) - - with FusionDefinition() as fd: - fusion_func(fd, - has_optional_inputs=has_optional_inputs, - optional_inputs_dtypes = optional_inputs_dtypes) - nvf_out = fd.execute(inputs) + N, S = 10, 3 + input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False) + weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True) + optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool] - torch.manual_seed(0) - norm_type = 2.0 if norm_type is None else norm_type - scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq - sparse = False if sparse is None else sparse - ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) - torch.testing.assert_close(nvf_out[0], ref_out) \ No newline at end of file + # This is not in pytest_ops.py since the torch API does not accept None values for some arguments. + # Different inputs for nvfuser and torch API cannot be handled within OpInfo + optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse] + has_optional_inputs = [None] * 5 + inputs = [input, weight] + for idx, param in enumerate(optional_inputs): + if param is not None: + has_optional_inputs[idx] = True + inputs.append(param) + + with FusionDefinition() as fd: + fusion_func(fd, + has_optional_inputs=has_optional_inputs, + optional_inputs_dtypes= optional_inputs_dtypes) + nvf_out = fd.execute(inputs) + + norm_type = 2.0 if norm_type is None else norm_type + scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq + sparse = False if sparse is None else sparse + ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse) + torch.testing.assert_close(nvf_out[0], ref_out) \ No newline at end of file