From 9a0e9e55d81e8ea1b1fd2fa4eaf67074f5908bec Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 14 Nov 2024 17:31:16 -0800 Subject: [PATCH] [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`. Also lower shardings with `Collective` axes correctly to HloSharding. PiperOrigin-RevId: 696703030 --- jax/_src/interpreters/mlir.py | 14 ++++++++++++ jax/_src/lax/lax.py | 41 ++++++++++------------------------- jax/_src/lax/parallel.py | 40 +++++++++++++++++++--------------- jax/_src/mesh.py | 19 ++++++++++++++++ jax/_src/sharding_impls.py | 7 ++++-- tests/pjit_test.py | 26 ++++++++++++++++++++++ 6 files changed, 99 insertions(+), 48 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bef465c6aa75..ee3c929b26f7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2474,6 +2474,20 @@ def _wrap_with_spmd_op(name: str, wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape") +def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): + if sharding_proto is None: + proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + else: + proto = sharding_proto + # TODO(yashkatariya): Setting all axes as unspecified should work even when + # any axes is Collective because that's what happens in partial auto shmap. + # Do that after tests for it exists. + unspecified_dims = (set(range(aval.ndim)) + if aval.sharding.mesh.are_all_axes_collective else None) + return wrap_with_sharding_op( + ctx, op, aval, proto, unspecified_dims=unspecified_dims) + + def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c45d8f5c80b2..b780aab870e9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2203,14 +2203,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) - elif in_aval.sharding.mesh.are_all_axes_collective: - out.append(op) else: - # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains - # CompilerShardingAxis, then specify `unspecified_dims` via - # `wrap_with_sharding_op`. - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() - out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto)) return out @@ -2226,10 +2221,7 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if aval_out.sharding.mesh.are_all_axes_collective: - return [out] - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] else: return [out] @@ -2646,8 +2638,7 @@ def _integer_pow_lowering(ctx, x, *, y): out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -3029,8 +3020,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, if config.sharding_in_types.value: if sharding is not None: assert aval_out.sharding == sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3765,8 +3755,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): if config.sharding_in_types.value: if out_type is not None: assert aval_out.sharding == out_type - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) + result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) return [result] @@ -4231,8 +4220,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, if config.sharding_in_types.value: if sharding is not None: assert sharding == aval_out.sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, @@ -4645,8 +4633,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _reshape_staging_rule( @@ -4726,8 +4713,7 @@ def _transpose_lower(ctx, x, *, permutation): permutation = [*permutation, *trailing_dims] out = hlo.transpose(x, mlir.dense_int_array(permutation)) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] transpose_p = standard_primitive( @@ -4868,8 +4854,7 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): def _add_shit_to_select(ctx, op, aval_out): if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto) + return mlir.lower_sharding_under_shit(ctx, op, aval_out) return op def _select_hlo_lowering(ctx, which, *cases): @@ -5241,8 +5226,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): with ir.InsertionPoint(reducer_region): hlo.return_([reducer(*reducer_region.arguments)]) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] return op.results mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, @@ -5941,8 +5925,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): out = mlir.iota(ctx, aval_out, dimension=dimension) if config.sharding_in_types.value: assert aval_out.sharding == sharding - proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(iota_p, _iota_lower) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3a1c1ef3bcf1..c8cea6a9df5b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,9 +24,11 @@ from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes -from jax._src import sharding_impls +from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, + NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -635,9 +637,15 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), - arg.dtype) for arg in args] + if config.sharding_in_types.value: + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) + for arg in args + ] + else: + out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) + for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _check_axis_names(axes): @@ -673,10 +681,7 @@ def _positional_reduce(aval, arg): _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) def all_reduce(aval, x): if is_spmd: @@ -694,7 +699,11 @@ def all_reduce(aval, x): else: op = hlo.AllReduceOp( [x.type], [x], replica_groups=replica_groups, **other_args) - scalar_aval = core.ShapedArray((), aval.dtype) + if config.sharding_in_types.value: + scalar_aval = core.ShapedArray( + (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) + else: + scalar_aval = core.ShapedArray((), aval.dtype) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_block): @@ -778,7 +787,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): axis_context = ctx.module_context.axis_context is_manual = ( - isinstance(axis_context, sharding_impls.SPMDAxisContext) + isinstance(axis_context, SPMDAxisContext) and axis_context.manual_axes ) if is_manual: @@ -896,7 +905,7 @@ def _all_to_all_lowering( raise ValueError('Replica groups must be equally sized') is_spmd = isinstance( ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1129,10 +1138,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, x_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) if not tiled: new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) @@ -1260,7 +1266,7 @@ def _reduce_scatter_lowering( axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1489,7 +1495,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: device_id = hlo.partition_id() diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 082c443fade4..6c6017c4b2b7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -107,6 +107,17 @@ class AxisTypes(enum.Enum): User = enum.auto() Collective = enum.auto() +def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: + if axis_types is None: + return {} + d = {} + for t, names in axis_types.items(): + if isinstance(names, tuple): + for n in names: + d[n] = t + else: + d[names] = t + return d _mesh_object_dict = {} # type: ignore @@ -269,6 +280,10 @@ def shape_tuple(self): def axis_sizes(self) -> tuple[int, ...]: return self.devices.shape + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @property def size(self): return math.prod(self.shape.values()) if self.devices.ndim else 0 @@ -390,6 +405,10 @@ def axis_names(self): def axis_sizes(self) -> tuple[int, ...]: return self._axis_sizes + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @functools.cached_property def size(self): return math.prod(self._axis_sizes) if self._axis_sizes else 0 diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9b847f15d86a..8957a6186339 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -137,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - if self._manual_axes: + mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() + if t == mesh_lib.AxisTypes.Collective} + manual_axes = self._manual_axes.union(mesh_manual_axes) + if manual_axes: axis_names = self.mesh.axis_names - for manual_axis in self._manual_axes: + for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL replicated_mesh_axes = [] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8a63bbe39099..7196a6335960 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5225,6 +5225,32 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_shard_map_dot(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axes_collective) + self.assertTrue(y.sharding.mesh.are_all_axes_collective) + allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) + z = x @ allgatherd_y + return jax.lax.psum(z, axis_name='y') + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', None))(x, y) + self.assertEqual(z.sharding.spec, P('x', None)) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):