diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 44f11e9512a1..667a65a1c351 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -173,6 +173,7 @@ jax_multiplatform_test( enable_backends = [], enable_configs = [ "gpu_h100_x32", + "gpu_h100", ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e5d8d1077623..92806688d8ac 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -373,7 +373,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(result, ref) def test_gmem_to_smem_with_multiple_smem_indexers(self): - x = jax.random.uniform(jax.random.key(0), (2, 64, 64)) + x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), @@ -392,7 +392,7 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): - x = jnp.arange(512 * 512).reshape(512, 512) + x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( pl.pallas_call, grid=(4, 4),