From fbed329feecff61554886c13deef033bfb0e8eaf Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 16 Apr 2024 13:01:31 -0700 Subject: [PATCH] add more test cases --- .../dynamo/conversion/test_layer_norm_aten.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index 6b4a1d6961..c0e055304a 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -39,18 +39,31 @@ def forward(self, x): class TestNativeLayerNormConverter(DispatchTestCase): - def test_layer_norm(self): + @parameterized.expand( + [ + ( + (5, 3, 2, 4), + [ + 4, + ], + ), + ((5, 3, 2, 4), [2, 4]), + ((5, 3, 2, 4), [3, 2, 4]), + ((5, 3, 2, 4), [5, 3, 2, 4]), + ] + ) + def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( x, - torch.tensor([3, 224, 224]), - torch.ones((3, 224, 224)), - torch.zeros((3, 224, 224)), - 1e-05, + normalized_shape, + torch.randn(normalized_shape), + torch.randn(normalized_shape), + eps, )[0] - inputs = [torch.randn(1, 3, 224, 224)] + inputs = [torch.randn(input_shape)] self.run_test( LayerNorm(), inputs,