Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet. #25970

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import hidden_axes as hidden_axes
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

hidden_axes = pjit.hidden_axes


def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
Expand Down
17 changes: 12 additions & 5 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ def standard_primitive(shape_rule, dtype_rule, name,

def _get_array_abstraction_level(a): return a.array_abstraction_level

def call_sharding_rule(rule, num_out, *avals, **kwargs):
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None if num_out is None else [None] * num_out
if rule is None:
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None if num_out is None else [None] * num_out
else:
raise ValueError(
f'sharding rule for {prim.name} is not implemented. Please file a'
' bug at https://github.com/jax-ml/jax/issues. You can work around'
' this error by dropping that operation into full hidden sharding'
' mode via: `jax.hidden_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

Expand All @@ -65,7 +72,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=call_sharding_rule(sharding_rule, None, *avals, **kwargs))
sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs))
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
elif least_specialized is core.DShapedArray:
Expand All @@ -90,7 +97,7 @@ def standard_multi_result_abstract_eval(
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = call_sharding_rule(
sharding_rule, len(out_shapes), *avals, **kwargs)
prim, sharding_rule, len(out_shapes), *avals, **kwargs)
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
weak_types, out_shardings)]
Expand Down
4 changes: 2 additions & 2 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
from jax._src.pjit import (pjit, mesh_cast, visible_axes,
use_hidden_axes, use_visible_axes)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
Expand Down Expand Up @@ -6048,7 +6048,7 @@ def test_auto_mode_mix(self, mesh):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@partial(hidden_axes, axes='x', out_shardings=P('x', None))
@partial(jax.hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
Expand Down
Loading