Skip to content

Commit

Permalink
[xnnpack] Fix pyre issues for deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai committed Jan 17, 2025
1 parent 9f47380 commit 4f62baf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion backends/xnnpack/_passes/tag_implicit_q_dq_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def is_supported_quant_op(self, node: torch.fx.Node) -> bool:

# Weight and Input should both be quantized
if op_name == exir_ops.edge.aten.convolution.default.name():
return is_dequant(node.args[1])
return is_dequant(cast(torch.fx.Node, node.args[1]))

return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET

Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,11 @@ def get_serialized_buffer_index(
# which should be used for depthwise/transpose convolution weights for XNNPACK
shape = const_val.shape
const_val = const_val.reshape(
(groups, const_val.shape[0] // groups) + const_val.shape[1:]
(groups, const_val.shape[0] // groups) + tuple(const_val.shape[1:])
)
const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim())))
const_val = const_val.reshape(
(shape[1] * groups, shape[0] // groups) + shape[2:]
(shape[1] * groups, shape[0] // groups) + tuple(shape[2:])
).contiguous()

if convert_to_nhwc:
Expand Down

0 comments on commit 4f62baf

Please sign in to comment.