diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 0f7f38f7..402f3a5b 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -107,40 +107,54 @@ def __init__(self, atol=1e-05): self.atol = atol def apply(self, model): + opset_version = model.model.opset_import[0].version graph = model.graph node_ind = 0 graph_modified = False - for n in graph.node: + for node in graph.node: node_ind += 1 - if n.op_type in ["Add", "Sub"] and not model.is_fork_node(n) and not model.is_join_node(n): - A = model.get_initializer(n.input[1]) + if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node): + A = model.get_initializer(node.input[1]) if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all(): - remove_node_and_rewire(model, n) + remove_node_and_rewire(model, node) graph_modified = True break - elif n.op_type in ["Mul", "Div"] and not model.is_fork_node(n) and not model.is_join_node(n): - A = model.get_initializer(n.input[1]) + elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node): + A = model.get_initializer(node.input[1]) if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all(): - remove_node_and_rewire(model, n) + remove_node_and_rewire(model, node) graph_modified = True break - elif n.op_type == "Pad" and not model.is_fork_node(n) and not model.is_join_node(n): - pads = get_by_name(n.attribute, "pads") + elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node): + pads = get_by_name(node.attribute, "pads") if pads is not None: # older versions of Pad op specify pads as attribute pads = np.asarray(pads.ints, dtype=np.int64) else: # newer versions of Pad op specify pads as input - pads = model.get_initializer(n.input[1]) + pads = model.get_initializer(node.input[1]) if (pads is not None) and (pads == 0).all(): - remove_node_and_rewire(model, n) + remove_node_and_rewire(model, node) graph_modified = True break - elif n.op_type == "Identity": - remove_node_and_rewire(model, n) + elif node.op_type == "Identity": + remove_node_and_rewire(model, node) graph_modified = True break + elif node.op_type == "Dropout": + if opset_version < 12: + dropout_ratio = get_by_name(node.attribute, "ratio") + dropout_id_cond = not (dropout_ratio is None) and dropout_ratio.f == 0 + else: + based_on_inplen = len(node.input) == 1 + based_on_ratio_inp = (not based_on_inplen) and model.get_initializer(node.input[1]) == 0 + dropout_id_cond = based_on_inplen or based_on_ratio_inp + if dropout_id_cond: + remove_node_and_rewire(model, node) + graph_modified = True + break + model = model.transform(InferShapes()) return (model, graph_modified) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index cfe01a82..506cdbc3 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -41,6 +41,8 @@ def insert_identity_op(model, op, as_first_node, approx): + kwargs = {} + inp_ndims = 4 if as_first_node else 2 if approx: zero_val = 0.000001 one_val = 0.999999 @@ -53,6 +55,12 @@ def insert_identity_op(model, op, as_first_node, approx): val = np.asarray([one_val], dtype=np.float32) elif op in ["Identity"]: val = None + elif op == "Pad": + # opset 11 and above: padding specified as input and not attribute + val = np.asarray([0] * 2 * inp_ndims, dtype=np.int64) + elif op == "Dropout": + val = None + kwargs = {"ratio": 0.0} else: return @@ -62,7 +70,7 @@ def insert_identity_op(model, op, as_first_node, approx): else: model.set_initializer("value", val) inplist = ["inp" if as_first_node else "div_out", "value"] - identity_node = helper.make_node(op, inplist, ["ident_out"]) + identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs) if as_first_node: graph.node.insert(0, identity_node) graph.node[1].input[0] = "ident_out" @@ -74,7 +82,7 @@ def insert_identity_op(model, op, as_first_node, approx): # identity operations to be inserted -@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) +@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity", "Pad", "Dropout"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) @pytest.mark.parametrize("fork_before_id", [False, True])