diff --git a/jax/__init__.py b/jax/__init__.py index b9d82d9f926c..d24ec60e1057 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -126,7 +126,6 @@ from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import hidden_axes as hidden_axes from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh diff --git a/jax/_src/api.py b/jax/_src/api.py index 3ed68c85054e..5585544d80d7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -99,8 +99,6 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -hidden_axes = pjit.hidden_axes - def _nan_check_posthook(fun, args, kwargs, output): """Hook function called by the C++ jit/pmap to perform NaN checking.""" diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 9e7dd24c744c..d5f32f5f45fa 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1883,7 +1883,8 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): # TODO(yashkatariya): Write a proper gather sharding rule. - if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore return None raise GatherShardingError( "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 2ec4d3591476..c5f8f31576ec 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -57,7 +57,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): f'sharding rule for {prim.name} is not implemented. Please file a' ' bug at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full hidden sharding' - ' mode via: `jax.hidden_axes(fun, out_shardings=...)`') + ' mode via: `jax.experimental.shard.hidden_axes(fun, out_shardings=...)`') return rule(*avals, **kwargs) return None if num_out is None else [None] * num_out diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 2a0b00211319..636d8a68f142 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2684,17 +2684,11 @@ def _sharding_constraint_batcher( # TODO(yashkatariya): Make shardings optional. def mesh_cast(xs, out_shardings): - if isinstance(out_shardings, (NamedSharding, PartitionSpec)): - return tree_map( - lambda x: mesh_cast_p.bind( - x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding( - out_shardings, check_mesh_consistency=False)), xs) - x_flat, treedef = tree_flatten(xs) shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings) out_flat = [ mesh_cast_p.bind( - x, src_sharding=x.sharding, + x, src_sharding=x.aval.sharding, dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False)) for x, s in safe_zip(x_flat, shardings_flat) ] @@ -2779,6 +2773,49 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): # batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher # batching.skippable_batchers[mesh_cast_p] = lambda _: () +# -------------------- reshard ------------------------------------ + +def reshard(xs, out_shardings): + x_flat, treedef = tree_flatten(xs) + shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings) + out_flat = [] + for x, s in safe_zip(x_flat, shardings_flat): + ds = canonicalize_sharding(s) + ds = ds.with_spec(ds.spec._normalized_spec(x.ndim)) # type: ignore + out_flat.append(reshard_p.bind(x, src_sharding=x.aval.sharding, + dst_sharding=ds)) + return tree_unflatten(treedef, out_flat) + +reshard_p = core.Primitive('reshard') + +def _reshard_abstract_eval(aval, src_sharding, dst_sharding): + if src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh: + raise ValueError( + f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not' + ' equal the mesh of the target sharding' + f' {dst_sharding.mesh.abstract_mesh} for shape {aval.str_short()}') + return aval.update(sharding=dst_sharding) +reshard_p.def_abstract_eval(_reshard_abstract_eval) + +def _reshard_impl(x, src_sharding, dst_sharding): + return dispatch.apply_primitive(reshard_p, x, src_sharding=src_sharding, + dst_sharding=dst_sharding) +reshard_p.def_impl(_reshard_impl) + +def _reshard_transpose_rule(ct, _, src_sharding, dst_sharding): + return [reshard_p.bind(ct, src_sharding=dst_sharding, + dst_sharding=src_sharding)] +ad.deflinear2(reshard_p, _reshard_transpose_rule) + +def _reshard_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): + aval, = ctx.avals_in + aval_out, = ctx.avals_out + proto = (dst_sharding._to_sdy_sharding(aval.ndim) + if config.use_shardy_partitioner.value else + dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()) + return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)] +mlir.register_lowering(reshard_p, _reshard_hlo_lowering) + # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8150c9a90037..b2a482be5450 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1766,11 +1766,13 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore else: if (check_mesh_consistency and - sharding.mesh != mesh_lib.get_abstract_mesh()): + sharding.mesh.abstract_mesh != mesh_lib.get_abstract_mesh()): raise ValueError( f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh' - f' of sharding {sharding.mesh}. This error occurs at source: ' - f' {source_info_util.summarize(source_info_util.current())}') + f' of sharding {sharding.mesh.abstract_mesh}. This error occurs at' + f' source: {source_info_util.summarize(source_info_util.current())}') + if isinstance(sharding.mesh, mesh_lib.Mesh): + sharding = NamedSharding(sharding.mesh.abstract_mesh, sharding.spec) for s in flatten_spec(sharding.spec): if sharding.mesh._name_to_type[s] in { diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 4679e81b0ad8..98e608817cc2 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -176,6 +176,8 @@ def test_primitive_coverage(self): continue if p.name == "mesh_cast": continue + if p.name == "reshard": + continue # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue diff --git a/jax/experimental/shard.py b/jax/experimental/shard.py new file mode 100644 index 000000000000..6674ac3b4681 --- /dev/null +++ b/jax/experimental/shard.py @@ -0,0 +1,18 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from jax._src.pjit import ( + reshard as reshard, + hidden_axes as hidden_axes, +) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5d1313ba7c75..5196ac746be5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -43,7 +43,8 @@ from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils from jax.experimental.shard_map import shard_map -from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING +from jax.experimental.custom_partitioning import ( + custom_partitioning, SdyShardingRule, BATCHING) from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings @@ -51,10 +52,10 @@ from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) -from jax._src.pjit import (pjit, mesh_cast, visible_axes, - use_hidden_axes, use_visible_axes) +from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes, + use_hidden_axes, use_visible_axes, reshard) from jax._src import mesh as mesh_lib -from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes +from jax._src.mesh import AxisTypes from jax._src.interpreters import pxla from jax._src.lib.mlir import dialects from jax._src import xla_bridge @@ -5905,24 +5906,6 @@ def f(x, y): ValueError, "For primitive dot_general, context mesh.*aval mesh"): f(arr, arr.T) - def test_mesh_cast_src_dst_mesh_mismatch(self): - np_inp = np.arange(16.).reshape(8, 2) - mesh = jtu.create_mesh((2, 1), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')}) - mesh2 = jtu.create_mesh((2, 1), ('a', 'b'), - axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')}) - s = NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) - f = lambda x: mesh_cast(x, NamedSharding(mesh2, P('a', 'b'))) - with self.assertRaisesRegex( - ValueError, "Mesh shape of the input.*does not match"): - f(arr) - - with mesh_lib.use_mesh(mesh): - with self.assertRaisesRegex( - ValueError, "Mesh shape of the input.*does not match"): - jax.jit(f)(arr) - @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_split(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -5960,8 +5943,7 @@ def test_return_output_different_context(self, mesh): @jax.jit def f(x): - auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'}) - with set_abstract_mesh(auto_mesh): + with use_hidden_axes('x'): x = mesh_cast(x, P(None, None)) return x @@ -6048,7 +6030,7 @@ def test_auto_mode_mix(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(jax.hidden_axes, axes='x', out_shardings=P('x', None)) + @partial(hidden_axes, axes='x', out_shardings=P('x', None)) def h(y): self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6181,6 +6163,86 @@ def g(x, y): out = jax.jit(jax.grad(g))(embed, tok) self.assertEqual(out.sharding, embed.sharding) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reshard_error(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + def f(x): + y = reshard(x, P('x', None)) + self.assertEqual(y.aval.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + f = jax.jit(f) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + lowered_text = f.lower(arr).as_text() + self.check_wsc_in_lowered(lowered_text) + + def g(x): + y = f(x) + return jnp.sum(y) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + @jax.jit + def h(x): + with use_hidden_axes('x'): + return reshard(x, P('y', None)) + + with self.assertRaisesRegex( + ValueError, 'Mesh of the input.*does not equal.*target sharding'): + h(arr) + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_full_auto_outside_jit(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x * 2 + self.assertEqual(y.sharding.spec, P(None, None)) + z = jnp.sin(y) + self.assertEqual(z.sharding.spec, P(None, None)) + a = z @ z.T + self.assertEqual(a.sharding.spec, P(None, None)) + return a + + hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y')) + out = hf(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_user_mesh((2, 2), ('x', 'y'), + axis_types={AxisTypes.Hidden: ('x', 'y')}) + def test_full_visible_outside_jit(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x * 2 + self.assertEqual(y.sharding.spec, P('x', 'y')) + z = jnp.sin(y) + self.assertEqual(z.sharding.spec, P('x', 'y')) + return z + + hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y')) + out = hf(arr) # doesn't crash + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):