Skip to content

Commit

Permalink
Merge pull request #149 from fastmachinelearning/feature/improved_cha…
Browse files Browse the repository at this point in the history
…nlast_eltwiseops

Improved channels-last via elementwise op generalization
  • Loading branch information
maltanar authored Dec 13, 2024
2 parents 9b22db4 + a18a186 commit e5d5903
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/qonnx/custom_op/channels_last/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def verify_node(self):

# verify number of attributes
num_of_attr = 2
if len(node.attribute) == num_of_attr:
if len(node.attribute) >= num_of_attr:
info_messages.append("The number of attributes is correct")
else:
info_messages.append(
Expand Down
Binary file not shown.
241 changes: 195 additions & 46 deletions src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,110 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from copy import deepcopy
from onnx import TensorProto, helper

from qonnx.analysis.topology import is_linear
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op import channels_last
from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args
from qonnx.transformation.base import Transformation
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import SortGraph
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
from qonnx.util.basic import get_by_name
from qonnx.util.onnx import is_eltwise_optype

# Standard ONNX nodes which require a ChannelsLast data format to function properly
_channelsLast_node_types = list(channels_last.custom_op.keys())

# Nodes, which do not modify the shape of the tensor
# And modify all values in the same way.
_move_through_nodes = ["Quant", "Relu"]
_move_through_nodes = ["Quant", "Relu", "Selu", "LeakyRelu", "Sigmoid", "Tanh"]

# Nodes, which do not modify the shape of the tensor,
# And modify all values in the same way, if the second tensor is a scalar.
_move_through_nodes_if_scalar = ["Mul", "Div", "Sub", "Add"]


def get_transpose_perms(transpose_node, model):
perm = get_by_name(transpose_node.attribute, "perm")
ndim = len(model.get_tensor_shape(transpose_node.input[0]))
if perm is None:
# default perm is to reverse the dim order
return list(range(ndim - 1, -1, -1))
else:
return list(perm.ints)


def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrapper):
t0 = transpose_node.input[0]
t1 = transpose_node.output[0]
t2 = eltwise_node.output[0]
subgraph_inp_shape = model.get_tensor_shape(t0)
ndim_inp = len(subgraph_inp_shape)
perm = get_transpose_perms(transpose_node, model)

# before: t0 -> transpose -> t1 -> eltwise -> t2
# after: t0 -> eltwise -> t1 -> transpose -> t2
# find the eltwise inp index fed by transpose
transpose_in_ind = list(eltwise_node.input).index(t1)
# check all inputs for the eltwise op:
# we need to ensure those inputs get inverse-transposed
# to keep the graph semantics intact
for ind, eltwise_inp in enumerate(eltwise_node.input):
if ind == transpose_in_ind:
# the input that feeds from the original transpose
# node will be implicitly inverse-transposed, since we'll be
# moving that transpose node past the eltwise op
continue
inp_shape = model.get_tensor_shape(eltwise_inp)
ndim = len(inp_shape)
if ndim == 0:
# scalar input, always broadcastable, no action needed
continue
elif ndim == ndim_inp:
# input with matching dimensions, add inverse transpose
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm)
t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
model.graph.node.append(new_transpose_node)
else:
# input with non-matching dimensions, assume broadcastable
# first add Unsqueeze node to match number of dimensions
unsqueeze_param_name = model.make_new_valueinfo_name()
model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64))
unsqueeze_out_name = model.make_new_valueinfo_name()
new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name])
unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape
model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape)
model.graph.node.append(new_unsqueeze_node)
# now add inverse transpose
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm)
t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
model.graph.node.append(new_transpose_node)
# rewire to swap transpose and eltwise node order
eltwise_node.input[transpose_in_ind] = t0
eltwise_node.output[0] = t1
transpose_node.input[0] = t1
transpose_node.output[0] = t2
# t1 tensor shape changes to inp_shape
model.set_tensor_shape(t1, subgraph_inp_shape)
model = model.transform(SortGraph())
model = model.transform(FoldConstants())
return model


class ConvertToChannelsLastAndClean(Transformation):
"""
Converts data layout dependent nodes to ChannelsLast nodes and inserts transformations.
Expand All @@ -67,8 +146,7 @@ def __init__(self, make_input_channels_last=False):
super().__init__()
self._make_input_channels_last = make_input_channels_last

def apply(self, model):
assert model.analysis(is_linear)["is_linear"], "Only linear and non-branching models are supported at this moment."
def apply(self, model: ModelWrapper):
assert model.check_all_tensor_shapes_specified(), (
"All tensor shapes must be specified. " "Consider running InferShapes."
)
Expand All @@ -85,8 +163,9 @@ def apply(self, model):
# Technically only required if something changed in the previous trafo
model = model.transform(RemoveConsecutiveChanFirstAndChanLastTrafos())

# Apply MoveChanLastDownStream
# Apply MoveChanLastDownStream and MoveTransposePastFork
model = model.transform(MoveChanFirstDownstream())
model = model.transform(MoveTransposePastFork())

# Run RemoveConsecutiveChanFirstAndChanLastTrafos again,
# Technically only required if something changed in the previous trafo
Expand Down Expand Up @@ -218,9 +297,9 @@ def apply(self, model):
# Check the input shape and make sure we support it
input_shape = model.get_tensor_shape(n.input[0])
# Check that this is a "to chan first" trafo
perm_1 = get_by_name(n.attribute, "perm")
perm_1 = get_transpose_perms(n, model)
ndim = len(input_shape)
if list(to_channels_first_args(ndim)) == perm_1.ints:
if list(to_channels_first_args(ndim)) == perm_1:
successor_nodes = model.find_direct_successors(n)
if successor_nodes is None:
continue
Expand All @@ -229,8 +308,8 @@ def apply(self, model):
if successor_node.op_type == "Transpose":
# Check that this is a "to chan last" trafo,
# if so both can get removed.
perm_2 = get_by_name(successor_node.attribute, "perm")
if list(to_channels_last_args(ndim)) == perm_2.ints:
perm_2 = get_transpose_perms(successor_node, model)
if list(to_channels_last_args(ndim)) == perm_2:
# Connect original input to new output
input_tensor = n.input[0]
output_tensor_name = successor_node.output[0]
Expand All @@ -257,7 +336,7 @@ class MoveChanLastUpstream(Transformation):
Moves channel last transformations further upstream.
"""

def apply(self, model):
def apply(self, model: ModelWrapper):
graph = model.graph
node_ind = 0
graph_modified = False
Expand All @@ -268,8 +347,8 @@ def apply(self, model):
# Check the input shape and make sure we support it
input_shape = model.get_tensor_shape(n.input[0])
ndim = len(input_shape)
perm = get_by_name(n.attribute, "perm")
if list(to_channels_last_args(ndim)) == perm.ints:
perm = get_transpose_perms(n, model)
if list(to_channels_last_args(ndim)) == perm:
predecessors = model.find_direct_predecessors(n)
# Check if we reached the top of the graph
if predecessors is None:
Expand All @@ -285,6 +364,10 @@ def apply(self, model):
if second_inp_shape == [1] or second_inp_shape == []:
move_through_valid |= True

# don't move through if the predecessor output is a fork
if model.is_fork_node(predecessor):
move_through_valid = False

# Apply move through trafo if possible
if move_through_valid:
# Input tensors are always input 0
Expand Down Expand Up @@ -334,52 +417,29 @@ def apply(self, model):
node_ind = 0
graph_modified = False
# Find transpose nodes, which are "to chan first" trafos
for n in graph.node:
for node in graph.node:
node_ind += 1
if n.op_type == "Transpose":
if node.op_type == "Transpose":
# Check the input shape and make sure we support it
input_shape = model.get_tensor_shape(n.input[0])
input_shape = model.get_tensor_shape(node.input[0])
ndim = len(input_shape)
perm = get_by_name(n.attribute, "perm")
if list(to_channels_first_args(ndim)) == perm.ints:
perm = get_transpose_perms(node, model)
if list(to_channels_first_args(ndim)) == perm:
# Do not move the node, if it is at the top of the graph,
# this is a strange edge case, for 1D networks, where channels last and channels first trafos
# are identical.
predecessors = model.find_direct_predecessors(n)
predecessors = model.find_direct_predecessors(node)
if predecessors is None:
continue

successors = model.find_direct_successors(n)
successors = model.find_direct_successors(node)
if successors is None:
continue
successor = successors[0]
transpose_node = node

# Check if we can simply move through the next node
move_through_valid = successor.op_type in _move_through_nodes
# Check if we have a node, which applies a scalar change,
# then we can also move through.
if successor.op_type in _move_through_nodes_if_scalar:
second_inp_shape = model.get_tensor_shape(successor.input[1])
if second_inp_shape == [1] or second_inp_shape == []:
move_through_valid |= True
# Apply move through trafo if possible
if move_through_valid:
# Collect all tensors connecting n and successor
# and surrounding nodes
tensor_1 = n.input[0]
tensor_2 = n.output[0]
tensor_3 = successor.output[0]
# Now connect the tensors to the nodes again,
# but in different order
successor.input[0] = tensor_1
successor.output[0] = tensor_2
n.input[0] = tensor_2
n.output[0] = tensor_3

# Change the shape of the middle tensor
target_shape = model.get_tensor_shape(tensor_1)
model.set_tensor_shape(tensor_2, target_shape)

if is_eltwise_optype(successor.op_type):
model = move_transpose_past_eltwise(transpose_node, successor, model)
graph_modified = True
return model, graph_modified

Expand Down Expand Up @@ -422,7 +482,7 @@ def apply(self, model):
input_shape = model.get_tensor_shape(transp_node.input[0])
# check if transpose converts ChannelsLast to ChannelsFirst
ndim = len(input_shape)
perms = get_by_name(transp_node.attribute, "perm").ints
perms = get_transpose_perms(transp_node, model)
if list(to_channels_first_args(ndim)) == perms:
producer = model.find_producer(transp_node.input[0])
consumer = model.find_consumer(n.output[0])
Expand Down Expand Up @@ -505,3 +565,92 @@ def apply(self, model):
into subsequent node"
)
return model, graph_modified


class MoveOpPastFork(Transformation):
"""Move node operations past graph forks. Used when a node before a fork
can be merged with nodes in the branches
"""

def __init__(self, op_name_list):
super().__init__()
self.ops_to_move = op_name_list

def apply(self, model):
graph = model.graph
graph_modified = False
nodes = [n for n in graph.node]
node_ind = 0
for node in nodes:
node_ind += 1
if node.op_type in self.ops_to_move and model.is_fork_node(node) and not model.is_join_node(node):
# Restrict this transform to operations with constant parameters
# Assuming parameters is in input 1
if len(node.input) > 1:
op_init_param = model.get_initializer(node.input[1])
else:
op_init_param = None

# Check case when branches are empty and go
# to the same node
consumers = model.find_consumers(node.output[0])
assert len(consumers) > 1, "Must have >1 consumer"
unique_consumer = True
for consum_node in consumers[1:]:
if consumers[0] != consum_node:
unique_consumer = False
break

if unique_consumer:
continue

for consumer_node in consumers[1:]:
# create new node
new_output_tensor_name = model.make_new_valueinfo_name()
if op_init_param is None:
new_inp_list = [node.input[0]]
else:
new_param_name = model.make_new_valueinfo_name()
new_inp_list = [node.input[0], new_param_name]
model.set_initializer(new_param_name, op_init_param)
new_node = deepcopy(node)
new_node.input[:] = new_inp_list
new_node.output[:] = [new_output_tensor_name]
graph.node.insert(node_ind, new_node)
node_ind += 1

# change consumer input tensor
graph.node.remove(consumer_node)
for idx, consumer_input in enumerate(consumer_node.input):
if consumer_input == node.output[0]:
consumer_node.input[idx] = new_output_tensor_name
break
else:
raise Exception("Consumer should have the current node output as input")

graph.node.insert(node_ind, consumer_node)

graph_modified = True

model = model.transform(InferShapes())
return (model, graph_modified)


class MoveAddPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add"])


class MoveMulPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Mul"])


class MoveLinearPastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Add", "Mul"])


class MoveTransposePastFork(MoveOpPastFork):
def __init__(self):
super().__init__(["Transpose"])
Loading

0 comments on commit e5d5903

Please sign in to comment.