Skip to content

Commit

Permalink
Fix remat bug on primitives with multiple outputs.
Browse files Browse the repository at this point in the history
Addresses #25841

PiperOrigin-RevId: 715434084
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 14, 2025
1 parent 2408fb7 commit b6acb9c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,8 @@ def has_effects(effects) -> bool:
outvars_copy = list[Atom](eqn.outvars)
offload_eqn = core.JaxprEqn(
outvars_copy, resvars, device_put_p,
dict(devices=[TransferToMemoryKind(policy.dst)], srcs=[None],
dict(devices=[TransferToMemoryKind(policy.dst)
] * len(outvars_copy), srcs=[None],
copy_semantics=[CopySemantics.COPY]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
Expand All @@ -1093,7 +1094,8 @@ def has_effects(effects) -> bool:
residuals.update(resvars)
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p,
dict(devices=[TransferToMemoryKind(policy.src)], srcs=[None],
dict(devices=[TransferToMemoryKind(policy.src)
] * len(resvars), srcs=[None],
copy_semantics=[CopySemantics.COPY]),
set(), source_info_util.new_source_info(),
JaxprEqnContext(None, False))
Expand Down
21 changes: 21 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,5 +1842,26 @@ def f(x):
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)

def test_primitive_with_multiple_outputs(self):
# Test for https://github.com/jax-ml/jax/issues/25841
shape = (128,)
inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)

def policy(prim, *args, **kwargs):
del args, kwargs
if prim.multiple_results:
return Offloadable("device", "pinned_host")
return Recompute

@functools.partial(remat, policy=policy)
def test_fn(x):
# Need any primitive with multiple outputs and a non-trivial grad.
x1, _ = jax.lax.approx_max_k(x, k=2)
return jnp.sum(x1)

fn = jax.grad(test_fn)
jax.jit(fn)(inp) # doesn't crash


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b6acb9c

Please sign in to comment.