From 009dddff4717005170e79f72ff694c23f84c9eb8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Jan 2025 19:45:52 -0800 Subject: [PATCH] Make sure reshard and mesh_cast behave properly under eager mode PiperOrigin-RevId: 717702155 --- jax/_src/pjit.py | 12 +++++++----- tests/pjit_test.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e06ce59a3461..bd2a60acef27 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2685,12 +2685,13 @@ def _sharding_constraint_batcher( # TODO(yashkatariya): Make shardings optional. def mesh_cast(xs, out_shardings): x_flat, treedef = tree_flatten(xs) + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings) out_flat = [ mesh_cast_p.bind( - x, src_sharding=x.aval.sharding, + x, src_sharding=x_aval.sharding, dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False)) - for x, s in safe_zip(x_flat, shardings_flat) + for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat) ] return tree_unflatten(treedef, out_flat) @@ -2778,11 +2779,12 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): def reshard(xs, out_shardings): x_flat, treedef = tree_flatten(xs) shardings_flat = flatten_axes("reshard shardings", treedef, out_shardings) + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] out_flat = [] - for x, s in safe_zip(x_flat, shardings_flat): + for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): ds = canonicalize_sharding(s) - ds = ds.with_spec(ds.spec._normalized_spec(x.ndim)) # type: ignore - out_flat.append(reshard_p.bind(x, src_sharding=x.aval.sharding, + ds = ds.with_spec(ds.spec._normalized_spec(x_aval.ndim)) # type: ignore + out_flat.append(reshard_p.bind(x, src_sharding=x_aval.sharding, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fefb2b3abede..5ccb9a7595ed 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6277,6 +6277,23 @@ def f(arr1, arr2): out = f(arr1, arr2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + def test_reshard_eager_mode(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types={AxisTypes.Visible: ('x', 'y')}) + np_inp = np.arange(16.).reshape(8, 2) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + def matmul_reshard(arr1, arr2): + arr2 = reshard(arr2, P('y', None)) + self.assertEqual(arr2.aval.sharding.spec, P('y', None)) + out = jnp.einsum('xy,yz->xz', arr1, arr2) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + with jax.sharding.use_mesh(mesh): + matmul_reshard(arr1, arr2) + @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_full_auto_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2)