From cd78c653e77689633c01afd7880d5934c4913d4f Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 15 Oct 2024 03:26:28 -0700 Subject: [PATCH] [Pallas] Use core_map instead of shard_map for Shmallas - core_map is like a shard_map but it takes in no inputs and outputs - we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU) - we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target PiperOrigin-RevId: 686036101 --- jax/_src/dispatch.py | 2 +- jax/_src/mesh.py | 9 --- jax/_src/pallas/core.py | 81 ++++++++++++++++++++++++--- jax/_src/pallas/mosaic/core.py | 63 ++++++++++++++++++++- jax/_src/pallas/mosaic/lowering.py | 52 ----------------- jax/experimental/pallas/__init__.py | 2 + tests/pallas/tpu_pallas_state_test.py | 34 +++++------ 7 files changed, 148 insertions(+), 95 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index a1accf95ee74..179f8430febe 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -243,7 +243,7 @@ def get_intermediate_shardings( out.extend((i, source_info) for i in eqn.params['in_shardings']) out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - if not eqn.params['mesh']._is_jax_device_mesh: + if isinstance(eqn.params['mesh'], AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) def _names_to_pspec(names): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 3b4dd7ca4c2c..468de7ee57ab 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -270,11 +270,6 @@ def local_mesh(self): def _local_mesh(self, process_index): return _get_local_mesh(self, process_index) - @property - def _is_jax_device_mesh(self): - # Returns if the mesh contains JAX devices or not - return True - @functools.cached_property def device_ids(self): assert not self.empty @@ -377,10 +372,6 @@ def size(self): def shape(self): return collections.OrderedDict(self.shape_tuple) - @property - def _is_jax_device_mesh(self): - return False - @property def _internal_device_list(self): return None diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0ff463562355..56f513d05158 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -31,7 +31,6 @@ from jax._src import core as jax_core from jax._src import dtypes from jax._src import linear_util as lu -from jax._src import mesh as mesh_lib from jax._src import state from jax._src import tree_util from jax._src import util @@ -1019,14 +1018,6 @@ def pytreedef_mismatch_err_msg( return "\n".join(msg) -class PallasMesh(mesh_lib.Mesh): - """A specialized mesh used for lowering shard_map -> pallas_call.""" - - @property - def _is_jax_device_mesh(self): - return False - - @dataclasses.dataclass(frozen=True) class CostEstimate: flops: int @@ -1038,3 +1029,75 @@ def to_json(self) -> bytes: f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},' f' "bytes_accessed": {self.bytes_accessed}}}' ).encode("ascii") + + +core_map_p = jax_core.Primitive("core_map") +core_map_p.multiple_results = True + +def core_map(mesh): + """Runs a function on a mesh, mapping it over the devices in the mesh. + + The function should be stateful in that it takes in no inputs and returns + no outputs but can mutate closed-over Refs, for example. + """ + def wrapped(f): + flat_args, in_tree = tree_util.tree_flatten(((), {})) + flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) + with jax_core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) + out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh) + if out: + raise ValueError("core_map-ped functions must not return any outputs.") + return tree_util.tree_unflatten(out_tree_thunk(), out) + return wrapped + + +@core_map_p.def_effectful_abstract_eval +def _core_map_abstract_eval(*args, jaxpr, mesh): + del args + if jaxpr.outvars: + raise ValueError("core_map must not return any outputs.") + effs = set() + for eff in jaxpr.effects: + if not isinstance(eff, jax_core.NamedAxisEffect): + effs.add(eff) + continue + if eff.name not in mesh.shape: + effs.add(eff) + return [], effs + + +_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} +@state_discharge.register_discharge_rule(core_map_p) +def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs): + if type(mesh) not in _core_map_mesh_rules: + raise NotImplementedError(f"Mesh type {type(mesh)} not supported.") + return _core_map_mesh_rules[type(mesh)]( + in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs + ) + + +def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): + del in_atoms + with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())): + jax_core.check_jaxpr(jaxpr) + effs = set() + for eff in jaxpr.effects: + if not isinstance(eff, jax_core.NamedAxisEffect): + effs.add(eff) + continue + if eff.name not in mesh.shape: + effs.add(eff) + return [], effs +jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule + + +def _core_map_axis_subst(params, subst, traverse): + if not traverse: + return params + def shadowed_subst(name): + return (name,) if name in params['mesh'].shape else subst(name) + with jax_core.extend_axis_env_nd(params['mesh'].shape.items()): + new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) + return dict(params, jaxpr=new_jaxpr) +jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 57f1cad325bb..3407fe2e3435 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -15,6 +15,7 @@ """Contains TPU-specific Pallas abstractions.""" from __future__ import annotations +import collections from collections.abc import Sequence import dataclasses import enum @@ -27,6 +28,7 @@ from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call import jax.numpy as jnp import numpy as np @@ -208,14 +210,69 @@ class TensorCore: id: int -def create_tensorcore_mesh(axis_name: str) -> pallas_core.PallasMesh: +@dataclasses.dataclass(frozen=True) +class TensorCoreMesh: + """A mesh of TensorCores.""" + devices: np.ndarray + axis_names: Sequence[str] + + @property + def shape(self): + return collections.OrderedDict(zip(self.axis_names, self.devices.shape)) + + +def create_tensorcore_mesh( + axis_name: str, devices: Sequence[jax.Device] | None = None +) -> TensorCoreMesh: # TODO(b/355036384): emit a better error if we don't have tensorcores. - num_cores = jax.devices()[0].num_cores - return pallas_core.PallasMesh( + if devices is None: + devices = jax.devices() + num_cores = devices[0].num_cores + return TensorCoreMesh( np.array([TensorCore(i) for i in range(num_cores)]), [axis_name], ) + def runtime_assert_enabled() -> bool: """Returns whether runtime asserts are enabled.""" return _ENABLE_RUNTIME_ASSERT.value + + +def _tensorcore_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh, + jaxpr, +): + del out_avals + assert isinstance(mesh, TensorCoreMesh) + if len(mesh.shape) > 1: + raise NotImplementedError("Mesh must be 1D") + core_axis_name, num_cores = list(mesh.shape.items())[0] + def body(*args): + # Due to aliasing, args contains aliased inputs and outputs so we remove + # outputs. + in_refs = args[:len(in_avals)] + jax_core.eval_jaxpr(jaxpr, in_refs) + assert len(jaxpr.outvars) == 0 + out = pallas_call.pallas_call( + body, + out_shape=in_avals, + in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] + * len(in_avals), + out_specs=[pallas_core.BlockSpec( + memory_space=pallas_core.MemorySpace.ANY)] + * len(in_avals), + input_output_aliases={i: i for i in range(len(in_avals))}, + grid=((core_axis_name, num_cores),), + compiler_params=dict( + mosaic=dict(dimension_semantics=("parallel",)), + ), + )(*args) + return out, () + +pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( + _tensorcore_mesh_discharge_rule +) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index de77b71f544a..568bc20a1d44 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -50,7 +50,6 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core @@ -3078,54 +3077,3 @@ def _lower_fun(shape): lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering - -# Lowering for shard_map - -# Technically this is not a lowering rule, but a discharge rule. When we use -# a special pallas mesh for a shard_map inside of a run_state, we turn it into -# a pallas call. The pallas_call has named grid axes corresponding to the names -# in the pallas mesh. It also sets up input/output aliasing automatically. - -def _shard_map_discharge_rule( - in_avals, - out_avals, - *args, - mesh, - auto, - in_names, - out_names, - jaxpr, - check_rep, - rewrite, -): - del out_avals, auto, in_names, out_names, check_rep, rewrite - if not isinstance(mesh, pallas_core.PallasMesh): - raise NotImplementedError("Mesh must be a PallasMesh") - if len(mesh.shape) > 1: - raise NotImplementedError("Mesh must be 1D") - core_axis_name, num_cores = list(mesh.shape.items())[0] - def body(*args): - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, (), *in_refs) - assert len(jaxpr.outvars) == 0 - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - out_specs=[pallas_core.BlockSpec( - memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, - grid=((core_axis_name, num_cores),), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel",)), - ), - )(*args) - return out, () - - -from jax.experimental import shard_map -state_discharge.register_discharge_rule(shard_map.shard_map_p)( - _shard_map_discharge_rule -) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 0a82137f8dd6..3eee14fea516 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -21,6 +21,7 @@ from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec from jax._src.pallas.core import CompilerParams +from jax._src.pallas.core import core_map from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import GridSpec from jax._src.pallas.core import IndexingMode @@ -53,6 +54,7 @@ from jax._src.pallas.utils import next_power_of_2 from jax._src.pallas.utils import strides_from_shape from jax._src.pallas.utils import when +from jax._src.state.discharge import run_state from jax._src.state.indexing import ds from jax._src.state.indexing import dslice from jax._src.state.indexing import Slice diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index b017cac2fba0..ab3a82dab09f 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -17,9 +17,7 @@ from absl.testing import absltest import jax from jax._src import test_util as jtu -from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -51,7 +49,7 @@ def f_stateful(refs): @jax.jit def f(x): - _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + _, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x))) return y x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) @@ -73,7 +71,7 @@ def f_stateful(refs): @jax.jit def f(x): - _, y = state_discharge.run_state(f_stateful)((x, jnp.zeros_like(x))) + _, y = pl.run_state(f_stateful)((x, jnp.zeros_like(x))) return y x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) @@ -101,7 +99,7 @@ def f_stateful(refs): @jax.jit def f(x): - _, y = state_discharge.run_state(f_stateful)((x[None], jnp.zeros_like(x))) + _, y = pl.run_state(f_stateful)((x[None], jnp.zeros_like(x))) return y x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) @@ -128,7 +126,7 @@ def f_stateful(refs): @jax.jit def f(x): - _, y, o = state_discharge.run_state(f_stateful)( + _, y, o = pl.run_state(f_stateful)( (x, jnp.zeros_like(x), jnp.zeros_like(x)) ) return y, o @@ -178,7 +176,7 @@ def matmul_pipeline_kernel(acc_ref): scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)], )() - _, _, o = state_discharge.run_state(run_matmul)( + _, _, o = pl.run_state(run_matmul)( (x, y, jnp.ones((m, n), dtype=x.dtype)) ) return o @@ -202,11 +200,7 @@ def setUp(self): def test_can_create_tensorcore_mesh(self): _ = pltpu.create_tensorcore_mesh("x") - def test_can_trivially_shard_map_with_pallas_mesh(self): - mesh = pltpu.create_tensorcore_mesh("x") - _ = shard_map.shard_map(lambda: None, mesh, in_specs=(), out_specs=None)() - - def test_can_run_basic_pallas_kernel_with_shard_map(self): + def test_can_run_basic_pallas_kernel_with_core_map(self): mesh = pltpu.create_tensorcore_mesh("x") @jax.jit @@ -214,19 +208,18 @@ def f(x): y = jnp.zeros_like(x) def inner(refs): x_ref, y_ref = refs - def kernel(): + @pl.core_map(mesh) + def _(): def alloc(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() pl.run_scoped(alloc, pltpu.SemaphoreType.DMA) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) + _, y = pl.run_state(inner)((x, y)) return y x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) y = f(x) np.testing.assert_array_equal(y, x) - def test_can_query_core_index_pallas_kernel_with_shard_map(self): + def test_can_query_core_index_pallas_kernel_with_core_map(self): mesh = pltpu.create_tensorcore_mesh("x") @jax.jit @@ -234,7 +227,8 @@ def f(x): y = jnp.zeros_like(x) def inner(refs): x_ref, y_ref = refs - def kernel(): + @pl.core_map(mesh) + def _(): num_cores = jax.lax.psum(1, "x") slc_size = 16 // num_cores def alloc(x_vmem_ref, y_vmem_ref, sem): @@ -254,9 +248,7 @@ def alloc(x_vmem_ref, y_vmem_ref, sem): pltpu.VMEM((slc_size, 128), y_ref.dtype), pltpu.SemaphoreType.DMA, ) - shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, - check_rep=False)() - _, y = state_discharge.run_state(inner)((x, y)) + _, y = pl.run_state(inner)((x, y)) return y num_cores = jax.devices()[0].num_cores x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128))