Skip to content

Commit

Permalink
Merge pull request #92 from iksnagreb/fix/multi_threshold_exec_layouts
Browse files Browse the repository at this point in the history
Add rudimentary support for "arbitrary" dimensions in MultiThreshold
  • Loading branch information
maltanar authored Dec 17, 2024
2 parents b3121e2 + 68c7cc5 commit cf640b9
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ def execute_node(self, context, graph):
pass
else:
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")

# Remember whether the shape has been modified to handle 1d or 3d data
# layouts
orig_shape = None
# If the input tensor has dimensions not covered by the NC or NCWH data
# layouts, the shape needs to be adapted such that it can be handled by
# multithreshold.
# TODO: Seems like a rather sketchy solution to support arbitrary data
# layouts. This does not even validate the assumption of channel last
# layout.
if v.ndim not in {2, 4}:
# Remember the original shape to be restored later
orig_shape = v.shape
# Assume last dimension to be the channel dimension C and reshape
# into NC layout which is supported by multithreshold
v = v.reshape((-1, v.shape[-1]))

# calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
Expand All @@ -145,6 +162,13 @@ def execute_node(self, context, graph):
pass
else:
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")

# If the shape has been modified to support arbitrary layouts, restore
# the original shape
# TODO: Part of the rather sketchy solution above.
if orig_shape is not None:
output = output.reshape(orig_shape)

context[node.output[0]] = output

def verify_node(self):
Expand Down

0 comments on commit cf640b9

Please sign in to comment.