Skip to content

Commit

Permalink
Make sure reshard and mesh_cast behave properly under eager mode
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717702155
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 21, 2025
1 parent a943ebf commit 009dddf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
12 changes: 7 additions & 5 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2685,12 +2685,13 @@ def _sharding_constraint_batcher(
# TODO(yashkatariya): Make shardings optional.
def mesh_cast(xs, out_shardings):
x_flat, treedef = tree_flatten(xs)
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings)
out_flat = [
mesh_cast_p.bind(
x, src_sharding=x.aval.sharding,
x, src_sharding=x_aval.sharding,
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
for x, s in safe_zip(x_flat, shardings_flat)
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat)
]
return tree_unflatten(treedef, out_flat)

Expand Down Expand Up @@ -2778,11 +2779,12 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
def reshard(xs, out_shardings):
x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings)
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
out_flat = []
for x, s in safe_zip(x_flat, shardings_flat):
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat):
ds = canonicalize_sharding(s)
ds = ds.with_spec(ds.spec._normalized_spec(x.ndim)) # type: ignore
out_flat.append(reshard_p.bind(x, src_sharding=x.aval.sharding,
ds = ds.with_spec(ds.spec._normalized_spec(x_aval.ndim)) # type: ignore
out_flat.append(reshard_p.bind(x, src_sharding=x_aval.sharding,
dst_sharding=ds))
return tree_unflatten(treedef, out_flat)

Expand Down
17 changes: 17 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6277,6 +6277,23 @@ def f(arr1, arr2):
out = f(arr1, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))

def test_reshard_eager_mode(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={AxisTypes.Visible: ('x', 'y')})
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))

def matmul_reshard(arr1, arr2):
arr2 = reshard(arr2, P('y', None))
self.assertEqual(arr2.aval.sharding.spec, P('y', None))
out = jnp.einsum('xy,yz->xz', arr1, arr2)
self.assertEqual(out.aval.sharding.spec, P('x', None))
return out

with jax.sharding.use_mesh(mesh):
matmul_reshard(arr1, arr2)

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_full_auto_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
Expand Down

0 comments on commit 009dddf

Please sign in to comment.