diff --git a/src/qonnx/transformation/create_generic_partitions.py b/src/qonnx/transformation/create_generic_partitions.py index bc6ea7d7..999d4efe 100755 --- a/src/qonnx/transformation/create_generic_partitions.py +++ b/src/qonnx/transformation/create_generic_partitions.py @@ -195,13 +195,14 @@ def __init__(self, partitioning={}, partition_dir=None): def apply(self, model): # prepare node -> int assignment fct. def partitioning_func(node): - partition_id = -1 - for key in self.partitioning: - if node in list(model.graph.node) and list(model.graph.node).index(node) in list(self.partitioning[key]): - assert partition_id == -1, """single node assigned to multiple partitions""" - partition_id = key - - return partition_id + if node not in model.graph.node: + return -1 + node_index = list(model.graph.node).index(node) + candidates = list(filter(lambda key_value: node_index in key_value[1], self.partitioning.items())) + if len(candidates) == 0: + return -1 + assert len(candidates) == 1, f"single node assigned to multiple partitions: {candidates}" + return candidates[0][0] # partition_id # apply partitioning model = model.transform(PartitionFromLambda(partitioning_func, self.partition_dir))