diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 12a2ad49306c..7b40b4e5a740 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -17,6 +17,7 @@ from collections.abc import Sequence import functools import itertools +import math import sys from typing import Any import unittest @@ -1410,12 +1411,39 @@ def f(x_ref, o_ref): np.testing.assert_allclose(f(x), expected) @parameterized.product( - size=[16, 32, 64, 128, 256], + lhs_and_rhs_shape=[ + ((16, 16), (16, 16)), + ((32, 32), (32, 32)), + ((64, 64), (64, 64)), + ((128, 128), (128, 128)), + ((256, 256), (256, 256)), + ((8, 128), (128, 256)), + ((8, 128), (256, 128)), + ((8, 256), (256, 128)), + ((16, 128), (128, 256)), + ((16, 128), (256, 128)), + ((16, 256), (256, 128)), + ((128, 8), (128, 256)), + ((128, 8), (256, 128)), + ((256, 8), (256, 128)), + ((128, 16), (128, 256)), + ((128, 16), (256, 128)), + ((256, 16), (256, 128)), + ], dtype=[jnp.float32, jnp.float16, jnp.bfloat16], trans_x=[False, True], trans_y=[False, True], ) - def test_dot(self, size, dtype, trans_x, trans_y): + def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): + lhs_shape, rhs_shape = lhs_and_rhs_shape + + final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape + final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape + if final_lhs_shape[1] != final_rhs_shape[0]: + self.skipTest("Contraction dimensions do not match") + + out_shape = (final_lhs_shape[0], final_rhs_shape[1]) + if jtu.test_device_matches(["tpu"]): if dtype == jnp.float16: self.skipTest("float16 type is not supported on TPU") @@ -1427,12 +1455,17 @@ def test_dot(self, size, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if size > 128: + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): self.skipTest("Shared memory size limit exceeded") + if any(x < 16 for x in lhs_shape) or any(x < 16 for x in rhs_shape): + self.skipTest("All dimensions of lhs and rhs must be >= 16") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((size, size), dtype), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), grid=1, ) def dot(x_ref, y_ref, o_ref): @@ -1441,8 +1474,8 @@ def dot(x_ref, y_ref, o_ref): o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (size, size), dtype=dtype) - y = random.normal(k2, (size, size), dtype=dtype) + x = random.normal(k1, lhs_shape, dtype=dtype) + y = random.normal(k2, rhs_shape, dtype=dtype) out = dot(x, y) expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) np.testing.assert_allclose(