Skip to content

Commit

Permalink
[Mosaic GPU] Enable x64 tests for mosaic gpu.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715496496
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 14, 2025
1 parent 57a259f commit ff5cb81
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ jax_multiplatform_test(
enable_backends = [],
enable_configs = [
"gpu_h100_x32",
"gpu_h100",
],
env = {
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
Expand Down
4 changes: 2 additions & 2 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
np.testing.assert_array_equal(result, ref)

def test_gmem_to_smem_with_multiple_smem_indexers(self):
x = jax.random.uniform(jax.random.key(0), (2, 64, 64))
x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32)
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32),
Expand All @@ -392,7 +392,7 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
np.testing.assert_array_equal(extract_x0(x), x[0])

def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self):
x = jnp.arange(512 * 512).reshape(512, 512)
x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512)
@functools.partial(
pl.pallas_call,
grid=(4, 4),
Expand Down

0 comments on commit ff5cb81

Please sign in to comment.