Skip to content

Commit

Permalink
RMSNorm blocked implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Nov 7, 2024
1 parent 086312b commit 381b660
Showing 1 changed file with 63 additions and 31 deletions.
94 changes: 63 additions & 31 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,67 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
row_idx = tl.program_id(0)

#Calculate squared mean by block
row_start_ptr = input_ptr + row_idx * input_row_stride
row_sum = 0.0
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
#tl.device_print("n_cols_blks",n_cols_blks)
for b in tl.range(0, n_cols_blks):
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row_block = tl.load(input_ptrs, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row

col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols) #tl.sum across row


row_norm = row_sum + eps
row_norm = tl.rsqrt(row_norm)

#Blocked normalization
output_row_start_ptr = output_ptr + row_idx * output_row_stride
#for b in tl.range(0, n_cols, BLOCK_SIZE):
for b in tl.range(0, n_cols_blks):
col_offsets = b*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
row_norm = row * row #square each value
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
row_norm = row_norm / n_cols #divide by n_cols
row_norm = row_norm + epsilon #add epsilon
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
rms_norm = row * row_norm #multiply each x by normalization value
rms_norm = rms_norm * g #element wise multiplication with g

output_row_start_ptr = output_ptr + row_idx * output_row_stride
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

output_ptrs = output_row_start_ptr + col_offsets
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
tl.store(output_ptrs, rms_norm, mask=mask)
tl.store(output_ptrs, output)

col_offsets = n_cols_blks*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

#tl.device_print("output",output)
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, output, mask=mask)



def triton_rmsnorm(x, g, epsilon=1e-6):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
#performance can drop significantly for larger n_cols.
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))

y = torch.empty_like(x, device='cuda')

Expand All @@ -84,7 +116,6 @@ def triton_rmsnorm(x, g, epsilon=1e-6):

return y


def torch_rmsnorm(x, g):
M, N = x.shape
if hasattr(torch.nn, 'RMSNorm'):
Expand All @@ -95,15 +126,17 @@ def torch_rmsnorm(x, g):
rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g
return rms_norm


# yapf: disable
@pytest.mark.parametrize('M, N', [
(1, 4),
(2, 10),
(8192, 4096),
(4096, 8192),
(1, 8192),
(873, 1245),
])
(1, 4),
(2, 10),
(8192, 4096),
(4096, 8192),
(1, 8192),
(873, 1245),
(1, 98304)
])
# yapf: enable
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
Expand All @@ -114,7 +147,6 @@ def test_rmsnorm(M, N):

assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)


#Benchmark
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}

Expand Down

0 comments on commit 381b660

Please sign in to comment.