Skip to content

Commit

Permalink
Add cleanup transformation sorting inputs of commutative operations
Browse files Browse the repository at this point in the history
  • Loading branch information
iksnagreb committed Nov 14, 2023
1 parent cadd6b2 commit 7719a3e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import RemoveStaticGraphInputs, RemoveUnusedTensors, SortGraph
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
RemoveUnusedTensors,
SortGraph,
SortCommutativeInputsInitializerLast
)


class ModelWrapper:
Expand Down Expand Up @@ -149,6 +154,7 @@ def cleanup(self):
RemoveUnusedTensors(),
RemoveStaticGraphInputs(),
SortGraph(),
SortCommutativeInputsInitializerLast(),
]
for trn in cleanup_transforms:
transformed_model = transformed_model.transform(trn, cleanup=False, make_deepcopy=False)
Expand Down
57 changes: 57 additions & 0 deletions src/qonnx/transformation/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import qonnx.util.basic as util
from qonnx.transformation.base import Transformation

# Protobuf onnx graph node type
from onnx import NodeProto # noqa


class MovePadAttributeToTensor(Transformation):
"Move padding info from attribute into input tensor for Pad nodes."
Expand Down Expand Up @@ -359,3 +362,57 @@ def apply(self, model):

# one iteration is enough
return (model, False)


# Groups inputs by categories, i.e., groups dynamic inputs first, followed by
# initializers. Keeps order of inputs in each category.
def group_inputs_by_category(node: NodeProto, model): # noqa
# Select all dynamic inputs, which are those without initializer tensor
dynamics = [i for i in node.input if model.get_initializer(i) is None]
# Select all input which are initializers, which, by exclusion, are all
# those not among the dynamic inputs
initializers = [i for i in node.input if i not in dynamics]
# Return lists of dynamic anc initializer inputs
return dynamics, initializers


# Tidy-Up transformation sorting the inputs to all commutative operations to
# have initializer inputs last
class SortCommutativeInputsInitializerLast(Transformation):
"""
Sorts inputs of nodes describing commutative operations to have initializer
inputs last. This order of inputs is assumed by many other transformations.
"""

# Set of supported commutative operations
# TODO: There might be more valid operations
SUPPORTED_COMMUTATIVE_OPS = {"Add", "Mul", "And", "Or", "Xor", "Sum"}

# Applies the transform to a whole model graph
def apply(self, model): # noqa
# Get the model graph out of the model wrapper object
graph = model.graph
# Keep track of whether the graph has been modified
graph_modified = False
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# Check whether this node is among the supported
if node.op_type in self.SUPPORTED_COMMUTATIVE_OPS:
# Group node inputs by category
dynamics, initializers = group_inputs_by_category(node, model)
# Flatten the grouped input list
inputs = [*dynamics, *initializers]
# Length of sorted and original input list must match
assert len(inputs) == len(node.input)
# Reassigned inputs from sorted categories
# Note: ONNX does not allow direct assignment to node.input
for i, name in enumerate(inputs):
# The graph has been modified if any input is reordered
if node.input[i] != name:
# Note: This is never reset back to False
graph_modified = True
# Reassign input name at the new index
node.input[i] = name
# Return the transformed model and indicate whether the graph actually
# has been transformed
return model, graph_modified

0 comments on commit 7719a3e

Please sign in to comment.