-
Notifications
You must be signed in to change notification settings - Fork 478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copies supersede OptimizationBarrier #20440
Comments
For an example which is not fixed with import os
from functools import partial
os.environ["XLA_FLAGS"] = "--xla_cpu_copy_insertion_use_region_analysis=true"
import jax
import jax.numpy as jnp
from jax import Array, jit, lax
@partial(jit, donate_argnums=0)
def roll1(a: Array) -> Array:
"""Roll with shift=1."""
n = a.size
x = a[n - 1]
a, x = lax.optimization_barrier((a, x))
a = lax.fori_loop(1, n, lambda i, a: a.at[n - i].set(a[n - 1 - i]), a)
a = a.at[0].set(x)
return a
if __name__ == "__main__":
x = jnp.arange(100)
print(jax.make_jaxpr(roll1)(x))
lowered = roll1.lower(x)
compiled = lowered.compile()
print(compiled.as_text()) Running on cpu copies the input twice (
Running on gpu does not copy.
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Consider the JAX function
Since XLA has control over scheduling, for efficiency it should schedule the slice first and then the in-place update, to avoid an unnecessary copy. However, on specifically the CPU backend it chooses to copy twice instead, generating
(I'm not sure why it needs to make two copies here instead of just one, but the important part is that it copies at all.)
By the semantics of
lax.optimization_barrier
, I would expect that introducing an explicit dependency ofx
ony
would force the slice to happen first, and then the liveliness analysis will kick in and remove the copies.However, what ends up happening is XLA still introduces copies and re-orders the calls, so the generated code is the same as the one shown above. This seems to violate the scheduling control one expects from
optimization_barrier
.Note that for this particular example, setting the XLA flag
--xla_cpu_copy_insertion_use_region_analysis=true
removes the copy and generatesas expected, with or without
optimization_barrier
. Also, using a GPU device generates the copylessalso with or without
optimization_barrier
. Finall, the reverse explicit schedulewhich should introduce a copy does not introduce a copy with
--xla_cpu_copy_insertion_use_region_analysis=true
.I'm a bit confused why the flag workaround works now, since region analysis was introduced more than 3 years ago in 92292d1. The core logic of
RemoveUnnecessaryCopies
andTryElideCopy
hasn't seemed to change much in that time either. Rather, what has recently changed is the flagxla_cpu_copy_insertion_use_region_analysis
was added to CPU (disabled by default) (#18521) and region analysis was disabled on GPU (#14680). Is there some context I'm missing?(originally reported in the discussion jax-ml/jax#19165 and JAX issue jax-ml/jax#25399.)
The text was updated successfully, but these errors were encountered: