Skip to content

Commit

Permalink
[Mosaic GPU] Implement tiled and swizzled transfers for tiled layouts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694449664
  • Loading branch information
apaszke authored and Google-ML-Automation committed Nov 8, 2024
1 parent 5e43220 commit 6a124ac
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 36 deletions.
216 changes: 181 additions & 35 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import functools
import math
from typing import Sequence, TypeVar, Iterable
from typing import Iterable, Sequence, TypeVar

import jax
from jaxlib.mlir import ir
Expand Down Expand Up @@ -110,6 +110,23 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]:
strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled)
return strides

def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
for tile in self.tiles:
untiled, tiled = indices[:-len(tile)], indices[-len(tile):]
indices = (
*untiled,
*(i // t for i, t in zip(tiled, tile)),
*(i % t for i, t in zip(tiled, tile)),
)
return indices

def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]:
for tile in reversed(self.tiles):
untiled = indices[:-2 * len(tile)]
outer = indices[-2 * len(tile):-len(tile)]
inner = indices[-len(tile):]
indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile)))
return indices

def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]:
"""Like built-in enumerate, but returns negative indices into the sequence."""
Expand Down Expand Up @@ -185,6 +202,15 @@ def __post_init__(self):
if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE:
raise ValueError

@property
def base_tile_shape(self) -> int:
"""The shape of the first tile in the tiling expression.
This tile acts as the divisibility constraint for a suffix of arrays to
which this layout applies.
"""
return self.tiling.tiles[0]

@functools.cached_property
def tiled_tiling_shape(self) -> tuple[int, ...]:
"""The shape of the suffix of the array after tiling.
Expand All @@ -194,7 +220,7 @@ def tiled_tiling_shape(self) -> tuple[int, ...]:
so the tiled shape always ends with this suffix, no matter what array shape
it's applied to.
"""
return self.tiling.tile_shape(self.tiling.tiles[0])
return self.tiling.tile_shape(self.base_tile_shape)

@property
def vector_length(self) -> int:
Expand Down Expand Up @@ -231,6 +257,8 @@ def lane_indices(self) -> tuple[ir.Value, ...]:
assert math.prod(tiled_shape) == WARP_SIZE
lane_strides = utils.get_contiguous_strides(tiled_shape)
lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32))
# TODO(apaszke): Rewrite so that we can be sure that this never actually
# does arithmetic for any dimensions that are not in lane_dims.
return tuple(
arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32))
for stride, size in zip(lane_strides, tiled_shape)
Expand Down Expand Up @@ -1260,10 +1288,8 @@ def _store_untiled_tiled(self, ref: ir.Value):
ptr = utils.memref_ptr(ref)
# Fold warp and lane offsets into the pointer once, since they are dynamic.
dyn_strides = [arith.constant(i32, s) for s in strides]
def dyn_dot(x, y):
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
warp_offset = dyn_dot(layout.warp_indices(), dyn_strides)
lane_offset = dyn_dot(layout.lane_indices(), dyn_strides)
warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides)
lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides)
dyn_offset = arith.addi(warp_offset, lane_offset)
ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype)
# All warp tile offsets are static and can be fused into the store.
Expand All @@ -1273,41 +1299,68 @@ def dyn_dot(x, y):
llvm.store(reg, reg_ptr)

def store_tiled(self, ref, swizzle: int | None):
if self.layout != WGMMA_LAYOUT:
raise NotImplementedError
dtype = self.mlir_dtype
bw = mgpu.bytewidth(dtype)
m, n = self.shape
assert m % 64 == 0 # This is implied by the layout.
cols_per_tile = swizzle // bw
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
if n < cols_per_tile: # We allow singular tiles shorter than swizzle.
expected_shape = [m // 64, 1, 64, cols_per_tile]
if ir.MemRefType(ref.type).shape != expected_shape:
raise ValueError(ref.type, (m, n))
for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle):
vector.store(get(self.registers), ref, idxs)
match self.layout:
case WGMMAFragLayout():
dtype = self.mlir_dtype
bw = mgpu.bytewidth(dtype)
m, n = self.shape
assert m % 64 == 0 # This is implied by the layout.
cols_per_tile = swizzle // bw
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
if n < cols_per_tile: # We allow singular tiles shorter than swizzle.
expected_shape = [m // 64, 1, 64, cols_per_tile]
if ir.MemRefType(ref.type).shape != expected_shape:
raise ValueError(ref.type, (m, n))
for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle):
vector.store(get(self.registers), ref, idxs)
case TiledLayout():
layout, shape = self.layout, self.shape
for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape):
llvm.store(get(self.registers), ptr)
case _:
raise NotImplementedError(self.layout)

@classmethod
def load_tiled(
cls, ref, swizzle: int | None, *, is_signed: bool | None = None
cls,
ref,
swizzle: int | None,
*,
is_signed: bool | None = None,
layout: FragmentedLayout = WGMMA_LAYOUT,
):
ref_ty = ir.MemRefType(ref.type)
dtype = ref_ty.element_type
bw = mgpu.bytewidth(dtype)
m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape
if m_tile_size != 64 or n_tile_size != (swizzle // bw):
raise ValueError
m, n = m_tiles * m_tile_size, n_tiles * n_tile_size
assert m % 64 == 0 # This is implied by the layout.
registers = np.full(
(m_tiles, n // 8, 2, 1),
vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)),
dtype=object,
)
for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle):
update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs))
return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed)
match layout:
case TiledLayout():
ref_ty = ir.MemRefType(ref.type)
tiled_shape = ref_ty.shape
if len(tiled_shape) % 2:
raise ValueError("Tiled reference must have even rank")
tiling = Tiling((tiled_shape[len(tiled_shape) // 2:],))
shape = tiling.untile_shape(tiled_shape)
registers = np.full(layout.registers_shape(shape), None, dtype=object)
reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type)
for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape):
update(registers, llvm.load(reg_ty, ptr))
assert all(r is not None for r in registers.flat)
case WGMMAFragLayout():
bw = mgpu.bytewidth(dtype)
m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape
if m_tile_size != 64 or n_tile_size != (swizzle // bw):
raise ValueError
m, n = m_tiles * m_tile_size, n_tiles * n_tile_size
assert m % 64 == 0 # This is implied by the layout.
registers = np.full(
(m_tiles, n // 8, 2, 1),
vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)),
dtype=object,
)
for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle):
update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs))
case _:
raise NotImplementedError(layout)
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)

@staticmethod
def transfer_tiled(shape, dtype, swizzle: int | None):
Expand Down Expand Up @@ -1393,6 +1446,99 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx):
regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new)
yield get_register, update_registers, idx

@staticmethod
def transfer_tiled2(
ref: ir.Value,
swizzle: int | None,
layout: TiledLayout,
shape: tuple[int, ...],
):
"""Generate a transfer schedule for a tiled layout.
Given a ref with one level tiling applied to it (we assume all dimensions
have been tiled), this function generates an iterable describing a good
schedule for swizzled SMEM loads/stores.
At each step, the iterable yields a tuple of three values:
* a function that takes a register array and returns the register to be
stored at the current address
* a function that takes a register array and a register loaded from the
current address, and updates the register array with that register
* the current address for load/store instructions
"""
# TODO(apaszke): Use ldmatrix/stmatrix when possible.
c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x)
tiling = layout.tiling

ref_ty = ir.MemRefType(ref.type)
dtype = ref_ty.element_type
if ref_ty.rank % 2:
raise ValueError("Tiled refence must have even rank")
ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:])
ref_tiling = Tiling((ref_tiling_shape,))
ref_strides, _ = ref_ty.get_strides_and_offset()
if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape:
raise ValueError()
if len(layout.base_tile_shape) > len(ref_tiling_shape):
raise ValueError("Memory tiling must be a multiple of the register tiling")
ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):]
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
raise ValueError("Memory tiling must be a multiple of the register tiling")

if swizzle not in {32, 64, 128}:
raise ValueError("Only swizzled transfers supported")
bw = mgpu.bytewidth(dtype)
swizzle_tile_elems = 16 // bw
swizzle_group_elems = 128 // bw
swizzle_groups_per_block = swizzle // 16
swizzle_block_elems = swizzle_groups_per_block * swizzle_group_elems

tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
if tiled_strides[layout.vector_dim] != 1:
raise ValueError("Stride of the vectorized dimension should be 1")
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
tiled_shape[d] = 1
full_tiling = Tiling((ref_tiling_shape, *tiling.tiles))
full_layout = dataclasses.replace(layout, tiling=full_tiling)

# XXX: This method is still slightly incompete. For example, it does not
# verify that the vector transfers don't cross swizzle tile boundaries. It
# also does not guarantee that the transfer pattern does not cause bank
# conflicts. For that reason, we only allow a select subset of layouts.
if layout != _tiled_wgmma_layout(shape) or bw > 2:
raise NotImplementedError("transfer_tiled2 not general enough yet")

dyn_tiled_strides = [c(s) for s in tiled_strides]
lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides)
warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides)
dyn_offset = arith.addi(lane_offset, warp_offset)
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
raise ValueError("Tiled stores can be performed into SMEM")
ptr = utils.memref_ptr(ref, memory_space=3)
for tile_idx in np.ndindex(*tiled_shape):
const_offset = sum(i * s for i, s in zip(tile_idx, tiled_strides))
# We split the offset into a part that interacts with swizzling and a
# part that doesn't. This lets us generate better code because constant
# offsets can be fused into load and store instructions.
const_offset_swizzle = const_offset % swizzle_block_elems
const_offset_no_swizzle = const_offset - const_offset_swizzle
offset_pre_swizzle = arith.addi(dyn_offset, c(const_offset_swizzle))
swizzle_group = arith.remui(
arith.divui(offset_pre_swizzle, c(swizzle_group_elems)),
c(swizzle_groups_per_block),
)
swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems))
offset = arith.xori(offset_pre_swizzle, swizzle_bits)
reg_ptr = utils.getelementptr(ptr, [offset], dtype)
reg_ptr = utils.getelementptr(reg_ptr, [const_offset_no_swizzle], dtype)
reg_idx = tiling.tile_indices(full_tiling.untile_indices(tile_idx))
def get_register(regs, reg_idx=reg_idx):
return regs[reg_idx]
def update_registers(regs, new, reg_idx=reg_idx):
regs[reg_idx] = new
yield get_register, update_registers, reg_ptr

def tree_flatten(self):
aux = self.layout, self.registers.shape, self.is_signed
return list(self.registers.flat), aux
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,7 @@ def getelementptr(
static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices]
dyn_indices = [i for i in indices if not isinstance(i, int)]
return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype)


def dyn_dot(x, y):
return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y)))
2 changes: 1 addition & 1 deletion tests/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jax_multiplatform_test(
"gpu_h100",
"gpu_h100_2gpu",
],
shard_count = 4,
shard_count = 8,
tags = ["multiaccelerator"],
deps = [
"//jax:mosaic_gpu",
Expand Down
59 changes: 59 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import itertools
import math
import operator
import os
import re
import unittest

from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -1627,6 +1629,63 @@ def kernel(ctx, dst, _):
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
np.testing.assert_array_equal(f(), expected)

@parameterized.product(
load_tiled=[False, True],
store_tiled=[False, True],
dtype=[jnp.int16],
swizzle=[32, 64, 128],
num_col_tiles=[1, 2, 4],
)
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles):
mlir_dtype = utils.dtype_to_ir_type(dtype)
col_tiling = swizzle // bytewidth(mlir_dtype)
m, n = 128, col_tiling * num_col_tiles
tiling = (64, col_tiling)
tiled_layout = fa._tiled_wgmma_layout((m, n))
load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT
store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=load_layout
)
t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
expected = (
np.arange(m * n, dtype=dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)

prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as get_sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
np.testing.assert_array_equal(iota, expected)

# Verify that we don't use too many registers for the transfers.
# We verify LDS and STS separately, because they might use two different
# methods of computing offsets and we don't rely on CSE between them.
register_pattern = re.compile(r"(R[0-9]+)")
expected_regs = swizzle // bytewidth(mlir_dtype) // 8
for instr in ("STS", "LDS"):
with self.subTest(instr + " count"):
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
chain = itertools.chain.from_iterable
used_regs = set(chain(register_pattern.findall(addr) for addr in addrs))
self.assertLen(used_regs, expected_regs)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 6a124ac

Please sign in to comment.