Skip to content

Commit

Permalink
Merge pull request #158 from fastmachinelearning/feature/remove_forke…
Browse files Browse the repository at this point in the history
…d_identity

Remove identity nodes with output forking
  • Loading branch information
maltanar authored Dec 6, 2024
2 parents 46721d9 + a5d5668 commit d9269a9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/qonnx/transformation/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ def apply(self, model):
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node):
if node.op_type in ["Add", "Sub"]:
A = model.get_initializer(node.input[1])
if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all():
remove_node_and_rewire(model, node)
graph_modified = True
break

elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node):
elif node.op_type in ["Mul", "Div"]:
A = model.get_initializer(node.input[1])
if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all():
remove_node_and_rewire(model, node)
graph_modified = True
break
elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node):
elif node.op_type == "Pad":
pads = get_by_name(node.attribute, "pads")
if pads is not None:
# older versions of Pad op specify pads as attribute
Expand Down
31 changes: 24 additions & 7 deletions tests/transformation/test_remove_identity_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@
import qonnx.core.onnx_exec as oxe
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.general import SortGraph
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model


def insert_identity_op(model, op, as_first_node, approx):
def insert_identity_op(model, op, as_first_node, approx, fork_after_id):
kwargs = {}
inp_ndims = 4 if as_first_node else 2
if approx:
Expand Down Expand Up @@ -71,12 +72,24 @@ def insert_identity_op(model, op, as_first_node, approx):
model.set_initializer("value", val)
inplist = ["inp" if as_first_node else "div_out", "value"]
identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs)
old_2nd_node = graph.node[1]
old_last_node = graph.node[-1]
graph.node.append(identity_node)
if fork_after_id:
graph.node.append(helper.make_node("Mul", ["ident_out", "mul2"], ["mulbranch0_out"]))
model.set_initializer("mul2", np.asarray([2.0], dtype=np.float32))
graph.node.append(helper.make_node("Mul", ["ident_out", "mul3"], ["mulbranch1_out"]))
model.set_initializer("mul3", np.asarray([3.0], dtype=np.float32))
graph.node.append(helper.make_node("Add", ["mulbranch0_out", "mulbranch1_out"], ["idfork_out"]))
subgraph_out = "idfork_out"
else:
subgraph_out = "ident_out"

if as_first_node:
graph.node.insert(0, identity_node)
graph.node[1].input[0] = "ident_out"
old_2nd_node.input[0] = subgraph_out
else:
graph.node.insert(3, identity_node)
graph.node[-1].input[0] = "ident_out"
old_last_node.input[0] = subgraph_out
model = model.transform(SortGraph())

return model

Expand All @@ -86,7 +99,10 @@ def insert_identity_op(model, op, as_first_node, approx):
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
@pytest.mark.parametrize("fork_before_id", [False, True])
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
@pytest.mark.parametrize("fork_after_id", [False, True])
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id, fork_after_id):
if approx and not (op in ["Add", "Sub", "Mul", "Div"]):
pytest.skip(f"approx=True not relevant for {op}")
# set up onnx model
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
Expand Down Expand Up @@ -119,7 +135,8 @@ def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
model.set_initializer("shape", shape_values)
model.set_initializer("div", div_values)
model.set_initializer("matmul", matmul_values)
insert_identity_op(model, op, as_first_node, approx)
insert_identity_op(model, op, as_first_node, approx, fork_after_id)
model = model.transform(InferShapes())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
Expand Down

0 comments on commit d9269a9

Please sign in to comment.