Skip to content

Commit

Permalink
parametrize in python test
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Dec 26, 2024
1 parent 6a557bb commit 84230a5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 79 deletions.
18 changes: 2 additions & 16 deletions tests/cpp/test_embedding_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using EmbeddingTest = NVFuserTest;

constexpr int64_t n = 5, s = 2;

TEST_F(EmbeddingTest, Basic) {
TEST_F(EmbeddingTest, EmbeddingNode) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
std::vector<int64_t> inp_shape({s});
Expand All @@ -45,19 +45,5 @@ TEST_F(EmbeddingTest, Basic) {
FusionExecutorCache executor_cache(std::move(fusion));
auto nvf_out = executor_cache.runFusionWithInputs({input, weight});
EXPECT_TRUE(at::allclose(nvf_out[0], aten_out));
}

// INSTANTIATE_TEST_SUITE_P(
// LinearWithoutBias,
// LinearNodeParametrizedTest,
// testing::Combine(
// testing::Values(
// Sizes({k}),
// Sizes({m, k}),
// Sizes({b, m, k}),
// Sizes({1, k}),
// Sizes({b, 1, k})),
// testing::Values(Sizes({n, k}), Sizes({1, k})),
// testing::Values(std::nullopt)));

}
} // namespace nvfuser
124 changes: 61 additions & 63 deletions tests/python/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,69 +11,67 @@
from functools import partial
import torch.nn.functional as F

class TestEmbedding(NVFuserTest):
def test_embedding(self):
def fusion_func(
fd: FusionDefinition,
has_optional_inputs: list[bool],
optional_inputs_dtypes: list[DataType]
):
input = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.Int,
is_cpu=False,
)
weight = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
)
# padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
optional_inputs = [None] * 5
for idx in range(len(optional_inputs)):
if has_optional_inputs[idx]:
optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx])
out = fd.ops.embedding(input, weight, *optional_inputs)
fd.add_output(out)

N, S = 10, 3
input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False)
weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True)

padding_idx_vals = [None, -1, -2]
max_norm_vals = [None, 1e-5]
norm_type_vals = [None, 2.0, 1.0]
scale_grad_by_freq = [None, True]
sparse = [None, False, True]
optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool]
@pytest.mark.parametrize("padding_idx", [None, -2])
@pytest.mark.parametrize("max_norm", [None, 1e-5])
@pytest.mark.parametrize("norm_type", [None, 1.0])
@pytest.mark.parametrize("scale_grad_by_freq", [None, True])
@pytest.mark.parametrize("sparse", [None, True])
def test_embedding(
padding_idx: None | int,
max_norm: None | float,
norm_type: None | float,
scale_grad_by_freq: None | bool,
sparse: None | bool
):
def fusion_func(
fd: FusionDefinition,
has_optional_inputs: list[bool],
optional_inputs_dtypes: list[DataType]
):
input = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.Int,
is_cpu=False,
)
weight = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
)
# padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
optional_inputs = [None] * 5
for idx in range(len(optional_inputs)):
if has_optional_inputs[idx]:
optional_inputs[idx] = fd.define_scalar(value=None, dtype=optional_inputs_dtypes[idx])
out = fd.ops.embedding(input, weight, *optional_inputs)
fd.add_output(out)


# TODO: Try to move this to pytest_ops.py. Currently, it does not work since the API between nvFuser and torch differs.
for padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse in itertools.product(
padding_idx_vals, max_norm_vals, norm_type_vals, scale_grad_by_freq, sparse
):
with self.subTest(padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse):
# Reset the FusionCache or the fusion would not recompile for all subtests, failing checks in exec_nvfuser.
FusionCache.reset()
optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
has_optional_inputs = [None] * 5
inputs = [input, weight]
for idx, param in enumerate(optional_inputs):
if param is not None:
has_optional_inputs[idx] = True
inputs.append(param)

with FusionDefinition() as fd:
fusion_func(fd,
has_optional_inputs=has_optional_inputs,
optional_inputs_dtypes = optional_inputs_dtypes)
nvf_out = fd.execute(inputs)
N, S = 10, 3
input = torch.randint(N, (S,), dtype=torch.int64, device='cuda', requires_grad=False)
weight = torch.randn(N, S, dtype=torch.bfloat16, device='cuda', requires_grad=True)
optional_inputs_dtypes = [DataType.Int, DataType.Float, DataType.Float, DataType.Bool, DataType.Bool]

torch.manual_seed(0)
norm_type = 2.0 if norm_type is None else norm_type
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
sparse = False if sparse is None else sparse
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
torch.testing.assert_close(nvf_out[0], ref_out)
# This is not in pytest_ops.py since the torch API does not accept None values for some arguments.
# Different inputs for nvfuser and torch API cannot be handled within OpInfo
optional_inputs = [padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
has_optional_inputs = [None] * 5
inputs = [input, weight]
for idx, param in enumerate(optional_inputs):
if param is not None:
has_optional_inputs[idx] = True
inputs.append(param)

with FusionDefinition() as fd:
fusion_func(fd,
has_optional_inputs=has_optional_inputs,
optional_inputs_dtypes= optional_inputs_dtypes)
nvf_out = fd.execute(inputs)

norm_type = 2.0 if norm_type is None else norm_type
scale_grad_by_freq = False if scale_grad_by_freq is None else scale_grad_by_freq
sparse = False if sparse is None else sparse
ref_out = F.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
torch.testing.assert_close(nvf_out[0], ref_out)

0 comments on commit 84230a5

Please sign in to comment.