Skip to content

Commit

Permalink
Add JAX unit test for Shardy which causes the compiler to introduce t…
Browse files Browse the repository at this point in the history
…he `mlir::tensor::TensorDialect`. This was causing the compiler to crash.

PiperOrigin-RevId: 714896947
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 13, 2025
1 parent 91ffb64 commit c14e5b4
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6813,6 +6813,19 @@ def test_dimension_sharding_repr(self):
self.assertEqual(repr(dim_sharding),
"SdyDimSharding({'data', 'model', ?}p2)")

def test_tensor_dialect(self):
# While this doesn't emit any `mlir::TensorDialect` ops, some pass in the
# compiler pipeline is temporarily introducing it before then discarding it
# again. Make sure this doesn't crash.
mesh = jtu.create_mesh((2,), ('x'))
in_sds = jax.ShapeDtypeStruct((4, 8), jnp.float32)

@partial(jax.jit, out_shardings=NamedSharding(mesh, P('x')))
def gen_dummy_inputs():
return tuple(jax.random.normal(jax.random.key(42), shape=in_sds.shape
).astype(in_sds.dtype))
gen_dummy_inputs() # doesn't crash


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c14e5b4

Please sign in to comment.