Skip to content

Commit

Permalink
Update on "use-pt-pinned-commit for test-arm-{backend,reference}-dele…
Browse files Browse the repository at this point in the history
…gation"

Without this, these builds don't respect the torchgen pinned commit and thus fail with #7546.

Differential Revision: [D67996459](https://our.internmc.facebook.com/intern/diff/D67996459/)

[ghstack-poisoned]
  • Loading branch information
Github Executorch committed Jan 13, 2025
2 parents bda857d + fb2cd54 commit 05385f1
Show file tree
Hide file tree
Showing 54 changed files with 2,687 additions and 412 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2ea4b56ec872424e486c4fe2d55da061067a2ed3
0a94bb432ed75cc2d950d81b2921363218a7e459
2 changes: 1 addition & 1 deletion .github/workflows/apple-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ jobs:
runs-on: linux.2xlarge
steps:
- name: Download the apps from GitHub
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
# The name here needs to match the name of the upload-artifact parameter
name: ios-apps
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/apple.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_DEMO_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD
upload-artifact: ios-apps
upload-artifact: ios-demo-app
script: |
set -eux
Expand Down Expand Up @@ -83,10 +83,10 @@ jobs:
runs-on: linux.2xlarge
steps:
- name: Download the artifacts from GitHub
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
# The name here needs to match the name of the upload-artifact parameter
name: ios-apps
name: ios-demo-app
path: ${{ runner.temp }}/artifacts/

- name: Verify the artifacts
Expand Down Expand Up @@ -216,7 +216,7 @@ jobs:
role-to-assume: arn:aws:iam::308535385114:role/gha_executorch_upload-frameworks-ios
aws-region: us-east-1
- name: Download the artifact
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
# NB: The name here needs to match the upload-artifact name from build-frameworks-ios job
name: executorch-frameworks-ios
Expand Down Expand Up @@ -291,7 +291,7 @@ jobs:
python-version: '3.11'
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
upload-artifact: ios-apps
upload-artifact: ios-benchmark-app
secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD
timeout: 90
script: |
Expand Down
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
url = https://github.com/pybind/pybind11.git
[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"]
path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3
url = https://github.com/foss-xtensa/nnlib-FusionG3/
url = https://github.com/foss-xtensa/nnlib-FusionG3.git
[submodule "third-party/ao"]
path = third-party/ao
url = https://github.com/pytorch/ao.git
6 changes: 6 additions & 0 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
CompileSpec,
PreprocessResult,
)

from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
from executorch.exir.program._program import _transform
from torch.export.exported_program import ExportedProgram

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -83,6 +86,9 @@ def preprocess(
# FlatBuffer graph, process the `output` nodes and add their id to
# the `output_ids` array in the schema.

# TODO: Remove this once we have a better support for the dim-order ops.
edge_program = _transform(edge_program, DimOrderOpsRevertPass())

mps_graph = MPSGraph(
version="0",
mps_nodes=[],
Expand Down
19 changes: 19 additions & 0 deletions backends/apple/mps/operators/constant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ def define_node(
)


@register_node_visitor
class ToDimOrderEmptyVisitor(NodeVisitor):
target = ["dim_order_ops._empty_dim_order.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
# But if we do, we can't handle it ATM, so raise an exception
raise NotImplementedError(
"dim_order_ops._empty_dim_order.default is not supported yet"
)


@register_node_visitor
class FullLikeVisitor(NodeVisitor):
target = "aten.full_like.default"
Expand Down
19 changes: 19 additions & 0 deletions backends/apple/mps/operators/op_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,22 @@ def define_node(
)
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
self.tensor_to_id[node] = input_id


@register_node_visitor
class ToDimOrderCopyVisitor(NodeVisitor):
target = ["dim_order_ops._to_dim_order_copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
# But if we do, we can't handle it ATM, so raise an exception
raise NotImplementedError(
"dim_order_ops._to_dim_order_copy.default is not supported yet"
)
15 changes: 15 additions & 0 deletions backends/apple/mps/test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,21 @@ def forward(self, x):
Clone(), model_inputs, func_name=inspect.stack()[0].function[5:]
)

def test_mps_backend_to_copy(self):
class Copy(torch.nn.Module):
def forward(self, x):
return (
torch.ops.aten._to_copy.default(
x + 2, memory_format=torch.contiguous_format
)
+ x
)

model_inputs = (torch.randn(1, 3, 3),)
self.lower_and_test_with_partitioner(
Copy(), model_inputs, func_name=inspect.stack()[0].function[5:]
)

def test_mps_backend_floor(self):
class Floor(torch.nn.Module):
def forward(self, x):
Expand Down
7 changes: 1 addition & 6 deletions backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@

# Config for Capturing the weights, will be moved in the future

# TODO(T182928844): Delegate dim order op to backend.
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
_check_ir_validity=False, _skip_dim_order=True
)
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False)


class ansi_colors:
Expand Down Expand Up @@ -219,7 +216,6 @@ def lower_module_and_test_output(
dynamic_shapes=dynamic_shapes,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
),
)

Expand Down Expand Up @@ -250,7 +246,6 @@ def lower_module_and_test_output(
export(delegated_program, sample_inputs, strict=True),
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
),
).to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
Expand Down
22 changes: 22 additions & 0 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ The you can run the tests with
pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP
```

### Code coverage

To get code coverage:

```
coverage run --source=<SRC> --rcfile=backends/arm/test/.coveragerc -m pytest \
--config-file=/dev/null backends/arm/test/
```

All files in `SRC` and its child directories will be analysed for code coverage,
unless explicitly exluded in the .coveragerc file. If using venv this might be
under `env/lib/python<VERSION_NUMBER>/site-packages/executorch/`. To get the
absolute path, run:

```
python -c "import executorch; print(executorch.__path__)"
```

This contains a list of paths where the source directory is located. Pick the
one that is located in `env/lib`. If that does not work try the others. Add
`backends/arm` to the path in `--source` to only get code coverage for the Arm
backend.

### A note on unit tests

Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand Down Expand Up @@ -73,6 +76,7 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
60 changes: 60 additions & 0 deletions backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.tosa_quant_utils import q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node


class FuseQuantizedActivationPass(ExportPass):
def _is_fuseable_quantized_activation(self, node: Node):
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
is_fuseable = node.target == exir_ops.edge.aten.relu.default
if node.target == exir_ops.edge.aten.hardtanh.default:
min_val = node.args[1]
is_fuseable = min_val == 0

is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
if is_quantized:
quant_node = next(iter(node.users))
zp = quant_node.args[2]
qmin = quant_node.args[3]

return is_fuseable and is_quantized and zp == qmin

def _is_fuseable_input(self, node: Node):
return (
node.target
in (
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.linear.default,
)
and len(node.users) == 1
)

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if not self._is_fuseable_quantized_activation(node):
continue

input_node = node.args[0]
if not self._is_fuseable_input(input_node):
continue

node.replace_all_uses_with(input_node)
graph_module.graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
66 changes: 65 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
_annotate_output_qspec(node, quant_property.qspec)


def _match_pattern(
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
) -> bool:
"""
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
chain pass the filtering.
Each 'pattern' element is composed of a list of disjunctive nodes types.
"""
assert len(pattern) == 2, "Only two-nodes patterns supported currently"

if node.target in pattern[0]:
assert len(node.users) != 0
parent = node
child = next(iter(node.users))
elif node.target in pattern[1]:
assert len(node.args) != 0
parent = node.args[0]
child = node
else:
return False

if len(parent.users) != 1:
return False

if parent.target not in pattern[0] or child.target not in pattern[1]:
return False

if filter_fn is not None:
return filter_fn(parent) and filter_fn(child)

return True


_one_to_one = [
torch.ops.aten.exp.default,
torch.ops.aten.log.default,
Expand Down Expand Up @@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
bias_qspec = quantization_config.get_bias_qspec()

quant_properties = _OpQuantProperties()
if node.target in (

def any_or_hardtanh_min_zero(n: Node):
# Check that if the node is a hardtanh, its min_val is zero
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0

if _match_pattern(
node,
[
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
],
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
],
any_or_hardtanh_min_zero,
):
if node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
else:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/test/.coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[run]
omit =
*__init__.py*

[report]
skip_covered = true
exclude_also =
raise NotImplementedError
7 changes: 4 additions & 3 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module):
]

test_data = [
(20 * torch.randn(1, 3, 256, 256),),
(5 * torch.randn(1, 3, 256, 256),),
(2 * torch.randn(1, 3, 256, 256),),
(0.5 * torch.randn(1, 3, 256, 256),),
(torch.randn(1, 3, 256, 256),),
(-5 * torch.randn(1, 3, 256, 256),),
(-0.5 * torch.randn(1, 3, 256, 256),),
(-2 * torch.randn(1, 3, 256, 256),),
]

def __init__(self):
Expand Down
Loading

0 comments on commit 05385f1

Please sign in to comment.