Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both: #25978

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import hidden_axes as hidden_axes
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh

Expand Down
2 changes: 0 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

hidden_axes = pjit.hidden_axes


def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,8 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
# TODO(yashkatariya): Write a proper gather sharding rule.
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore
return None
raise GatherShardingError(
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
f'sharding rule for {prim.name} is not implemented. Please file a'
' bug at https://github.com/jax-ml/jax/issues. You can work around'
' this error by dropping that operation into full hidden sharding'
' mode via: `jax.hidden_axes(fun, out_shardings=...)`')
' mode via: `jax.experimental.shard.hidden_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

Expand Down
51 changes: 44 additions & 7 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,17 +2684,11 @@ def _sharding_constraint_batcher(

# TODO(yashkatariya): Make shardings optional.
def mesh_cast(xs, out_shardings):
if isinstance(out_shardings, (NamedSharding, PartitionSpec)):
return tree_map(
lambda x: mesh_cast_p.bind(
x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding(
out_shardings, check_mesh_consistency=False)), xs)

x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings)
out_flat = [
mesh_cast_p.bind(
x, src_sharding=x.sharding,
x, src_sharding=x.aval.sharding,
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
for x, s in safe_zip(x_flat, shardings_flat)
]
Expand Down Expand Up @@ -2779,6 +2773,49 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
# batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
# batching.skippable_batchers[mesh_cast_p] = lambda _: ()

# -------------------- reshard ------------------------------------

def reshard(xs, out_shardings):
x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings)
out_flat = []
for x, s in safe_zip(x_flat, shardings_flat):
ds = canonicalize_sharding(s)
ds = ds.with_spec(ds.spec._normalized_spec(x.ndim)) # type: ignore
out_flat.append(reshard_p.bind(x, src_sharding=x.aval.sharding,
dst_sharding=ds))
return tree_unflatten(treedef, out_flat)

reshard_p = core.Primitive('reshard')

def _reshard_abstract_eval(aval, src_sharding, dst_sharding):
if src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh:
raise ValueError(
f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not'
' equal the mesh of the target sharding'
f' {dst_sharding.mesh.abstract_mesh} for shape {aval.str_short()}')
return aval.update(sharding=dst_sharding)
reshard_p.def_abstract_eval(_reshard_abstract_eval)

def _reshard_impl(x, src_sharding, dst_sharding):
return dispatch.apply_primitive(reshard_p, x, src_sharding=src_sharding,
dst_sharding=dst_sharding)
reshard_p.def_impl(_reshard_impl)

def _reshard_transpose_rule(ct, _, src_sharding, dst_sharding):
return [reshard_p.bind(ct, src_sharding=dst_sharding,
dst_sharding=src_sharding)]
ad.deflinear2(reshard_p, _reshard_transpose_rule)

def _reshard_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
aval, = ctx.avals_in
aval_out, = ctx.avals_out
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
if config.use_shardy_partitioner.value else
dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto())
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)

# -------------------- auto and user mode -------------------------

def _get_new_mesh(axes: str | tuple[str, ...] | None,
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,11 +1766,13 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
else:
if (check_mesh_consistency and
sharding.mesh != mesh_lib.get_abstract_mesh()):
sharding.mesh.abstract_mesh != mesh_lib.get_abstract_mesh()):
raise ValueError(
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
f' of sharding {sharding.mesh}. This error occurs at source: '
f' {source_info_util.summarize(source_info_util.current())}')
f' of sharding {sharding.mesh.abstract_mesh}. This error occurs at'
f' source: {source_info_util.summarize(source_info_util.current())}')
if isinstance(sharding.mesh, mesh_lib.Mesh):
sharding = NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)

for s in flatten_spec(sharding.spec):
if sharding.mesh._name_to_type[s] in {
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def test_primitive_coverage(self):
continue
if p.name == "mesh_cast":
continue
if p.name == "reshard":
continue
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
Expand Down
18 changes: 18 additions & 0 deletions jax/experimental/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.

from jax._src.pjit import (
reshard as reshard,
hidden_axes as hidden_axes,
)
112 changes: 87 additions & 25 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@
from jax.sharding import PartitionSpec as P, Mesh
from jax.experimental import multihost_utils
from jax.experimental.shard_map import shard_map
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING
from jax.experimental.custom_partitioning import (
custom_partitioning, SdyShardingRule, BATCHING)
from jax._src import array
from jax._src.sharding import Sharding, common_devices_indices_map
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, visible_axes,
use_hidden_axes, use_visible_axes)
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
use_hidden_axes, use_visible_axes, reshard)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
from jax._src.mesh import AxisTypes
from jax._src.interpreters import pxla
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
Expand Down Expand Up @@ -5905,24 +5906,6 @@ def f(x, y):
ValueError, "For primitive dot_general, context mesh.*aval mesh"):
f(arr, arr.T)

def test_mesh_cast_src_dst_mesh_mismatch(self):
np_inp = np.arange(16.).reshape(8, 2)
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')})
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')})
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
f = lambda x: mesh_cast(x, NamedSharding(mesh2, P('a', 'b')))
with self.assertRaisesRegex(
ValueError, "Mesh shape of the input.*does not match"):
f(arr)

with mesh_lib.use_mesh(mesh):
with self.assertRaisesRegex(
ValueError, "Mesh shape of the input.*does not match"):
jax.jit(f)(arr)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_split(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
Expand Down Expand Up @@ -5960,8 +5943,7 @@ def test_return_output_different_context(self, mesh):

@jax.jit
def f(x):
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'})
with set_abstract_mesh(auto_mesh):
with use_hidden_axes('x'):
x = mesh_cast(x, P(None, None))
return x

Expand Down Expand Up @@ -6048,7 +6030,7 @@ def test_auto_mode_mix(self, mesh):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@partial(jax.hidden_axes, axes='x', out_shardings=P('x', None))
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
Expand Down Expand Up @@ -6181,6 +6163,86 @@ def g(x, y):
out = jax.jit(jax.grad(g))(embed, tok)
self.assertEqual(out.sharding, embed.sharding)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reshard_error(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

def f(x):
y = reshard(x, P('x', None))
self.assertEqual(y.aval.sharding.spec, P('x', None))
return y

out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

f = jax.jit(f)

out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

lowered_text = f.lower(arr).as_text()
self.check_wsc_in_lowered(lowered_text)

def g(x):
y = f(x)
return jnp.sum(y)

out = jax.grad(g)(arr)
self.assertEqual(out.sharding, arr.sharding)

out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)

@jax.jit
def h(x):
with use_hidden_axes('x'):
return reshard(x, P('y', None))

with self.assertRaisesRegex(
ValueError, 'Mesh of the input.*does not equal.*target sharding'):
h(arr)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_full_auto_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
return a

hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y'))
out = hf(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))

@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={AxisTypes.Hidden: ('x', 'y')})
def test_full_visible_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
return z

hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y'))
out = hf(arr) # doesn't crash
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down
Loading