Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert Identity nodes on given tensor and top-level graph inputs #166

Merged
merged 6 commits into from
Jan 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
python-version: ["3.10"]

steps:
- name: Checkout
127 changes: 127 additions & 0 deletions src/qonnx/transformation/insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from onnx import helper as oh

from qonnx.transformation.base import Transformation
from qonnx.transformation.general import SortGraph


class InsertIdentityOnAllTopLevelIO(Transformation):
"""
Transformation that inserts an Identity node on all top-level inputs and outputs
of the ONNX graph. This can be useful before calling transformations that do not
gracefully handle edge cases where transformed tensors are top-level inputs or outputs.
"""

def apply(self, model):
graph = model.graph
for inp in graph.input:
model = model.transform(InsertIdentity(inp.name, "consumer"))
for out in graph.output:
model = model.transform(InsertIdentity(out.name, "producer"))
return model, False


class InsertIdentity(Transformation):
"""
Transformation that inserts an Identity node in the ONNX graph. For edge cases
where tensor_name is a graph input and producer_or_consumer is 'producer', the
graph input will be replaced with a new tensor name <old_name>_identity. For the
edge case where tensor_name is a graph output and producer_or_consumer is 'consumer',
the graph output will be replaced with a new tensor name <old_name>_identity

Parameters:
tensor_name (str): The name of the tensor where the Identity node will be inserted.
producer_or_consumer (str): Indicates whether the Identity node will be inserted before ('producer')
or after ('consumer') the tensor_name.

"""

def __init__(self, tensor_name, producer_or_consumer):
super().__init__()
self.tensor_name = tensor_name
self.producer_or_consumer = producer_or_consumer

def insert_node_before(self, model, tensor):
graph = model.graph
new_tensor_name = tensor + "_identity"
# rewire the tensor's producer to the new tensor
prod = model.find_producer(tensor)
if prod is not None:
prod_outlist = list(prod.output)
prod.output[prod_outlist.index(tensor)] = new_tensor_name
else:
# if the tensor is an input tensor (top-level)
# update the graph's input
top_inp_names = [inp.name for inp in graph.input]
graph.input[top_inp_names.index(tensor)].name = new_tensor_name
# Create a new node
identity_node = oh.make_node("Identity", [new_tensor_name], [tensor])
# Insert the new node
# we do this late in the process to avoid affecting find_producer
graph.node.append(identity_node)

def insert_node_after(self, model, tensor):
graph = model.graph
new_tensor_name = tensor + "_identity"
# rewire the tensor's consumers to the new node
consumers = model.find_consumers(tensor)
if consumers == []:
# if the tensor is an output tensor (top-level)
# find the graph's output and replace it with the new name
top_out_name = [out.name for out in graph.output]
graph.output[top_out_name.index(tensor)].name = new_tensor_name
# TODO what if feeding multiple graph outputs? seems unlikely...
else:
for consumer in consumers:
consumer_inplist = list(consumer.input)
consumer.input[consumer_inplist.index(tensor)] = new_tensor_name
# Create a new node
# we do this late in the process to avoid affecting find_consumers
identity_node = oh.make_node("Identity", [tensor], [new_tensor_name])
# Insert the new node
graph.node.append(identity_node)

def apply(self, model):
# Find the tensor in the graph
tshape = model.get_tensor_shape(self.tensor_name)
if tshape is None:
raise ValueError(f"Tensor '{self.tensor_name}' not found in the graph.")
tensor = self.tensor_name
# Insert the Identity node before or after the specified tensor
if self.producer_or_consumer == "producer":
self.insert_node_before(model, tensor)
elif self.producer_or_consumer == "consumer":
self.insert_node_after(model, tensor)
else:
raise ValueError("producer_or_consumer must be either 'producer' or 'consumer'.")

model = model.transform(SortGraph())
# important to return run_again=False to avoid infinite loop
return (model, False)
128 changes: 128 additions & 0 deletions tests/transformation/test_insert_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of AMD nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

from onnx import TensorProto
from onnx import helper as oh

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.insert import InsertIdentity, InsertIdentityOnAllTopLevelIO


@pytest.fixture
def simple_model():
# Create a simple ONNX model for testing
input_tensor = oh.make_tensor_value_info("input", TensorProto.FLOAT, [1, 2])
output_tensor = oh.make_tensor_value_info("output", TensorProto.FLOAT, [1, 2])
node1 = oh.make_node("Relu", ["input"], ["intermediate"])
node2 = oh.make_node("Relu", ["intermediate"], ["output"])
graph = oh.make_graph([node1, node2], "test_graph", [input_tensor], [output_tensor])
model = ModelWrapper(oh.make_model(graph))
model = model.transform(InferShapes())
return model


def test_insert_identity_on_all_top_level_io(simple_model):
orig_top_inp_names = [inp.name for inp in simple_model.graph.input]
orig_top_out_names = [out.name for out in simple_model.graph.output]
model = simple_model.transform(InsertIdentityOnAllTopLevelIO())
for inp in orig_top_inp_names:
assert model.find_consumer(inp).op_type == "Identity"
for out in orig_top_out_names:
assert model.find_producer(out).op_type == "Identity"
assert orig_top_inp_names == [inp.name for inp in model.graph.input]
assert orig_top_out_names == [out.name for out in model.graph.output]


def test_insert_identity_before_input(simple_model):
# Apply the transformation
transformation = InsertIdentity("input", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("input")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_input(simple_model):
# Apply the transformation
transformation = InsertIdentity("input", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("input")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_before_intermediate(simple_model):
# Apply the transformation
transformation = InsertIdentity("intermediate", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("intermediate")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_intermediate(simple_model):
# Apply the transformation
transformation = InsertIdentity("intermediate", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("intermediate")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_before_output(simple_model):
# Apply the transformation
transformation = InsertIdentity("output", "producer")
model = simple_model.transform(transformation)

identity_node = model.find_producer("output")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_insert_identity_after_output(simple_model):
# Apply the transformation
transformation = InsertIdentity("output", "consumer")
model = simple_model.transform(transformation)

identity_node = model.find_consumer("output")
assert identity_node is not None
assert identity_node.op_type == "Identity"


def test_tensor_not_found(simple_model):
# Apply the transformation with a non-existent tensor
transformation = InsertIdentity("non_existent_tensor", "producer")
with pytest.raises(ValueError):
simple_model.transform(transformation)

Unchanged files with check annotations Beta

.. _readme:

Check warning on line 1 in docs/readme.rst

GitHub Actions / docs

document isn't included in any toctree
.. include:: ../README.rst

Check warning on line 2 in docs/readme.rst

GitHub Actions / docs

Problems with "include" directive path:
.. _changes:

Check warning on line 1 in docs/changelog.rst

GitHub Actions / docs

document isn't included in any toctree