diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a9d12706ff47..040174b900c9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -19,7 +19,7 @@ import dataclasses import functools import math -from typing import Sequence, TypeVar, Iterable +from typing import Iterable, Sequence, TypeVar import jax from jaxlib.mlir import ir @@ -110,6 +110,23 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in self.tiles: + untiled, tiled = indices[:-len(tile)], indices[-len(tile):] + indices = ( + *untiled, + *(i // t for i, t in zip(tiled, tile)), + *(i % t for i, t in zip(tiled, tile)), + ) + return indices + + def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in reversed(self.tiles): + untiled = indices[:-2 * len(tile)] + outer = indices[-2 * len(tile):-len(tile)] + inner = indices[-len(tile):] + indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) + return indices def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: """Like built-in enumerate, but returns negative indices into the sequence.""" @@ -185,6 +202,15 @@ def __post_init__(self): if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: raise ValueError + @property + def base_tile_shape(self) -> int: + """The shape of the first tile in the tiling expression. + + This tile acts as the divisibility constraint for a suffix of arrays to + which this layout applies. + """ + return self.tiling.tiles[0] + @functools.cached_property def tiled_tiling_shape(self) -> tuple[int, ...]: """The shape of the suffix of the array after tiling. @@ -194,7 +220,7 @@ def tiled_tiling_shape(self) -> tuple[int, ...]: so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ - return self.tiling.tile_shape(self.tiling.tiles[0]) + return self.tiling.tile_shape(self.base_tile_shape) @property def vector_length(self) -> int: @@ -231,6 +257,8 @@ def lane_indices(self) -> tuple[ir.Value, ...]: assert math.prod(tiled_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(tiled_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) + # TODO(apaszke): Rewrite so that we can be sure that this never actually + # does arithmetic for any dimensions that are not in lane_dims. return tuple( arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, tiled_shape) @@ -1260,10 +1288,8 @@ def _store_untiled_tiled(self, ref: ir.Value): ptr = utils.memref_ptr(ref) # Fold warp and lane offsets into the pointer once, since they are dynamic. dyn_strides = [arith.constant(i32, s) for s in strides] - def dyn_dot(x, y): - return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) - warp_offset = dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = dyn_dot(layout.lane_indices(), dyn_strides) + warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) + lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) dyn_offset = arith.addi(warp_offset, lane_offset) ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) # All warp tile offsets are static and can be fused into the store. @@ -1273,41 +1299,68 @@ def dyn_dot(x, y): llvm.store(reg, reg_ptr) def store_tiled(self, ref, swizzle: int | None): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError - dtype = self.mlir_dtype - bw = mgpu.bytewidth(dtype) - m, n = self.shape - assert m % 64 == 0 # This is implied by the layout. - cols_per_tile = swizzle // bw - expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] - if n < cols_per_tile: # We allow singular tiles shorter than swizzle. - expected_shape = [m // 64, 1, 64, cols_per_tile] - if ir.MemRefType(ref.type).shape != expected_shape: - raise ValueError(ref.type, (m, n)) - for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): - vector.store(get(self.registers), ref, idxs) + match self.layout: + case WGMMAFragLayout(): + dtype = self.mlir_dtype + bw = mgpu.bytewidth(dtype) + m, n = self.shape + assert m % 64 == 0 # This is implied by the layout. + cols_per_tile = swizzle // bw + expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] + if ir.MemRefType(ref.type).shape != expected_shape: + raise ValueError(ref.type, (m, n)) + for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): + vector.store(get(self.registers), ref, idxs) + case TiledLayout(): + layout, shape = self.layout, self.shape + for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + llvm.store(get(self.registers), ptr) + case _: + raise NotImplementedError(self.layout) @classmethod def load_tiled( - cls, ref, swizzle: int | None, *, is_signed: bool | None = None + cls, + ref, + swizzle: int | None, + *, + is_signed: bool | None = None, + layout: FragmentedLayout = WGMMA_LAYOUT, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type - bw = mgpu.bytewidth(dtype) - m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape - if m_tile_size != 64 or n_tile_size != (swizzle // bw): - raise ValueError - m, n = m_tiles * m_tile_size, n_tiles * n_tile_size - assert m % 64 == 0 # This is implied by the layout. - registers = np.full( - (m_tiles, n // 8, 2, 1), - vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), - dtype=object, - ) - for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): - update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) - return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed) + match layout: + case TiledLayout(): + ref_ty = ir.MemRefType(ref.type) + tiled_shape = ref_ty.shape + if len(tiled_shape) % 2: + raise ValueError("Tiled reference must have even rank") + tiling = Tiling((tiled_shape[len(tiled_shape) // 2:],)) + shape = tiling.untile_shape(tiled_shape) + registers = np.full(layout.registers_shape(shape), None, dtype=object) + reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) + for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + update(registers, llvm.load(reg_ty, ptr)) + assert all(r is not None for r in registers.flat) + case WGMMAFragLayout(): + bw = mgpu.bytewidth(dtype) + m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape + if m_tile_size != 64 or n_tile_size != (swizzle // bw): + raise ValueError + m, n = m_tiles * m_tile_size, n_tiles * n_tile_size + assert m % 64 == 0 # This is implied by the layout. + registers = np.full( + (m_tiles, n // 8, 2, 1), + vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), + dtype=object, + ) + for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): + update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) + case _: + raise NotImplementedError(layout) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): @@ -1393,6 +1446,99 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new) yield get_register, update_registers, idx + @staticmethod + def transfer_tiled2( + ref: ir.Value, + swizzle: int | None, + layout: TiledLayout, + shape: tuple[int, ...], + ): + """Generate a transfer schedule for a tiled layout. + + Given a ref with one level tiling applied to it (we assume all dimensions + have been tiled), this function generates an iterable describing a good + schedule for swizzled SMEM loads/stores. + + At each step, the iterable yields a tuple of three values: + * a function that takes a register array and returns the register to be + stored at the current address + * a function that takes a register array and a register loaded from the + current address, and updates the register array with that register + * the current address for load/store instructions + """ + # TODO(apaszke): Use ldmatrix/stmatrix when possible. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + tiling = layout.tiling + + ref_ty = ir.MemRefType(ref.type) + dtype = ref_ty.element_type + if ref_ty.rank % 2: + raise ValueError("Tiled refence must have even rank") + ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:]) + ref_tiling = Tiling((ref_tiling_shape,)) + ref_strides, _ = ref_ty.get_strides_and_offset() + if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: + raise ValueError() + if len(layout.base_tile_shape) > len(ref_tiling_shape): + raise ValueError("Memory tiling must be a multiple of the register tiling") + ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):] + if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)): + raise ValueError("Memory tiling must be a multiple of the register tiling") + + if swizzle not in {32, 64, 128}: + raise ValueError("Only swizzled transfers supported") + bw = mgpu.bytewidth(dtype) + swizzle_tile_elems = 16 // bw + swizzle_group_elems = 128 // bw + swizzle_groups_per_block = swizzle // 16 + swizzle_block_elems = swizzle_groups_per_block * swizzle_group_elems + + tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) + tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + if tiled_strides[layout.vector_dim] != 1: + raise ValueError("Stride of the vectorized dimension should be 1") + for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + tiled_shape[d] = 1 + full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) + full_layout = dataclasses.replace(layout, tiling=full_tiling) + + # XXX: This method is still slightly incompete. For example, it does not + # verify that the vector transfers don't cross swizzle tile boundaries. It + # also does not guarantee that the transfer pattern does not cause bank + # conflicts. For that reason, we only allow a select subset of layouts. + if layout != _tiled_wgmma_layout(shape) or bw > 2: + raise NotImplementedError("transfer_tiled2 not general enough yet") + + dyn_tiled_strides = [c(s) for s in tiled_strides] + lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) + warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides) + dyn_offset = arith.addi(lane_offset, warp_offset) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError("Tiled stores can be performed into SMEM") + ptr = utils.memref_ptr(ref, memory_space=3) + for tile_idx in np.ndindex(*tiled_shape): + const_offset = sum(i * s for i, s in zip(tile_idx, tiled_strides)) + # We split the offset into a part that interacts with swizzling and a + # part that doesn't. This lets us generate better code because constant + # offsets can be fused into load and store instructions. + const_offset_swizzle = const_offset % swizzle_block_elems + const_offset_no_swizzle = const_offset - const_offset_swizzle + offset_pre_swizzle = arith.addi(dyn_offset, c(const_offset_swizzle)) + swizzle_group = arith.remui( + arith.divui(offset_pre_swizzle, c(swizzle_group_elems)), + c(swizzle_groups_per_block), + ) + swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems)) + offset = arith.xori(offset_pre_swizzle, swizzle_bits) + reg_ptr = utils.getelementptr(ptr, [offset], dtype) + reg_ptr = utils.getelementptr(reg_ptr, [const_offset_no_swizzle], dtype) + reg_idx = tiling.tile_indices(full_tiling.untile_indices(tile_idx)) + def get_register(regs, reg_idx=reg_idx): + return regs[reg_idx] + def update_registers(regs, new, reg_idx=reg_idx): + regs[reg_idx] = new + yield get_register, update_registers, reg_ptr + def tree_flatten(self): aux = self.layout, self.registers.shape, self.is_signed return list(self.registers.flat), aux diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index f8918488563e..b716456eceb3 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1047,3 +1047,7 @@ def getelementptr( static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] dyn_indices = [i for i in indices if not isinstance(i, int)] return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + + +def dyn_dot(x, y): + return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index ca2c9a4bf27d..6ea9c02b9639 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -37,7 +37,7 @@ jax_multiplatform_test( "gpu_h100", "gpu_h100_2gpu", ], - shard_count = 4, + shard_count = 8, tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 382cad79fb2b..157f682f5eef 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -18,6 +18,8 @@ import itertools import math import operator +import os +import re import unittest from absl.testing import absltest, parameterized @@ -1627,6 +1629,63 @@ def kernel(ctx, dst, _): expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) np.testing.assert_array_equal(f(), expected) + @parameterized.product( + load_tiled=[False, True], + store_tiled=[False, True], + dtype=[jnp.int16], + swizzle=[32, 64, 128], + num_col_tiles=[1, 2, 4], + ) + def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): + mlir_dtype = utils.dtype_to_ir_type(dtype) + col_tiling = swizzle // bytewidth(mlir_dtype) + m, n = 128, col_tiling * num_col_tiles + tiling = (64, col_tiling) + tiled_layout = fa._tiled_wgmma_layout((m, n)) + load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT + store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=load_layout + ) + t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + expected = ( + np.arange(m * n, dtype=dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + + prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) + os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" + try: + with jtu.capture_stdout() as get_sass: + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, + [expected, expected, mgpu.TMABarrier()], + )(expected) + finally: + if prev_dump is not None: + os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump + np.testing.assert_array_equal(iota, expected) + + # Verify that we don't use too many registers for the transfers. + # We verify LDS and STS separately, because they might use two different + # methods of computing offsets and we don't rely on CSE between them. + register_pattern = re.compile(r"(R[0-9]+)") + expected_regs = swizzle // bytewidth(mlir_dtype) // 8 + for instr in ("STS", "LDS"): + with self.subTest(instr + " count"): + addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) + chain = itertools.chain.from_iterable + used_regs = set(chain(register_pattern.findall(addr) for addr in addrs)) + self.assertLen(used_regs, expected_regs) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())