Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Dropout with ratio=0 with RemoveIdentityOps #157

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions src/qonnx/transformation/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,40 +107,54 @@ def __init__(self, atol=1e-05):
self.atol = atol

def apply(self, model):
opset_version = model.model.opset_import[0].version
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
for node in graph.node:
node_ind += 1
if n.op_type in ["Add", "Sub"] and not model.is_fork_node(n) and not model.is_join_node(n):
A = model.get_initializer(n.input[1])
if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node):
A = model.get_initializer(node.input[1])
if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all():
remove_node_and_rewire(model, n)
remove_node_and_rewire(model, node)
graph_modified = True
break

elif n.op_type in ["Mul", "Div"] and not model.is_fork_node(n) and not model.is_join_node(n):
A = model.get_initializer(n.input[1])
elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node):
A = model.get_initializer(node.input[1])
if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all():
remove_node_and_rewire(model, n)
remove_node_and_rewire(model, node)
graph_modified = True
break
elif n.op_type == "Pad" and not model.is_fork_node(n) and not model.is_join_node(n):
pads = get_by_name(n.attribute, "pads")
elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node):
pads = get_by_name(node.attribute, "pads")
if pads is not None:
# older versions of Pad op specify pads as attribute
pads = np.asarray(pads.ints, dtype=np.int64)
else:
# newer versions of Pad op specify pads as input
pads = model.get_initializer(n.input[1])
pads = model.get_initializer(node.input[1])

if (pads is not None) and (pads == 0).all():
remove_node_and_rewire(model, n)
remove_node_and_rewire(model, node)
graph_modified = True
break
elif n.op_type == "Identity":
remove_node_and_rewire(model, n)
elif node.op_type == "Identity":
remove_node_and_rewire(model, node)
graph_modified = True
break
elif node.op_type == "Dropout":
if opset_version < 12:
dropout_ratio = get_by_name(node.attribute, "ratio")
dropout_id_cond = not (dropout_ratio is None) and dropout_ratio.f == 0
else:
based_on_inplen = len(node.input) == 1
based_on_ratio_inp = (not based_on_inplen) and model.get_initializer(node.input[1]) == 0
dropout_id_cond = based_on_inplen or based_on_ratio_inp
if dropout_id_cond:
remove_node_and_rewire(model, node)
graph_modified = True
break

model = model.transform(InferShapes())
return (model, graph_modified)
12 changes: 10 additions & 2 deletions tests/transformation/test_remove_identity_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@


def insert_identity_op(model, op, as_first_node, approx):
kwargs = {}
inp_ndims = 4 if as_first_node else 2
if approx:
zero_val = 0.000001
one_val = 0.999999
Expand All @@ -53,6 +55,12 @@ def insert_identity_op(model, op, as_first_node, approx):
val = np.asarray([one_val], dtype=np.float32)
elif op in ["Identity"]:
val = None
elif op == "Pad":
# opset 11 and above: padding specified as input and not attribute
val = np.asarray([0] * 2 * inp_ndims, dtype=np.int64)
elif op == "Dropout":
val = None
kwargs = {"ratio": 0.0}
else:
return

Expand All @@ -62,7 +70,7 @@ def insert_identity_op(model, op, as_first_node, approx):
else:
model.set_initializer("value", val)
inplist = ["inp" if as_first_node else "div_out", "value"]
identity_node = helper.make_node(op, inplist, ["ident_out"])
identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs)
if as_first_node:
graph.node.insert(0, identity_node)
graph.node[1].input[0] = "ident_out"
Expand All @@ -74,7 +82,7 @@ def insert_identity_op(model, op, as_first_node, approx):


# identity operations to be inserted
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity", "Pad", "Dropout"])
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
@pytest.mark.parametrize("fork_before_id", [False, True])
Expand Down
Loading