Skip to content

Commit

Permalink
Add overflow detection based on matrix size.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Aug 31, 2023
1 parent 8aa616a commit f7ba880
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
29 changes: 14 additions & 15 deletions stk/backend/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def _sdd_kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
row_indices, column_indices,
row_indices, column_indices,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid = tl.program_id(0)
pid_m = tl.load(row_indices + pid)
pid_n = tl.load(column_indices + pid)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
Expand Down Expand Up @@ -68,11 +68,11 @@ def _dsd_kernel(A, B, C, M, N, K,
# matrix multiplication
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)

start_inx = tl.load(offsets + pid_m)
end_inx = tl.load(offsets + pid_m + 1)

Expand Down Expand Up @@ -109,7 +109,7 @@ def _dsd_kernel(A, B, C, M, N, K,

a = tl.load(ptr_A)
b = tl.load(ptr_B)
acc += tl.dot(a, b)
acc += tl.dot(a, b)

acc = acc.to(C.dtype.element_ty)

Expand Down Expand Up @@ -140,11 +140,11 @@ def _dds_kernel(A, B, C, M, N, K,
# matrix multiplication
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)

start_inx = tl.load(offsets + pid_n)
end_inx = tl.load(offsets + pid_n + 1)

Expand All @@ -164,7 +164,7 @@ def _dds_kernel(A, B, C, M, N, K,
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)

BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE

ak_sub_incr = BLOCK_K * stride_ak
ak_block_incr = BLOCK_SIZE * stride_ak
bk_sub_incr = BLOCK_K * stride_bk
Expand All @@ -181,7 +181,7 @@ def _dds_kernel(A, B, C, M, N, K,
ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
a = tl.load(ptr_A)
b = tl.load(ptr_B)
acc += tl.dot(a, b)
acc += tl.dot(a, b)

acc = acc.to(C.dtype.element_ty)
cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
Expand Down Expand Up @@ -213,7 +213,7 @@ def dsd(shape,
assert shape[1] == rhs.shape[0], "incompatible dimensions"
M, K = shape
_, N = rhs.shape

# accumulator types
ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32

Expand All @@ -230,7 +230,7 @@ def dsd(shape,
a_column_indices, a_offsets = column_indices_t, offsets_t

if trans_B:
stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)

_dsd_kernel[grid](
data.data, rhs, out, M, N, K,
Expand Down Expand Up @@ -267,7 +267,7 @@ def dds(lhs,
assert lhs.shape[1] == shape[0], "incompatible dimensions"
M, K = lhs.shape
_, N = shape

# accumulator types
ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32

Expand All @@ -280,7 +280,7 @@ def dds(lhs,
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))

if trans_A:
stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
if trans_B:
stride_bk, stride_bn = data.stride(2), data.stride(1)
b_column_indices, b_offsets = column_indices, offsets
Expand Down Expand Up @@ -331,7 +331,7 @@ def sdd(lhs,
if trans_A:
stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
if trans_B:
stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)

_sdd_kernel[grid](
lhs, rhs, out, M, N, K,
Expand All @@ -355,4 +355,3 @@ def row_indices(
):
block_rows = len(offsets) - 1
_row_indices_kernel[(block_rows, )](offsets, out)

7 changes: 7 additions & 0 deletions stk/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def __init__(self,

self._transposed = False

# Validate that our metadata will not overflow.
max_dim = np.iinfo(np.int16).max * self.blocking
if column_indices.dtype == torch.int16:
if size[0] > max_dim or size[1] > max_dim:
raise ValueError(
"Sparse matrix with shape {size} exceeds representable "
"size with 16-bit indices.")

def validate(self):
_validate_matrix(self._size,
Expand Down

0 comments on commit f7ba880

Please sign in to comment.