Skip to content

Commit

Permalink
Add reshard API in experimental. Currently for sharding_in_types we…
Browse files Browse the repository at this point in the history
… have 2 APIs: `mesh_cast` and `reshard`. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:

* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 716879595
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 20, 2025
1 parent 4fd0bb0 commit c69e51d
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 40 deletions.
1 change: 0 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import hidden_axes as hidden_axes
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh

Expand Down
2 changes: 0 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

hidden_axes = pjit.hidden_axes


def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,8 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
# TODO(yashkatariya): Write a proper gather sharding rule.
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore
return None
raise GatherShardingError(
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
f'sharding rule for {prim.name} is not implemented. Please file a'
' bug at https://github.com/jax-ml/jax/issues. You can work around'
' this error by dropping that operation into full hidden sharding'
' mode via: `jax.hidden_axes(fun, out_shardings=...)`')
' mode via: `jax.experimental.shard.hidden_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

Expand Down
51 changes: 44 additions & 7 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,17 +2684,11 @@ def _sharding_constraint_batcher(

# TODO(yashkatariya): Make shardings optional.
def mesh_cast(xs, out_shardings):
if isinstance(out_shardings, (NamedSharding, PartitionSpec)):
return tree_map(
lambda x: mesh_cast_p.bind(
x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding(
out_shardings, check_mesh_consistency=False)), xs)

x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings)
out_flat = [
mesh_cast_p.bind(
x, src_sharding=x.sharding,
x, src_sharding=x.aval.sharding,
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
for x, s in safe_zip(x_flat, shardings_flat)
]
Expand Down Expand Up @@ -2779,6 +2773,49 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
# batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
# batching.skippable_batchers[mesh_cast_p] = lambda _: ()

# -------------------- reshard ------------------------------------

def reshard(xs, out_shardings):
x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings)
out_flat = []
for x, s in safe_zip(x_flat, shardings_flat):
ds = canonicalize_sharding(s)
ds = ds.with_spec(ds.spec._normalized_spec(x.ndim)) # type: ignore
out_flat.append(reshard_p.bind(x, src_sharding=x.aval.sharding,
dst_sharding=ds))
return tree_unflatten(treedef, out_flat)

reshard_p = core.Primitive('reshard')

def _reshard_abstract_eval(aval, src_sharding, dst_sharding):
if src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh:
raise ValueError(
f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not'
' equal the mesh of the target sharding'
f' {dst_sharding.mesh.abstract_mesh} for shape {aval.str_short()}')
return aval.update(sharding=dst_sharding)
reshard_p.def_abstract_eval(_reshard_abstract_eval)

def _reshard_impl(x, src_sharding, dst_sharding):
return dispatch.apply_primitive(reshard_p, x, src_sharding=src_sharding,
dst_sharding=dst_sharding)
reshard_p.def_impl(_reshard_impl)

def _reshard_transpose_rule(ct, _, src_sharding, dst_sharding):
return [reshard_p.bind(ct, src_sharding=dst_sharding,
dst_sharding=src_sharding)]
ad.deflinear2(reshard_p, _reshard_transpose_rule)

def _reshard_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
aval, = ctx.avals_in
aval_out, = ctx.avals_out
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
if config.use_shardy_partitioner.value else
dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto())
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)

# -------------------- auto and user mode -------------------------

def _get_new_mesh(axes: str | tuple[str, ...] | None,
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,11 +1766,13 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
else:
if (check_mesh_consistency and
sharding.mesh != mesh_lib.get_abstract_mesh()):
sharding.mesh.abstract_mesh != mesh_lib.get_abstract_mesh()):
raise ValueError(
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
f' of sharding {sharding.mesh}. This error occurs at source: '
f' {source_info_util.summarize(source_info_util.current())}')
f' of sharding {sharding.mesh.abstract_mesh}. This error occurs at'
f' source: {source_info_util.summarize(source_info_util.current())}')
if isinstance(sharding.mesh, mesh_lib.Mesh):
sharding = NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)

for s in flatten_spec(sharding.spec):
if sharding.mesh._name_to_type[s] in {
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def test_primitive_coverage(self):
continue
if p.name == "mesh_cast":
continue
if p.name == "reshard":
continue
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
Expand Down
18 changes: 18 additions & 0 deletions jax/experimental/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.

from jax._src.pjit import (
reshard as reshard,
hidden_axes as hidden_axes,
)
112 changes: 87 additions & 25 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@
from jax.sharding import PartitionSpec as P, Mesh
from jax.experimental import multihost_utils
from jax.experimental.shard_map import shard_map
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING
from jax.experimental.custom_partitioning import (
custom_partitioning, SdyShardingRule, BATCHING)
from jax._src import array
from jax._src.sharding import Sharding, common_devices_indices_map
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, visible_axes,
use_hidden_axes, use_visible_axes)
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
use_hidden_axes, use_visible_axes, reshard)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
from jax._src.mesh import AxisTypes
from jax._src.interpreters import pxla
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
Expand Down Expand Up @@ -5905,24 +5906,6 @@ def f(x, y):
ValueError, "For primitive dot_general, context mesh.*aval mesh"):
f(arr, arr.T)

def test_mesh_cast_src_dst_mesh_mismatch(self):
np_inp = np.arange(16.).reshape(8, 2)
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')})
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')})
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
f = lambda x: mesh_cast(x, NamedSharding(mesh2, P('a', 'b')))
with self.assertRaisesRegex(
ValueError, "Mesh shape of the input.*does not match"):
f(arr)

with mesh_lib.use_mesh(mesh):
with self.assertRaisesRegex(
ValueError, "Mesh shape of the input.*does not match"):
jax.jit(f)(arr)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_split(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
Expand Down Expand Up @@ -5960,8 +5943,7 @@ def test_return_output_different_context(self, mesh):

@jax.jit
def f(x):
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'})
with set_abstract_mesh(auto_mesh):
with use_hidden_axes('x'):
x = mesh_cast(x, P(None, None))
return x

Expand Down Expand Up @@ -6048,7 +6030,7 @@ def test_auto_mode_mix(self, mesh):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@partial(jax.hidden_axes, axes='x', out_shardings=P('x', None))
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
Expand Down Expand Up @@ -6181,6 +6163,86 @@ def g(x, y):
out = jax.jit(jax.grad(g))(embed, tok)
self.assertEqual(out.sharding, embed.sharding)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reshard_error(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

def f(x):
y = reshard(x, P('x', None))
self.assertEqual(y.aval.sharding.spec, P('x', None))
return y

out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

f = jax.jit(f)

out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

lowered_text = f.lower(arr).as_text()
self.check_wsc_in_lowered(lowered_text)

def g(x):
y = f(x)
return jnp.sum(y)

out = jax.grad(g)(arr)
self.assertEqual(out.sharding, arr.sharding)

out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)

@jax.jit
def h(x):
with use_hidden_axes('x'):
return reshard(x, P('y', None))

with self.assertRaisesRegex(
ValueError, 'Mesh of the input.*does not equal.*target sharding'):
h(arr)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_full_auto_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
return a

hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y'))
out = hf(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))

@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={AxisTypes.Hidden: ('x', 'y')})
def test_full_visible_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
return z

hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y'))
out = hf(arr) # doesn't crash
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit c69e51d

Please sign in to comment.