From 16826a685f4e40e40d11fb31434674b63203f982 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Wed, 18 Dec 2024 23:59:30 +0100 Subject: [PATCH 1/2] [Test] add test_channelslast_residual Co-authored-by: Javier Campos --- src/qonnx/data/onnx/residual_block_clean.onnx | Bin 0 -> 19297 bytes .../test_channelslast_residual.py | 55 ++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 src/qonnx/data/onnx/residual_block_clean.onnx create mode 100644 tests/transformation/test_channelslast_residual.py diff --git a/src/qonnx/data/onnx/residual_block_clean.onnx b/src/qonnx/data/onnx/residual_block_clean.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fea0eadb0cedc57b63bf0c0320c17e4ba7d47ca7 GIT binary patch literal 19297 zcmeI4PjAyO7{|!%e$2Kn0)5>)nyN};Lzf0=qC4IJ;og~Lqk?qx4Gt1_C&i+;3d|7}r@lOoP%^!AH#_AUz#N8bUCx$^d9Zz}8+eZBeMb9fUL^Cp zoHdaZSzoN4e~)e)W%Vp8;%1UmSwtctF5jU}m84B{qkD@>1KMlmb)IHT7qnUF z!q9X<4wVZ--G$+QcA?u$x}bvBTFisjx-c}8&^aqTr_a+=IexSK?%BMf6ROQHD%s-k z+3p`CC$Gw~hy$-Dzn)>FOaHkTVeBw#@^q0DBSG`aI;kf9=N;k_i}2UL^4{23>rWzbj2wdjFaQR?02lxRU;qq&0WbgtzyKHk17H9QfB`T72EYIq00Us)@(jr1 zI`!1>6`lU6r;c&@cliu~a>Nh}fB`T72EYIq00UqE41fVJ00zJS7ytuc01SWuFaQR? z02lxRVBowA$kRVnaOa9n|5U-A*SlKXBzB)}5sNXpMy$5k Date: Thu, 19 Dec 2024 00:01:47 +0100 Subject: [PATCH 2/2] [ChanLast] don't remove opposing transposes if first one is forking --- src/qonnx/transformation/channels_last.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index c22889a6..175af058 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -301,7 +301,10 @@ def apply(self, model): ndim = len(input_shape) if list(to_channels_first_args(ndim)) == perm_1: successor_nodes = model.find_direct_successors(n) - if successor_nodes is None: + # skip if: + # - this Transpose has no successors (nothing to do) + # - this Transpose output is forking (cannot remove) + if successor_nodes is None or len(successor_nodes) > 1: continue successor_node = successor_nodes[0]