diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..241311a4 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -62,9 +62,23 @@ def _dims_to_layout(model, node, ndims): else: return DataLayout.UNKNOWN else: - # propagate input layout to output - # TODO this won't work for concat, squeeze/unsqueeze/reshape... - return model.get_tensor_layout(node.input[0]) + # Check whether there is a layout annotation for the first input + # TODO: There are multi-input operations, why should the first + # determine the output layout? + if layout := model.get_tensor_layout(node.input[0]): + # If annotation present: propagate input layout to output + # TODO: this won't work for concat, squeeze/unsqueeze/reshape... + return layout + # Fallback to the same defaults as for the FINN-Ops above + else: + if ndims == 4: + return DataLayout.NHWC + elif ndims == 3: + return DataLayout.NWC + elif ndims == 2: + return DataLayout.NC + else: + return DataLayout.UNKNOWN def _infer_node_data_layout(model, node):