-
Notifications
You must be signed in to change notification settings - Fork 417
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add all g3 tests using facto (#7707)
Summary: Pull Request resolved: #7707 It has to be in the whole file to retain the same order. Will dedup in following diffs fixed FACTO and sub scalar test cases Reproducing testcases from internal testing. ``` ✗ Fail: on_device_ai/Assistant/Jarvis/nightly:test_g3_nightly - test_g3_sub_tensor_out_11 (on_device_ai.Assistant.Jarvis.nightly.test_g3_nightly.TestOperators) (25.3s) /data/users/zonglinpeng/fbsource 7c1566d4aa2e+ **************************************************************************************************** OrderedDict([('alpha', 1.6130766369937761)]) [tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500], [147.8750, 247.8750]])] ``` VS ``` ✓ Pass: executorch/examples/cadence/operators:test_g3_ops - test_g3_sub_tensor_out_11 (executorch.examples.cadence.operators.test_g3_ops.ATenOpTestCases) (1.0s) **************************************************************************************************** [tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500], [147.8750, 247.8750]])] OrderedDict([('alpha', 1.6130766369937761)]) ``` Differential Revision: D68195603
- Loading branch information
1 parent
590c04f
commit 1feba50
Showing
3 changed files
with
304 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
import unittest | ||
from typing import Any, cast, List, OrderedDict, Tuple | ||
|
||
from executorch.examples.cadence.operators import facto_util | ||
|
||
from parameterized import parameterized | ||
|
||
from executorch.backends.cadence.aot.ops_registrations import * # noqa | ||
|
||
import torch | ||
import torch.nn as nn | ||
from executorch.backends.cadence.aot.export_example import export_model | ||
|
||
|
||
class ATenOpTestCases(unittest.TestCase): | ||
def run_and_verify(self, model: nn.Module, inputs: Tuple[Any, ...]) -> None: | ||
model.eval() | ||
export_model( | ||
model, inputs, file_name=self._testMethodName, run_and_compare=False | ||
) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("add.Tensor")]) | ||
@torch.no_grad() | ||
def test_g3_add_tensor_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class AddTensor(nn.Module): | ||
def __init__(self, alpha: float): | ||
super().__init__() | ||
self.alpha = alpha | ||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return torch.add(x, y, alpha=self.alpha) | ||
|
||
model = AddTensor(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("add.Scalar")]) | ||
@torch.no_grad() | ||
def test_aten_add_Scalar_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class AddScalar(nn.Module): | ||
def __init__(self, alpha: float): | ||
super().__init__() | ||
self.alpha = alpha | ||
|
||
def forward(self, x: torch.Tensor, y: float): | ||
return torch.add(x, y, alpha=self.alpha) | ||
|
||
inputs = posargs[:-1] # posargs = [x_tensor, y_scalar, alpha_scalar] | ||
alpha = posargs[-1] | ||
model = AddScalar(alpha) | ||
|
||
self.run_and_verify(model, tuple(inputs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Tensor")]) | ||
@torch.no_grad() | ||
def test_g3_sub_tensor_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class SubTensor(nn.Module): | ||
def __init__(self, alpha: float): | ||
super().__init__() | ||
self.alpha = alpha | ||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return torch.sub(x, y, alpha=self.alpha) | ||
|
||
model = SubTensor(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Scalar")]) | ||
@torch.no_grad() | ||
def test_g3_sub_scalar_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
# Tensor-Scalar subtraction | ||
class SubScalar(torch.nn.Module): | ||
def __init__(self, other): | ||
super().__init__() | ||
self.other = other | ||
|
||
def forward(self, x): | ||
return torch.ops.aten.sub.Scalar(x, self.other) | ||
|
||
inputs = posargs[0] # posargs = [x_tensor, y_scalar, alpha_scalar] | ||
model = SubScalar(posargs[1]) | ||
|
||
self.run_and_verify(model, (inputs,)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("div.Tensor")]) | ||
@torch.no_grad() | ||
def test_g3_div_tensor_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class DivTensor(nn.Module): | ||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return torch.div(x, y + 1) | ||
|
||
model = DivTensor(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("div.Scalar")]) | ||
@torch.no_grad() | ||
def test_g3_div_scalar_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class DivScalar(nn.Module): | ||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return torch.div(x, y + 1) | ||
|
||
model = DivScalar(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("exp.default")]) | ||
@torch.no_grad() | ||
def test_g3_exp_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class Exp(nn.Module): | ||
def forward(self, x: torch.Tensor): | ||
return torch.exp(x) | ||
|
||
model = Exp(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Tensor")]) | ||
@torch.no_grad() | ||
def test_g3_mul_tensor_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class MulTensor(nn.Module): | ||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return x * y | ||
|
||
model = MulTensor(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Scalar")]) | ||
@torch.no_grad() | ||
def test_g3_mul_scalar_out( | ||
self, | ||
posargs: List[str], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class MulScalar(nn.Module): | ||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
return x * y | ||
|
||
model = MulScalar(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("native_layer_norm.default")]) | ||
@torch.no_grad() | ||
def test_g3_native_layer_norm_out( | ||
self, | ||
posargs: List[int], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
inputs, normalized_shape, weight, bias, _ = posargs | ||
model = nn.LayerNorm(normalized_shape, eps=1e-5) | ||
if weight is not None: | ||
weight = cast(torch.Tensor, weight) | ||
model.weight = nn.Parameter(torch.rand_like(weight)) | ||
if bias is not None: | ||
bias = cast(torch.Tensor, bias) | ||
model.bias = nn.Parameter(torch.rand_like(bias)) | ||
|
||
self.run_and_verify(model, (inputs,)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("neg.default")]) | ||
@torch.no_grad() | ||
def test_g3_neg_out( | ||
self, | ||
posargs: List[int], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class Neg(nn.Module): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return torch.neg(x) | ||
|
||
model = Neg(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("rsqrt.default")]) | ||
@torch.no_grad() | ||
def test_g3_rsqrt_out( | ||
self, | ||
posargs: List[int], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
class Rsqrt(nn.Module): | ||
def forward(self, x: torch.Tensor): | ||
return torch.ops.aten.rsqrt(x) | ||
|
||
model = Rsqrt(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("sigmoid.default")]) | ||
@torch.no_grad() | ||
def test_g3_sigmoid_out( | ||
self, | ||
posargs: List[int], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
model = nn.Sigmoid(**inkwargs) | ||
|
||
self.run_and_verify(model, tuple(posargs)) | ||
|
||
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`. | ||
@parameterized.expand([*facto_util.facto_testcase_gen("_softmax.default")]) | ||
@torch.no_grad() | ||
def test_g3__softmax_out( | ||
self, | ||
posargs: List[int], | ||
inkwargs: OrderedDict[str, str], | ||
) -> None: | ||
inputs, _, _ = posargs | ||
model = nn.Softmax(dim=-1) | ||
|
||
self.run_and_verify(model, (inputs,)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |