Skip to content

Commit

Permalink
enable TritonFusedRMSNorm with local_map annotation
Browse files Browse the repository at this point in the history
ghstack-source-id: 213ef4323f9888463076ea580c3b72e2359ec492
Pull Request resolved: #364
  • Loading branch information
XilunWu committed Jun 9, 2024
1 parent 104bd6c commit 22a7490
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
57 changes: 57 additions & 0 deletions test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)

from torchtitan.models.norms import fused_rms_norm_fn


class TestFusedRMSNorm(DTensorTestBase):
@property
def world_size(self):
return 4

@skip_if_lt_x_gpu(4)
@with_comms
def test_fused_rms_norm(self):
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1)
w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate

dx = distribute_tensor(x, mesh, [Shard(1)])
dw = distribute_tensor(w, mesh, [Replicate()])

comm_mode = CommDebugMode()
# fused rmsnorm
with comm_mode:
out = fused_rms_norm_fn(dx, dw)

self.assertEqual(comm_mode.get_total_counts(), 0)

with comm_mode:
grad_out = torch.ones_like(out)
out.backward(grad_out)

self.assertEqual(comm_mode.get_total_counts(), 0)


if __name__ == "__main__":
run_tests()
15 changes: 15 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import math

from functools import partial

import torch
import torch.nn as nn

import triton
import triton.language as tl

from torch.distributed._tensor.experimental import local_map
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Expand Down Expand Up @@ -214,6 +219,11 @@ def _rms_norm_bwd_kernel_sm(


class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
Expand Down Expand Up @@ -256,6 +266,11 @@ def forward(ctx, x, weight, eps):
y = y.reshape(x_shape_start)
return y

@partial(
local_map,
out_placements=([Shard(1)], [_Partial()], None),
in_placements=(None, [Shard(1)]),
)
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
Expand Down

0 comments on commit 22a7490

Please sign in to comment.