Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test case for jit compile of vmap on gpu #25980

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

codinglover222
Copy link

@codinglover222 codinglover222 commented Jan 18, 2025

Description

The error case of jit compile of vmap on GPU is reported in openxla/xla#15744. @jakevdp

I fetched the head of jax repository. Build and installed jax from source. Modified the code to use jax.device_put and jax.local_devices which is required for the new jax version. The test passed.

System Configurations

GPU: NVIDIA T4
Google cloud machine type: n1-standard-4
Linux x86-64

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A couple suggestions below.

tests/jax_jit_test.py Outdated Show resolved Hide resolved
Comment on lines 209 to 210
if jtu.device_under_test() == "cpu":
self.skipTest("Test only runs on GPU")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the skip? no harm in running this test on CPU as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. The initial thought is to emphasize the issue on GPU. But I'm ok to allow tests on other device types as well.

pos = H @ jnp.concatenate([x, jnp.array([1.0])])
return pos, R1
jitted_fn = jax.jit(fn)
x_gpu = jax.device_put(jnp.zeros((2,3)), jax.local_devices()[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for explicit device put I think: when the tests are run on GPU, the default device is GPU.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

return pos, R1
jitted_fn = jax.jit(fn)
x_gpu = jax.device_put(jnp.zeros((2,3)), jax.local_devices()[0])
jitted_fn(x_gpu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't hurt to also validate the expected output.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. My style is a little focus on the specific GPU issue. While the suggestion sounds good for the test itself.

@jakevdp jakevdp self-assigned this Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants