Skip to content

Commit

Permalink
Expose hidden_axes via jax namespace as public API. Also mention it a…
Browse files Browse the repository at this point in the history
…s a workaround for primitives we don't support yet.

PiperOrigin-RevId: 716770575
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 18, 2025
1 parent 9fb2976 commit efcca6a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
1 change: 1 addition & 0 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import hidden_axes as hidden_axes
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

hidden_axes = pjit.hidden_axes


def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
Expand Down
17 changes: 12 additions & 5 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ def standard_primitive(shape_rule, dtype_rule, name,

def _get_array_abstraction_level(a): return a.array_abstraction_level

def call_sharding_rule(rule, num_out, *avals, **kwargs):
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None if num_out is None else [None] * num_out
if rule is None:
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None if num_out is None else [None] * num_out
else:
raise ValueError(
f'sharding rule for {prim.name} is not implemented. Please file a'
' bug at https://github.com/jax-ml/jax/issues. You can work around'
' this error by dropping that operation into full hidden sharding'
' mode via: `jax.hidden_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

Expand All @@ -65,7 +72,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=call_sharding_rule(sharding_rule, None, *avals, **kwargs))
sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs))
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
elif least_specialized is core.DShapedArray:
Expand All @@ -90,7 +97,7 @@ def standard_multi_result_abstract_eval(
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = call_sharding_rule(
sharding_rule, len(out_shapes), *avals, **kwargs)
prim, sharding_rule, len(out_shapes), *avals, **kwargs)
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
weak_types, out_shardings)]
Expand Down
4 changes: 2 additions & 2 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
from jax._src.pjit import (pjit, mesh_cast, visible_axes,
use_hidden_axes, use_visible_axes)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
Expand Down Expand Up @@ -6048,7 +6048,7 @@ def test_auto_mode_mix(self, mesh):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@partial(hidden_axes, axes='x', out_shardings=P('x', None))
@partial(jax.hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
Expand Down

0 comments on commit efcca6a

Please sign in to comment.