Skip to content

Commit

Permalink
[sharding_in_types] Rename .at[...].get(out_spec) to `.at[...].get(…
Browse files Browse the repository at this point in the history
…out_sharding)`.

PiperOrigin-RevId: 716466870
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 17, 2025
1 parent 97cd748 commit af66719
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class _IndexUpdateHelper:
class _IndexUpdateRef:
def get(self, indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None,
out_spec: PartitionSpec | None = None) -> Array: ...
out_spec: Sharding | PartitionSpec | None = None) -> Array: ...
def set(self, values: Any,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,8 +1886,8 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None
raise GatherShardingError(
"Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the"
" gather indexing.")
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
" the gather indexing.")

def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
unique_indices, indices_are_sorted, fill_value,
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from jax._src.numpy import lax_numpy
from jax._src import mesh as mesh_lib
from jax._src.pjit import hidden_mode, PartitionSpec
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.ops import scatter
Expand Down Expand Up @@ -765,7 +766,7 @@ def __repr__(self) -> str:
return f"_IndexUpdateRef({self.array!r}, {self.index!r})"

def get(self, *, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None, out_spec=None):
mode=None, fill_value=None, out_sharding=None):
"""Equivalent to ``x[idx]``.
Returns the value of ``x`` that would result from the NumPy-style
Expand All @@ -779,10 +780,11 @@ def get(self, *, indices_are_sorted=False, unique_indices=False,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
fill_value=fill_value)
if out_spec is not None:
assert isinstance(out_spec, PartitionSpec)
if out_sharding is not None:
assert isinstance(out_sharding, (NamedSharding, PartitionSpec))
out_sharding = canonicalize_sharding(out_sharding)
take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_specs=out_spec)
out_specs=out_sharding.spec)
return take(self.array, self.index)

def set(self, values, *, indices_are_sorted=False, unique_indices=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6117,15 +6117,15 @@ def f(x):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_auto_gather_out_spec(self, mesh):
def test_auto_gather_out_sharding(self, mesh):
embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16),
jax.NamedSharding(mesh, P(None, 'x')))
tok = jax.device_put(jnp.arange(8 * 4).reshape(8, 4),
jax.NamedSharding(mesh, P('x', None)))

@jax.jit
def f(embed_vd, token_bt):
out = embed_vd.at[token_bt].get(out_spec=P('x', None, None))
out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None))
self.assertEqual(out.shape, (8, 4, 16))
self.assertEqual(out.sharding.spec, P('x', None, None))
return out
Expand Down

0 comments on commit af66719

Please sign in to comment.