Skip to content

Commit

Permalink
[Pallas] Use core_map instead of shard_map for Shmallas
Browse files Browse the repository at this point in the history
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
  • Loading branch information
sharadmv authored and Google-ML-Automation committed Oct 15, 2024
1 parent b076890 commit cd78c65
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 95 deletions.
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_intermediate_shardings(
out.extend((i, source_info) for i in eqn.params['in_shardings'])
out.extend((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
if not eqn.params['mesh']._is_jax_device_mesh:
if isinstance(eqn.params['mesh'], AbstractMesh):
continue
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
def _names_to_pspec(names):
Expand Down
9 changes: 0 additions & 9 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,6 @@ def local_mesh(self):
def _local_mesh(self, process_index):
return _get_local_mesh(self, process_index)

@property
def _is_jax_device_mesh(self):
# Returns if the mesh contains JAX devices or not
return True

@functools.cached_property
def device_ids(self):
assert not self.empty
Expand Down Expand Up @@ -377,10 +372,6 @@ def size(self):
def shape(self):
return collections.OrderedDict(self.shape_tuple)

@property
def _is_jax_device_mesh(self):
return False

@property
def _internal_device_list(self):
return None
Expand Down
81 changes: 72 additions & 9 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import state
from jax._src import tree_util
from jax._src import util
Expand Down Expand Up @@ -1019,14 +1018,6 @@ def pytreedef_mismatch_err_msg(
return "\n".join(msg)


class PallasMesh(mesh_lib.Mesh):
"""A specialized mesh used for lowering shard_map -> pallas_call."""

@property
def _is_jax_device_mesh(self):
return False


@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
Expand All @@ -1038,3 +1029,75 @@ def to_json(self) -> bytes:
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode("ascii")


core_map_p = jax_core.Primitive("core_map")
core_map_p.multiple_results = True

def core_map(mesh):
"""Runs a function on a mesh, mapping it over the devices in the mesh.
The function should be stateful in that it takes in no inputs and returns
no outputs but can mutate closed-over Refs, for example.
"""
def wrapped(f):
flat_args, in_tree = tree_util.tree_flatten(((), {}))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
with jax_core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh)
if out:
raise ValueError("core_map-ped functions must not return any outputs.")
return tree_util.tree_unflatten(out_tree_thunk(), out)
return wrapped


@core_map_p.def_effectful_abstract_eval
def _core_map_abstract_eval(*args, jaxpr, mesh):
del args
if jaxpr.outvars:
raise ValueError("core_map must not return any outputs.")
effs = set()
for eff in jaxpr.effects:
if not isinstance(eff, jax_core.NamedAxisEffect):
effs.add(eff)
continue
if eff.name not in mesh.shape:
effs.add(eff)
return [], effs


_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
@state_discharge.register_discharge_rule(core_map_p)
def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs):
if type(mesh) not in _core_map_mesh_rules:
raise NotImplementedError(f"Mesh type {type(mesh)} not supported.")
return _core_map_mesh_rules[type(mesh)](
in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs
)


def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh):
del in_atoms
with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())):
jax_core.check_jaxpr(jaxpr)
effs = set()
for eff in jaxpr.effects:
if not isinstance(eff, jax_core.NamedAxisEffect):
effs.add(eff)
continue
if eff.name not in mesh.shape:
effs.add(eff)
return [], effs
jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule


def _core_map_axis_subst(params, subst, traverse):
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with jax_core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst
63 changes: 60 additions & 3 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Contains TPU-specific Pallas abstractions."""
from __future__ import annotations

import collections
from collections.abc import Sequence
import dataclasses
import enum
Expand All @@ -27,6 +28,7 @@
from jax._src import dtypes
from jax._src import util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -208,14 +210,69 @@ class TensorCore:
id: int


def create_tensorcore_mesh(axis_name: str) -> pallas_core.PallasMesh:
@dataclasses.dataclass(frozen=True)
class TensorCoreMesh:
"""A mesh of TensorCores."""
devices: np.ndarray
axis_names: Sequence[str]

@property
def shape(self):
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))


def create_tensorcore_mesh(
axis_name: str, devices: Sequence[jax.Device] | None = None
) -> TensorCoreMesh:
# TODO(b/355036384): emit a better error if we don't have tensorcores.
num_cores = jax.devices()[0].num_cores
return pallas_core.PallasMesh(
if devices is None:
devices = jax.devices()
num_cores = devices[0].num_cores
return TensorCoreMesh(
np.array([TensorCore(i) for i in range(num_cores)]),
[axis_name],
)


def runtime_assert_enabled() -> bool:
"""Returns whether runtime asserts are enabled."""
return _ENABLE_RUNTIME_ASSERT.value


def _tensorcore_mesh_discharge_rule(
in_avals,
out_avals,
*args,
mesh,
jaxpr,
):
del out_avals
assert isinstance(mesh, TensorCoreMesh)
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)
assert len(jaxpr.outvars) == 0
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
out_specs=[pallas_core.BlockSpec(
memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
)(*args)
return out, ()

pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
_tensorcore_mesh_discharge_rule
)
52 changes: 0 additions & 52 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
Expand Down Expand Up @@ -3078,54 +3077,3 @@ def _lower_fun(shape):


lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering

# Lowering for shard_map

# Technically this is not a lowering rule, but a discharge rule. When we use
# a special pallas mesh for a shard_map inside of a run_state, we turn it into
# a pallas call. The pallas_call has named grid axes corresponding to the names
# in the pallas mesh. It also sets up input/output aliasing automatically.

def _shard_map_discharge_rule(
in_avals,
out_avals,
*args,
mesh,
auto,
in_names,
out_names,
jaxpr,
check_rep,
rewrite,
):
del out_avals, auto, in_names, out_names, check_rep, rewrite
if not isinstance(mesh, pallas_core.PallasMesh):
raise NotImplementedError("Mesh must be a PallasMesh")
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
def body(*args):
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, (), *in_refs)
assert len(jaxpr.outvars) == 0
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
out_specs=[pallas_core.BlockSpec(
memory_space=pallas_core.MemorySpace.ANY)]
* len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((core_axis_name, num_cores),),
compiler_params=dict(
mosaic=dict(dimension_semantics=("parallel",)),
),
)(*args)
return out, ()


from jax.experimental import shard_map
state_discharge.register_discharge_rule(shard_map.shard_map_p)(
_shard_map_discharge_rule
)
2 changes: 2 additions & 0 deletions jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import core_map
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import GridSpec
from jax._src.pallas.core import IndexingMode
Expand Down Expand Up @@ -53,6 +54,7 @@
from jax._src.pallas.utils import next_power_of_2
from jax._src.pallas.utils import strides_from_shape
from jax._src.pallas.utils import when
from jax._src.state.discharge import run_state
from jax._src.state.indexing import ds
from jax._src.state.indexing import dslice
from jax._src.state.indexing import Slice
Expand Down
Loading

0 comments on commit cd78c65

Please sign in to comment.