Skip to content

Commit

Permalink
enable TritonFusedRMSNorm with local_map annotation
Browse files Browse the repository at this point in the history
ghstack-source-id: 6125011aba1a4bd9521fb4a3b761b62285ea6195
Pull Request resolved: #364
  • Loading branch information
XilunWu committed Jun 4, 2024
1 parent 1ceaa4e commit 51825ca
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
55 changes: 55 additions & 0 deletions test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)

from torchtitan.models.norms import fused_rms_norm_fn


class TestFusedRMSNorm(DTensorTestBase):
@property
def world_size(self):
return 4

@with_comms
def test_fused_rms_norm(self):
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1)
w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate

dx = distribute_tensor(x, mesh, [Shard(1)])
dw = distribute_tensor(w, mesh, [Replicate()])

comm_mode = CommDebugMode()
# fused rmsnorm
with comm_mode:
out = fused_rms_norm_fn(dx, dw)

self.assertEqual(comm_mode.get_total_counts(), 0)

with comm_mode:
grad_out = torch.ones_like(out)
out.backward(grad_out)

self.assertEqual(comm_mode.get_total_counts(), 0)


if __name__ == "__main__":
run_tests()
15 changes: 15 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import math

from functools import partial

import torch
import torch.nn as nn

import triton
import triton.language as tl

from torch.distributed._tensor.experimental import local_map
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Expand Down Expand Up @@ -214,6 +219,11 @@ def _rms_norm_bwd_kernel_sm(


class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
Expand Down Expand Up @@ -256,6 +266,11 @@ def forward(ctx, x, weight, eps):
y = y.reshape(x_shape_start)
return y

@partial(
local_map,
out_placements=([Shard(1)], [_Partial()], None),
in_placements=(None, [Shard(1)]),
)
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
tensor_parallel_degree = 2
fp8_linear = ""
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
Expand Down

0 comments on commit 51825ca

Please sign in to comment.