Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test case for jit compile of vmap on gpu #25980

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/jax_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,26 @@ def f(a, b, c):
jitted_f = jax.jit(f)
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))

def test_jit_compile_vmap_on_gpu(self):
codinglover222 marked this conversation as resolved.
Show resolved Hide resolved
if jtu.device_under_test() == "cpu":
self.skipTest("Test only runs on GPU")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the skip? no harm in running this test on CPU as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. The initial thought is to emphasize the issue on GPU. But I'm ok to allow tests on other device types as well.


@jax.vmap
def fn(x):
R1 = jnp.array([[x[0], 0, 0],
[0, x[0], 0],
[0, 0, x[0]]])
R2 = jnp.array([[x[0], 0, 0],
[0, x[1], 0],
[0, 0, x[2]]])
H = jnp.eye(4)
H = H.at[:3, :3].set(R2.T)
pos = H @ jnp.concatenate([x, jnp.array([1.0])])
return pos, R1
jitted_fn = jax.jit(fn)
x_gpu = jax.device_put(jnp.zeros((2,3)), jax.local_devices()[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for explicit device put I think: when the tests are run on GPU, the default device is GPU.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

jitted_fn(x_gpu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't hurt to also validate the expected output.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. My style is a little focus on the specific GPU issue. While the suggestion sounds good for the test itself.



if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())