Skip to content

Commit

Permalink
[sharding_in_types] Handle collective axes in lowering rules more gen…
Browse files Browse the repository at this point in the history
…erally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`.

Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Nov 15, 2024
1 parent 4511f0c commit 9a0e9e5
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 48 deletions.
14 changes: 14 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2474,6 +2474,20 @@ def _wrap_with_spmd_op(name: str,
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")


def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
if sharding_proto is None:
proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
else:
proto = sharding_proto
# TODO(yashkatariya): Setting all axes as unspecified should work even when
# any axes is Collective because that's what happens in partial auto shmap.
# Do that after tests for it exists.
unspecified_dims = (set(range(aval.ndim))
if aval.sharding.mesh.are_all_axes_collective else None)
return wrap_with_sharding_op(
ctx, op, aval, proto, unspecified_dims=unspecified_dims)


def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
if config.use_shardy_partitioner.value:
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
Expand Down
41 changes: 12 additions & 29 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,14 +2203,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
for op, in_aval in zip(ops, in_avals):
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
out.append(op)
elif in_aval.sharding.mesh.are_all_axes_collective:
out.append(op)
else:
# TODO(yashkatariya, dougalm): If `in_aval.sharding` contains
# CompilerShardingAxis, then specify `unspecified_dims` via
# `wrap_with_sharding_op`.
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp))
proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto))
return out


Expand All @@ -2226,10 +2221,7 @@ def _nary_lower_hlo(op: Callable, ctx,

out = op(*args)
if config.sharding_in_types.value:
if aval_out.sharding.mesh.are_all_axes_collective:
return [out]
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
else:
return [out]

Expand Down Expand Up @@ -2646,8 +2638,7 @@ def _integer_pow_lowering(ctx, x, *, y):
out, = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
Expand Down Expand Up @@ -3029,8 +3020,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
if config.sharding_in_types.value:
if sharding is not None:
assert aval_out.sharding == sharding
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
Expand Down Expand Up @@ -3765,8 +3755,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
if config.sharding_in_types.value:
if out_type is not None:
assert aval_out.sharding == out_type
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
return [result]
Expand Down Expand Up @@ -4231,8 +4220,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
Expand Down Expand Up @@ -4645,8 +4633,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
out = mlir.reshape(ctx, x, aval_out)
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

def _reshape_staging_rule(
Expand Down Expand Up @@ -4726,8 +4713,7 @@ def _transpose_lower(ctx, x, *, permutation):
permutation = [*permutation, *trailing_dims]
out = hlo.transpose(x, mlir.dense_int_array(permutation))
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

transpose_p = standard_primitive(
Expand Down Expand Up @@ -4868,8 +4854,7 @@ def _select_hlo_lowering_opaque(ctx, which, *cases):

def _add_shit_to_select(ctx, op, aval_out):
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto)
return mlir.lower_sharding_under_shit(ctx, op, aval_out)
return op

def _select_hlo_lowering(ctx, which, *cases):
Expand Down Expand Up @@ -5241,8 +5226,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
with ir.InsertionPoint(reducer_region):
hlo.return_([reducer(*reducer_region.arguments)])
if config.sharding_in_types.value:
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)]
return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)]
return op.results

mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
Expand Down Expand Up @@ -5941,8 +5925,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
out = mlir.iota(ctx, aval_out, dimension=dimension)
if config.sharding_in_types.value:
assert aval_out.sharding == sharding
proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(iota_p, _iota_lower)

Expand Down
40 changes: 23 additions & 17 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

from jax import tree_util
from jax._src import core
from jax._src import config
from jax._src import dispatch
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
NamedSharding, PartitionSpec as P)
from jax._src.core import AxisName, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -635,9 +637,15 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
if len(pos_axes) != 0:
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes),
arg.dtype) for arg in args]
if config.sharding_in_types.value:
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
for arg in args
]
else:
out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype)
for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}

def _check_axis_names(axes):
Expand Down Expand Up @@ -673,10 +681,7 @@ def _positional_reduce(aval, arg):
_replica_groups(ctx.module_context.axis_env, named_axes,
axis_index_groups))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))

def all_reduce(aval, x):
if is_spmd:
Expand All @@ -694,7 +699,11 @@ def all_reduce(aval, x):
else:
op = hlo.AllReduceOp(
[x.type], [x], replica_groups=replica_groups, **other_args)
scalar_aval = core.ShapedArray((), aval.dtype)
if config.sharding_in_types.value:
scalar_aval = core.ShapedArray(
(), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P()))
else:
scalar_aval = core.ShapedArray((), aval.dtype)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
Expand Down Expand Up @@ -778,7 +787,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):

axis_context = ctx.module_context.axis_context
is_manual = (
isinstance(axis_context, sharding_impls.SPMDAxisContext)
isinstance(axis_context, SPMDAxisContext)
and axis_context.manual_axes
)
if is_manual:
Expand Down Expand Up @@ -896,7 +905,7 @@ def _all_to_all_lowering(
raise ValueError('Replica groups must be equally sized')
is_spmd = isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
Expand Down Expand Up @@ -1129,10 +1138,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
x_aval, = ctx.avals_in
out_aval, = ctx.avals_out
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))
if not tiled:
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
Expand Down Expand Up @@ -1260,7 +1266,7 @@ def _reduce_scatter_lowering(
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
Expand Down Expand Up @@ -1489,7 +1495,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
device_id = hlo.partition_id()
Expand Down
19 changes: 19 additions & 0 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ class AxisTypes(enum.Enum):
User = enum.auto()
Collective = enum.auto()

def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
if axis_types is None:
return {}
d = {}
for t, names in axis_types.items():
if isinstance(names, tuple):
for n in names:
d[n] = t
else:
d[names] = t
return d

_mesh_object_dict = {} # type: ignore

Expand Down Expand Up @@ -269,6 +280,10 @@ def shape_tuple(self):
def axis_sizes(self) -> tuple[int, ...]:
return self.devices.shape

@functools.cached_property
def _name_to_type(self):
return axis_names_to_types(self.axis_types)

@property
def size(self):
return math.prod(self.shape.values()) if self.devices.ndim else 0
Expand Down Expand Up @@ -390,6 +405,10 @@ def axis_names(self):
def axis_sizes(self) -> tuple[int, ...]:
return self._axis_sizes

@functools.cached_property
def _name_to_type(self):
return axis_names_to_types(self.axis_types)

@functools.cached_property
def size(self):
return math.prod(self._axis_sizes) if self._axis_sizes else 0
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding(
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}

special_axes = {}
if self._manual_axes:
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
if t == mesh_lib.AxisTypes.Collective}
manual_axes = self._manual_axes.union(mesh_manual_axes)
if manual_axes:
axis_names = self.mesh.axis_names
for manual_axis in self._manual_axes:
for manual_axis in manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL

replicated_mesh_axes = []
Expand Down
26 changes: 26 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5225,6 +5225,32 @@ def f(x, y):
self.assertArraysEqual(out, (np_inp * np_inp) * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))

def test_shard_map_dot(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))

def g(x, y):
self.assertTrue(x.sharding.mesh.are_all_axes_collective)
self.assertTrue(y.sharding.mesh.are_all_axes_collective)
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
z = x @ allgatherd_y
return jax.lax.psum(z, axis_name='y')

@jax.jit
def f(x, y):
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
out_specs=P('x', None))(x, y)
self.assertEqual(z.sharding.spec, P('x', None))
out = z * 2
self.assertEqual(out.sharding.spec, P('x', None))
return out

out = f(arr, arr2)
self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 9a0e9e5

Please sign in to comment.