Skip to content

Commit

Permalink
[Pallas] Add non-square pl.dot test cases.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704440071
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Dec 9, 2024
1 parent 978d35f commit f457bb5
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Sequence
import functools
import itertools
import math
import sys
from typing import Any
import unittest
Expand Down Expand Up @@ -1410,12 +1411,39 @@ def f(x_ref, o_ref):
np.testing.assert_allclose(f(x), expected)

@parameterized.product(
size=[16, 32, 64, 128, 256],
lhs_and_rhs_shape=[
((16, 16), (16, 16)),
((32, 32), (32, 32)),
((64, 64), (64, 64)),
((128, 128), (128, 128)),
((256, 256), (256, 256)),
((8, 128), (128, 256)),
((8, 128), (256, 128)),
((8, 256), (256, 128)),
((16, 128), (128, 256)),
((16, 128), (256, 128)),
((16, 256), (256, 128)),
((128, 8), (128, 256)),
((128, 8), (256, 128)),
((256, 8), (256, 128)),
((128, 16), (128, 256)),
((128, 16), (256, 128)),
((256, 16), (256, 128)),
],
dtype=[jnp.float32, jnp.float16, jnp.bfloat16],
trans_x=[False, True],
trans_y=[False, True],
)
def test_dot(self, size, dtype, trans_x, trans_y):
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
lhs_shape, rhs_shape = lhs_and_rhs_shape

final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape
if final_lhs_shape[1] != final_rhs_shape[0]:
self.skipTest("Contraction dimensions do not match")

out_shape = (final_lhs_shape[0], final_rhs_shape[1])

if jtu.test_device_matches(["tpu"]):
if dtype == jnp.float16:
self.skipTest("float16 type is not supported on TPU")
Expand All @@ -1427,12 +1455,17 @@ def test_dot(self, size, dtype, trans_x, trans_y):
if jtu.test_device_matches(["gpu"]):
if dtype == jnp.bfloat16:
self.skipTest("bfloat16 type are not supported on GPU")
if size > 128:
if (
math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape)
> (256 * 256) * 2
):
self.skipTest("Shared memory size limit exceeded")
if any(x < 16 for x in lhs_shape) or any(x < 16 for x in rhs_shape):
self.skipTest("All dimensions of lhs and rhs must be >= 16")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((size, size), dtype),
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
grid=1,
)
def dot(x_ref, y_ref, o_ref):
Expand All @@ -1441,8 +1474,8 @@ def dot(x_ref, y_ref, o_ref):
o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype)

k1, k2 = random.split(random.key(0))
x = random.normal(k1, (size, size), dtype=dtype)
y = random.normal(k2, (size, size), dtype=dtype)
x = random.normal(k1, lhs_shape, dtype=dtype)
y = random.normal(k2, rhs_shape, dtype=dtype)
out = dot(x, y)
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
np.testing.assert_allclose(
Expand Down

0 comments on commit f457bb5

Please sign in to comment.