-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test case for jit compile of vmap on gpu #25980
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A couple suggestions below.
tests/jax_jit_test.py
Outdated
if jtu.device_under_test() == "cpu": | ||
self.skipTest("Test only runs on GPU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the skip? no harm in running this test on CPU as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. The initial thought is to emphasize the issue on GPU. But I'm ok to allow tests on other device types as well.
tests/jax_jit_test.py
Outdated
pos = H @ jnp.concatenate([x, jnp.array([1.0])]) | ||
return pos, R1 | ||
jitted_fn = jax.jit(fn) | ||
x_gpu = jax.device_put(jnp.zeros((2,3)), jax.local_devices()[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for explicit device put I think: when the tests are run on GPU, the default device is GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
tests/jax_jit_test.py
Outdated
return pos, R1 | ||
jitted_fn = jax.jit(fn) | ||
x_gpu = jax.device_put(jnp.zeros((2,3)), jax.local_devices()[0]) | ||
jitted_fn(x_gpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wouldn't hurt to also validate the expected output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. My style is a little focus on the specific GPU issue. While the suggestion sounds good for the test itself.
Co-authored-by: Jake Vanderplas <[email protected]>
Description
The error case of jit compile of vmap on GPU is reported in openxla/xla#15744. @jakevdp
I fetched the head of jax repository. Build and installed jax from source. Modified the code to use jax.device_put and jax.local_devices which is required for the new jax version. The test passed.
System Configurations
GPU: NVIDIA T4
Google cloud machine type: n1-standard-4
Linux x86-64