diff --git a/stk/backend/triton_kernels.py b/stk/backend/triton_kernels.py index a116c23..5426415 100644 --- a/stk/backend/triton_kernels.py +++ b/stk/backend/triton_kernels.py @@ -15,12 +15,12 @@ def _sdd_kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - row_indices, column_indices, + row_indices, column_indices, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, ): # matrix multiplication - pid = tl.program_id(0) + pid = tl.program_id(0) pid_m = tl.load(row_indices + pid) pid_n = tl.load(column_indices + pid) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -68,11 +68,11 @@ def _dsd_kernel(A, B, C, M, N, K, # matrix multiplication pid_m = tl.program_id(0) pid_n = tl.program_id(1) - + num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - + start_inx = tl.load(offsets + pid_m) end_inx = tl.load(offsets + pid_m + 1) @@ -109,7 +109,7 @@ def _dsd_kernel(A, B, C, M, N, K, a = tl.load(ptr_A) b = tl.load(ptr_B) - acc += tl.dot(a, b) + acc += tl.dot(a, b) acc = acc.to(C.dtype.element_ty) @@ -140,11 +140,11 @@ def _dds_kernel(A, B, C, M, N, K, # matrix multiplication pid_m = tl.program_id(0) pid_n = tl.program_id(1) - + num_pid_m = tl.num_programs(0) num_pid_n = tl.num_programs(1) pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - + start_inx = tl.load(offsets + pid_n) end_inx = tl.load(offsets + pid_n + 1) @@ -164,7 +164,7 @@ def _dds_kernel(A, B, C, M, N, K, nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - + ak_sub_incr = BLOCK_K * stride_ak ak_block_incr = BLOCK_SIZE * stride_ak bk_sub_incr = BLOCK_K * stride_bk @@ -181,7 +181,7 @@ def _dds_kernel(A, B, C, M, N, K, ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr a = tl.load(ptr_A) b = tl.load(ptr_B) - acc += tl.dot(a, b) + acc += tl.dot(a, b) acc = acc.to(C.dtype.element_ty) cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -213,7 +213,7 @@ def dsd(shape, assert shape[1] == rhs.shape[0], "incompatible dimensions" M, K = shape _, N = rhs.shape - + # accumulator types ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 @@ -230,7 +230,7 @@ def dsd(shape, a_column_indices, a_offsets = column_indices_t, offsets_t if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) _dsd_kernel[grid]( data.data, rhs, out, M, N, K, @@ -267,7 +267,7 @@ def dds(lhs, assert lhs.shape[1] == shape[0], "incompatible dimensions" M, K = lhs.shape _, N = shape - + # accumulator types ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 @@ -280,7 +280,7 @@ def dds(lhs, grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) if trans_B: stride_bk, stride_bn = data.stride(2), data.stride(1) b_column_indices, b_offsets = column_indices, offsets @@ -331,7 +331,7 @@ def sdd(lhs, if trans_A: stride_am, stride_ak = lhs.stride(1), lhs.stride(0) if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) _sdd_kernel[grid]( lhs, rhs, out, M, N, K, @@ -355,4 +355,3 @@ def row_indices( ): block_rows = len(offsets) - 1 _row_indices_kernel[(block_rows, )](offsets, out) - \ No newline at end of file diff --git a/stk/matrix.py b/stk/matrix.py index 127b06f..80f4226 100644 --- a/stk/matrix.py +++ b/stk/matrix.py @@ -158,6 +158,13 @@ def __init__(self, self._transposed = False + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") def validate(self): _validate_matrix(self._size,