Skip to content

Commit

Permalink
fix: refactor layer norm converter with INormalization Layer (#2755)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored and peri044 committed Apr 29, 2024
1 parent ea4d580 commit 744abeb
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 102 deletions.
111 changes: 20 additions & 91 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
get_positive_dim,
get_trt_tensor,
to_numpy,
Expand Down Expand Up @@ -105,102 +106,30 @@ def layer_norm(
cudnn_enable: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

shape = weight.shape
gamma = to_numpy(weight).reshape(shape)
beta = to_numpy(bias).reshape(shape)

dims = list(range(len(input.shape) - len(shape), len(input.shape)))

# E[x]
mean_expected_trt = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean_expected", input, dims, True
)

# X-E[x]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
input,
mean_expected_trt,
)

# Variance = mean(pow(x_sub_mean, 2))
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow_var",
sub_trt,
pow_trt,
)
mean_trt = impl.reduce.mean(
ctx, target, source_ir, f"{name}_mean", pow_var, dims, True
)

# sqrt((var + eps))
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
mean_trt,
eps_trt,
)
sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)

# (X - E[X]) / sqrt((var + eps))
div_trt = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)

gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma")
beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta")

# y * gamma + beta
scaled_y = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
div_trt,
gamma_trt,
)
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
if tuple(input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
)
if tuple(input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
)

output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
scaled_y,
beta_trt,
)
layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
layer_norm.epsilon = eps
layer_norm.compute_precision = input.dtype
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)

if return_mean_rstd:
# return fake mean and rstd for now
return output, None, None
return layer_norm.get_output(0), None, None

return output
return layer_norm.get_output(0)


def native_group_norm(
Expand Down
72 changes: 61 additions & 11 deletions tests/py/dynamo/conversion/test_layer_norm_aten.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,75 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLayerNormConverter(DispatchTestCase):
def test_layer_norm(self):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.layer_norm.default(
x,
torch.tensor([3, 224, 224]),
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
True,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)

inputs = [torch.randn(1, 3, 224, 224)]
inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)


class TestNativeLayerNormConverter(DispatchTestCase):
def test_layer_norm(self):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)[0]

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)

def test_layernorm_with_dynamic_shape(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
Expand All @@ -37,10 +80,17 @@ def forward(self, x):
1e-05,
)[0]

inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(
input_specs = [
Input(
shape=(-1, 3, 224, 224),
dtype=torch.float32,
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))],
),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
inputs,
input_specs,
)


Expand Down

0 comments on commit 744abeb

Please sign in to comment.