Skip to content

Commit

Permalink
Revert "Remove unused functions for quantization handling" (#7724)
Browse files Browse the repository at this point in the history
Revert "Remove unused functions for quantization handling (#7700)"

This reverts commit ffc2020.
  • Loading branch information
oscarandersson8218 authored Jan 17, 2025
1 parent ffc2020 commit eaad7ff
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 21 deletions.
7 changes: 5 additions & 2 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -15,7 +15,7 @@
get_node_arg,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -43,6 +43,9 @@ def _transpose_impl(*args, **kwargs):
return args[0]


register_passable_op(torch.ops.passthrough_to_tosa._transpose)


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
op_bmm,
op_cat,
op_conv2d,
op_dequant,
op_exp,
op_full,
op_get_item,
Expand All @@ -23,6 +24,7 @@
op_min,
op_mul,
op_permute,
op_quant,
op_reciprocal,
op_relu,
op_repeat,
Expand Down
35 changes: 35 additions & 0 deletions backends/arm/operators/op_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class DequantVisitor(NodeVisitor):
target = "quantized_decomposed.dequantize_per_tensor.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
item_name = inputs[0].name
## Simply add an identityOp
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])
7 changes: 4 additions & 3 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2025 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -19,6 +19,7 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import quantize_value
from serializer.tosa_serializer import TosaOp


Expand All @@ -43,8 +44,8 @@ def define_node(
input_qparams = get_input_qparams(node) # pyre-ignore[16]
qargs = input_qparams[0]
# Convert to quantized representation
clamp_min_qs = qargs.quantize_value(inputs[1].number).item()
clamp_max_qs = qargs.quantize_value(inputs[2].number).item()
clamp_min_qs = quantize_value(inputs[1].number, qargs)
clamp_max_qs = quantize_value(inputs[2].number, qargs)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
Expand Down
35 changes: 35 additions & 0 deletions backends/arm/operators/op_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class QuantVisitor(NodeVisitor):
target = "quantized_decomposed.quantize_per_tensor.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
item_name = inputs[0].name
## Simply add an identityOp
tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name])
8 changes: 5 additions & 3 deletions backends/arm/operators/op_relu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import executorch.backends.arm.tosa_quant_utils as tqutils
import serializer.tosa_serializer as ts
import torch.fx

Expand Down Expand Up @@ -42,8 +43,9 @@ def define_node(
clamp_max_qs = 0
if inputs[0].dtype == ts.DType.INT8:
out_qargs = get_output_qparams(node) # pyre-ignore[16]
clamp_min_qs = out_qargs[0].quantize_value(0).item()
clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item()
clamp_min_qs = tqutils.quantize_value(0, out_qargs[0])
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0])

else:
clamp_min_fp = 0
clamp_max_fp = float("inf")
Expand Down
22 changes: 19 additions & 3 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
dq_op,
get_quantized_node_output_dtype,
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch.export.exported_program import ExportedProgram
Expand All @@ -30,8 +35,15 @@ def process_call_function(
# Convert output (this node itself)
output = TosaArg(node)

is_dq_node = node.target == dq_op
if is_dq_node:
output_dtype = ts.DType.INT8
else:
output_dtype = output.dtype
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
output.name,
tosa_shape(output.shape, output.dim_order),
output_dtype,
)

# Visiting each Node
Expand Down Expand Up @@ -67,7 +79,11 @@ def process_inputs(
tensor = ts.TosaSerializerTensor(
inputs[0].name,
tosa_shape(input_shape, input_dim_order),
inputs[0].dtype,
(
map_dtype(get_quantized_node_output_dtype(node))
if is_node_quantized(node)
else inputs[0].dtype
),
data=None,
placeholderFilename=inputs[0].name + ".npy",
)
Expand Down
Loading

0 comments on commit eaad7ff

Please sign in to comment.