Skip to content

Commit

Permalink
Merge pull request #163 from fastmachinelearning/fix/chanlast_forking…
Browse files Browse the repository at this point in the history
…_transpose

Don't remove opposing transposes if first one is forking
  • Loading branch information
maltanar authored Dec 18, 2024
2 parents cf640b9 + cf7c56e commit a6a23ed
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
Binary file added src/qonnx/data/onnx/residual_block_clean.onnx
Binary file not shown.
5 changes: 4 additions & 1 deletion src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ def apply(self, model):
ndim = len(input_shape)
if list(to_channels_first_args(ndim)) == perm_1:
successor_nodes = model.find_direct_successors(n)
if successor_nodes is None:
# skip if:
# - this Transpose has no successors (nothing to do)
# - this Transpose output is forking (cannot remove)
if successor_nodes is None or len(successor_nodes) > 1:
continue
successor_node = successor_nodes[0]

Expand Down
55 changes: 55 additions & 0 deletions tests/transformation/test_channelslast_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of qonnx nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
from pkgutil import get_data

import qonnx.core.onnx_exec as oxe
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.util.basic import gen_finn_dt_tensor


def test_channelslast_residual():
raw_m = get_data("qonnx.data", "onnx/residual_block_clean.onnx")
model = ModelWrapper(raw_m)
iname = model.graph.input[0].name
idt = model.get_tensor_datatype(iname)
ishape = model.get_tensor_shape(iname)
idict = {iname: gen_finn_dt_tensor(idt, ishape)}
oname = model.graph.output[0].name
expected_out = oxe.execute_onnx(model, idict)[oname]
model = model.transform(ConvertToChannelsLastAndClean(make_input_channels_last=False))
expected_ops = ["Transpose", "Conv", "Conv", "Relu", "Conv", "Relu", "Add", "MaxPool", "Transpose"]
ops = [x.op_type for x in model.graph.node]
assert ops == expected_ops, "Did not found expected op sequence after lowering and channels-last"
for node in model.graph.node:
if node.op_type in ["Conv", "MaxPool"]:
assert node.domain == "qonnx.custom_op.channels_last"
out = oxe.execute_onnx(model, idict)[oname]
assert np.isclose(expected_out, out, atol=1e-4).all()

0 comments on commit a6a23ed

Please sign in to comment.