Skip to content

Commit

Permalink
Remove memories flag now that JAX 0.5.0 has been released since it al…
Browse files Browse the repository at this point in the history
…ways defaults to True.

PiperOrigin-RevId: 716908015
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 18, 2025
1 parent 36daf36 commit 5a068da
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,22 +974,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
upgrade=True,
help='If True, pmap and shard_map API will be merged.')

# Remove after next JAX release on Jan 15, 2025.
if hasattr(jax_jit.global_state(), 'enable_memories'):
def _update_jax_memories_global(val):
jax_jit.global_state().enable_memories = val

def _update_jax_memories_thread_local(val):
jax_jit.thread_local_state().enable_memories = val

enable_memories = bool_state(
'jax_enable_memories',
default=True,
upgrade=True,
update_global_hook=_update_jax_memories_global,
update_thread_local_hook=_update_jax_memories_thread_local,
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))

spmd_mode = enum_state(
name='jax_spmd_mode',
Expand Down

0 comments on commit 5a068da

Please sign in to comment.