diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 402f3a5b..c45aea6c 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -113,20 +113,20 @@ def apply(self, model): graph_modified = False for node in graph.node: node_ind += 1 - if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node): + if node.op_type in ["Add", "Sub"]: A = model.get_initializer(node.input[1]) if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all(): remove_node_and_rewire(model, node) graph_modified = True break - elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node): + elif node.op_type in ["Mul", "Div"]: A = model.get_initializer(node.input[1]) if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all(): remove_node_and_rewire(model, node) graph_modified = True break - elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node): + elif node.op_type == "Pad": pads = get_by_name(node.attribute, "pads") if pads is not None: # older versions of Pad op specify pads as attribute diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index 506cdbc3..91e74554 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -34,13 +34,14 @@ import qonnx.core.onnx_exec as oxe from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.general import SortGraph from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.remove import RemoveIdentityOps from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model -def insert_identity_op(model, op, as_first_node, approx): +def insert_identity_op(model, op, as_first_node, approx, fork_after_id): kwargs = {} inp_ndims = 4 if as_first_node else 2 if approx: @@ -71,12 +72,24 @@ def insert_identity_op(model, op, as_first_node, approx): model.set_initializer("value", val) inplist = ["inp" if as_first_node else "div_out", "value"] identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs) + old_2nd_node = graph.node[1] + old_last_node = graph.node[-1] + graph.node.append(identity_node) + if fork_after_id: + graph.node.append(helper.make_node("Mul", ["ident_out", "mul2"], ["mulbranch0_out"])) + model.set_initializer("mul2", np.asarray([2.0], dtype=np.float32)) + graph.node.append(helper.make_node("Mul", ["ident_out", "mul3"], ["mulbranch1_out"])) + model.set_initializer("mul3", np.asarray([3.0], dtype=np.float32)) + graph.node.append(helper.make_node("Add", ["mulbranch0_out", "mulbranch1_out"], ["idfork_out"])) + subgraph_out = "idfork_out" + else: + subgraph_out = "ident_out" + if as_first_node: - graph.node.insert(0, identity_node) - graph.node[1].input[0] = "ident_out" + old_2nd_node.input[0] = subgraph_out else: - graph.node.insert(3, identity_node) - graph.node[-1].input[0] = "ident_out" + old_last_node.input[0] = subgraph_out + model = model.transform(SortGraph()) return model @@ -86,7 +99,10 @@ def insert_identity_op(model, op, as_first_node, approx): @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) @pytest.mark.parametrize("fork_before_id", [False, True]) -def test_remove_identity_ops(op, as_first_node, approx, fork_before_id): +@pytest.mark.parametrize("fork_after_id", [False, True]) +def test_remove_identity_ops(op, as_first_node, approx, fork_before_id, fork_after_id): + if approx and not (op in ["Add", "Sub", "Mul", "Div"]): + pytest.skip(f"approx=True not relevant for {op}") # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, []) @@ -119,7 +135,8 @@ def test_remove_identity_ops(op, as_first_node, approx, fork_before_id): model.set_initializer("shape", shape_values) model.set_initializer("div", div_values) model.set_initializer("matmul", matmul_values) - insert_identity_op(model, op, as_first_node, approx) + insert_identity_op(model, op, as_first_node, approx, fork_after_id) + model = model.transform(InferShapes()) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values}