Skip to content

Commit

Permalink
add all g3 tests using facto (#7707)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #7707

It has to be in the whole file to retain the same order. Will dedup in following diffs

fixed FACTO and sub scalar test cases

Reproducing testcases from internal testing.
```
✗ Fail: on_device_ai/Assistant/Jarvis/nightly:test_g3_nightly - test_g3_sub_tensor_out_11 (on_device_ai.Assistant.Jarvis.nightly.test_g3_nightly.TestOperators) (25.3s)
/data/users/zonglinpeng/fbsource
7c1566d4aa2e+
****************************************************************************************************
OrderedDict([('alpha', 1.6130766369937761)])
[tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500],
        [147.8750, 247.8750]])]
```
VS
```
✓ Pass: executorch/examples/cadence/operators:test_g3_ops - test_g3_sub_tensor_out_11 (executorch.examples.cadence.operators.test_g3_ops.ATenOpTestCases) (1.0s)
****************************************************************************************************
[tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500],
        [147.8750, 247.8750]])]
OrderedDict([('alpha', 1.6130766369937761)])
```

Differential Revision: D68195603
  • Loading branch information
zonglinpeng authored and facebook-github-bot committed Jan 16, 2025
1 parent 590c04f commit 1feba50
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 12 deletions.
50 changes: 38 additions & 12 deletions examples/cadence/operators/facto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
| "mul.Tensor"
| "div.Tensor"
):
tensor_constraints.append(
cp.Dtype.In(lambda deps: [torch.float]),
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float]),
cp.Size.Le(lambda deps, r, d: 2),
cp.Rank.Le(lambda deps: 2),
]
)
case (
"add.Tensor"
Expand All @@ -37,35 +41,60 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
| "mul.Scalar"
| "div.Scalar"
):
tensor_constraints.append(
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
cp.Size.Le(lambda deps, r, d: 2),
cp.Rank.Le(lambda deps: 2),
]
)
case "native_layer_norm.default":
tensor_constraints.extend(
[
cp.Size.Le(lambda deps, r, d: 2**4),
cp.Rank.Le(lambda deps: 2**4),
]
)
case _:
tensor_constraints.append(
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
cp.Size.Le(lambda deps, r, d: 2),
cp.Rank.Le(lambda deps: 2),
]

)
tensor_constraints.extend(
[
cp.Value.Ge(lambda deps, dtype, struct: -(2**8)),
cp.Value.Le(lambda deps, dtype, struct: 2**8),
cp.Rank.Ge(lambda deps: 1),
cp.Rank.Le(lambda deps: 2**2),
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.Le(lambda deps, r, d: 2**2),
]
)


def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
match op_name:
case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar":
return [ScalarDtype.int]
case _:
return [ScalarDtype.float, ScalarDtype.int]


def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]:
# minimal example to test add.Tensor using FACTO
spec = SpecDictDB[op_name]
tensor_constraints = []
# common tensor constraints
apply_tensor_contraints(op_name, tensor_constraints)

for index, in_spec in enumerate(copy.deepcopy(spec.inspec)):
if in_spec.type.is_scalar():
if in_spec.name != "alpha":
spec.inspec[index].constraints.extend(
[
cp.Dtype.In(lambda deps: [ScalarDtype.float, ScalarDtype.int]),
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
cp.Value.Ge(lambda deps, dtype: -(2**8)),
cp.Value.Le(lambda deps, dtype: 2**2),
cp.Size.Ge(lambda deps, r, d: 1),
Expand All @@ -80,9 +109,6 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
]
)
elif in_spec.type.is_tensor():
tensor_constraints = []
# common tensor constraints
apply_tensor_contraints(op_name, tensor_constraints)
spec.inspec[index].constraints.extend(tensor_constraints)

return [
Expand Down
2 changes: 2 additions & 0 deletions examples/cadence/operators/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

TESTS_LIST = [
"add_op",
"g3_ops",
"quantized_conv1d_op",
"quantized_linear_op",
]
Expand Down Expand Up @@ -46,5 +47,6 @@ def _define_test_target(test_name):
"fbcode//executorch/backends/cadence/aot:ops_registrations",
"fbcode//executorch/backends/cadence/aot:export_example",
"fbcode//executorch/backends/cadence/aot:compiler",
"fbcode//executorch/examples/cadence/operators:facto_util",
],
)
264 changes: 264 additions & 0 deletions examples/cadence/operators/test_g3_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import unittest
from typing import Any, cast, List, OrderedDict, Tuple

from executorch.examples.cadence.operators import facto_util

from parameterized import parameterized

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch
import torch.nn as nn
from executorch.backends.cadence.aot.export_example import export_model


class ATenOpTestCases(unittest.TestCase):
def run_and_verify(self, model: nn.Module, inputs: Tuple[Any, ...]) -> None:
model.eval()
export_model(
model, inputs, file_name=self._testMethodName, run_and_compare=False
)

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("add.Tensor")])
@torch.no_grad()
def test_g3_add_tensor_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class AddTensor(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.add(x, y, alpha=self.alpha)

model = AddTensor(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("add.Scalar")])
@torch.no_grad()
def test_aten_add_Scalar_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class AddScalar(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: float):
return torch.add(x, y, alpha=self.alpha)

inputs = posargs[:-1] # posargs = [x_tensor, y_scalar, alpha_scalar]
alpha = posargs[-1]
model = AddScalar(alpha)

self.run_and_verify(model, tuple(inputs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Tensor")])
@torch.no_grad()
def test_g3_sub_tensor_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class SubTensor(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.sub(x, y, alpha=self.alpha)

model = SubTensor(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Scalar")])
@torch.no_grad()
def test_g3_sub_scalar_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
# Tensor-Scalar subtraction
class SubScalar(torch.nn.Module):
def __init__(self, other):
super().__init__()
self.other = other

def forward(self, x):
return torch.ops.aten.sub.Scalar(x, self.other)

inputs = posargs[0] # posargs = [x_tensor, y_scalar, alpha_scalar]
model = SubScalar(posargs[1])

self.run_and_verify(model, (inputs,))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("div.Tensor")])
@torch.no_grad()
def test_g3_div_tensor_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class DivTensor(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.div(x, y + 1)

model = DivTensor(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("div.Scalar")])
@torch.no_grad()
def test_g3_div_scalar_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class DivScalar(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.div(x, y + 1)

model = DivScalar(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("exp.default")])
@torch.no_grad()
def test_g3_exp_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class Exp(nn.Module):
def forward(self, x: torch.Tensor):
return torch.exp(x)

model = Exp(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Tensor")])
@torch.no_grad()
def test_g3_mul_tensor_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class MulTensor(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x * y

model = MulTensor(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Scalar")])
@torch.no_grad()
def test_g3_mul_scalar_out(
self,
posargs: List[str],
inkwargs: OrderedDict[str, str],
) -> None:
class MulScalar(nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x * y

model = MulScalar(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("native_layer_norm.default")])
@torch.no_grad()
def test_g3_native_layer_norm_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
inputs, normalized_shape, weight, bias, _ = posargs
model = nn.LayerNorm(normalized_shape, eps=1e-5)
if weight is not None:
weight = cast(torch.Tensor, weight)
model.weight = nn.Parameter(torch.rand_like(weight))
if bias is not None:
bias = cast(torch.Tensor, bias)
model.bias = nn.Parameter(torch.rand_like(bias))

self.run_and_verify(model, (inputs,))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("neg.default")])
@torch.no_grad()
def test_g3_neg_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
class Neg(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.neg(x)

model = Neg(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("rsqrt.default")])
@torch.no_grad()
def test_g3_rsqrt_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
class Rsqrt(nn.Module):
def forward(self, x: torch.Tensor):
return torch.ops.aten.rsqrt(x)

model = Rsqrt(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("sigmoid.default")])
@torch.no_grad()
def test_g3_sigmoid_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
model = nn.Sigmoid(**inkwargs)

self.run_and_verify(model, tuple(posargs))

# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
@parameterized.expand([*facto_util.facto_testcase_gen("_softmax.default")])
@torch.no_grad()
def test_g3__softmax_out(
self,
posargs: List[int],
inkwargs: OrderedDict[str, str],
) -> None:
inputs, _, _ = posargs
model = nn.Softmax(dim=-1)

self.run_and_verify(model, (inputs,))


if __name__ == "__main__":
unittest.main()

0 comments on commit 1feba50

Please sign in to comment.