Skip to content

Commit

Permalink
2025-01-14 nightly release (b026b51)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 14, 2025
1 parent 4405f3d commit 5349967
Show file tree
Hide file tree
Showing 68 changed files with 716 additions and 645 deletions.
11 changes: 9 additions & 2 deletions .ci/scripts/build-qnn-sdk.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -11,10 +12,16 @@ set -o xtrace
build_qnn_backend() {
echo "Start building qnn backend."
export ANDROID_NDK_ROOT=/opt/ndk
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"

bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release
# Workaround to avoid issues around missing flatccrt library (depending on the
# number of jobs used), see issue #7300:
# Build twice (second time with `--no_clean`) to make sure libflatccrt.a is
# available.
# TODO: Remove this workaround once the underlying issue is fixed.
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release || \
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release --no_clean
}

set_up_aot() {
Expand Down
4 changes: 2 additions & 2 deletions .ci/scripts/setup-qnn-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ install_qnn() {
QNN_INSTALLATION_DIR=/tmp/qnn
mkdir -p "${QNN_INSTALLATION_DIR}"

curl -Lo /tmp/v2.25.0.24.07.28.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.25.0.240728.zip"
curl -Lo /tmp/v2.28.0.24.10.29.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.28.0.241029.zip"
echo "Finishing downloading qnn sdk."
unzip -qo /tmp/v2.25.0.24.07.28.zip -d /tmp
unzip -qo /tmp/v2.28.0.24.10.29.zip -d /tmp
echo "Finishing unzip qnn sdk."


Expand Down
2 changes: 1 addition & 1 deletion .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ echo "COREML option ${COREML}"
if [[ "${MODE}" =~ .*qnn.* ]]; then
QNN=ON
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)"
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang"
export PYTHONPATH=".."
cp schema/program.fbs exir/_serialize/program.fbs
Expand Down
13 changes: 4 additions & 9 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
Expand Down Expand Up @@ -62,7 +63,6 @@
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager

Expand All @@ -72,9 +72,7 @@ class ArmPassManager(PassManager):
def _transform(self, graph_module: torch.fx.GraphModule):
return self(graph_module).graph_module

def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(DecomposeLinearPass())
Expand Down Expand Up @@ -137,11 +135,8 @@ def transform_to_backend_pipeline(
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
if memory_format == "nhwc":
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(DecomposeSelectPass())
self.add_pass(AnnotateChannelsLastDimOrder())

return self._transform(exported_program.graph_module)

Expand Down
56 changes: 56 additions & 0 deletions backends/arm/_passes/decompose_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeSelectPass(ExportPass):
"""
This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1)
"""

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:

if node.op != "call_function":
continue

if node.target in (
exir_ops.edge.aten.select.int,
exir_ops.edge.aten.select_copy.int,
):
slice_op = exir_ops.edge.aten.slice_copy.Tensor
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
else:
continue

input_node, dim, index = node.args

rank = len(input_node.meta["val"].size())
dim = dim % rank if dim < 0 else dim
index = index % rank if index < 0 else index
dim_list = list(range(rank))

with graph_module.graph.inserting_before(node):
slice_node = create_node(
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
)
squeeze_node = create_node(
graph_module.graph, squeeze_op, (slice_node, dim_list)
)

node.replace_all_uses_with(squeeze_node)
graph_module.graph.erase_node(node)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
28 changes: 2 additions & 26 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -49,8 +49,6 @@ def __init__(self):
self.compiler_flags = []
self.output_format = None
self.path_for_intermediates = None
# TODO MLETORCH-265 Remove permute_nhwc flag
self.permute_nhwc = False
self.quantize_io = False
self.tosa_version = None
self.input_order = None
Expand Down Expand Up @@ -118,16 +116,6 @@ def dump_intermediate_artifacts_to(
self.path_for_intermediates = output_path
return self

def set_permute_memory_format(
self, set_nhwc_permutation: bool = True
) -> "ArmCompileSpecBuilder":
"""
Permute to channel last in compiler and runtime. Compilation and
runtime will convert rank 4 inputs to channel last for each sub-graph.
"""
self.permute_nhwc = set_nhwc_permutation
return self

def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
"""
Quantization of inputs and dequantization of outputs for cases where
Expand Down Expand Up @@ -170,11 +158,6 @@ def build(self) -> List[CompileSpec]:
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
)

if self.permute_nhwc:
self.compile_spec.append(
CompileSpec("permute_memory_format", "nhwc".encode())
)

if self.input_order:
self.compile_spec.append(
CompileSpec(
Expand All @@ -188,13 +171,6 @@ def build(self) -> List[CompileSpec]:
return self.compile_spec


def is_permute_memory(compile_spec: List[CompileSpec]) -> bool:
for spec in compile_spec:
if spec.key == "permute_memory_format":
return spec.value.decode() == "nhwc"
return False


def is_tosa(compile_spec: List[CompileSpec]) -> bool:
for spec in compile_spec:
if spec.key == "output_format":
Expand Down Expand Up @@ -264,7 +240,7 @@ def preprocess( # noqa: C901
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
exported_program=edge_program, compile_spec=compile_spec
exported_program=edge_program
)

node_visitors = get_node_visitors(edge_program, tosa_spec)
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
op_repeat,
op_rshift,
op_rsqrt,
op_select,
op_sigmoid,
op_slice,
op_squeeze,
Expand Down
68 changes: 0 additions & 68 deletions backends/arm/operators/op_select.py

This file was deleted.

30 changes: 3 additions & 27 deletions backends/arm/runtime/ArmBackendEthosU.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 Arm Limited and/or its affiliates.
* Copyright 2023-2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -76,7 +76,6 @@ namespace arm {

typedef struct {
FreeableBuffer* processed;
bool permuted_io_flag;
} ExecutionHandle;

extern "C" {
Expand Down Expand Up @@ -125,14 +124,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
handle->processed = processed;

handle->permuted_io_flag = false;
for (auto& compile_spec : compile_specs) {
if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
handle->permuted_io_flag = true;
}
}

// Return the same buffer we were passed - this data will be
// executed directly
return handle;
Expand Down Expand Up @@ -225,11 +216,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
// which require permutation.
bool permuted_input_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i,
tensor_in,
&handles.inputs->io[i],
execution_handle->permuted_io_flag,
&permuted_input_shape));
i, tensor_in, &handles.inputs->io[i], &permuted_input_shape));
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
handles.inputs->io[i].elem_size == 1;
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
Expand Down Expand Up @@ -330,11 +317,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {

bool permuted_output_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i,
tensor_out,
&handles.outputs->io[i],
execution_handle->permuted_io_flag,
&permuted_output_shape));
i, tensor_out, &handles.outputs->io[i], &permuted_output_shape));
if (tensor_out.scalar_type() == ScalarType::Char and
permuted_output_shape) {
EXECUTORCH_PROF_SCOPE(
Expand Down Expand Up @@ -395,7 +378,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
int index,
const executorch::aten::Tensor tensor,
VelaIO* io,
bool permuted_io_flag,
bool* is_permuted) const {
bool permuted_shape = false;

Expand All @@ -409,12 +391,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
if (permuted_shape) {
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
}
if (permuted_io_flag != permuted_shape) {
ET_LOG(
Error,
"Permute compile flag and permuted input/output don't agree");
return Error::InvalidProgram;
}
}
*is_permuted = permuted_shape;
return Error::Ok;
Expand Down
Loading

0 comments on commit 5349967

Please sign in to comment.