Skip to content

Commit

Permalink
Add fast top 2 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Dec 14, 2023
1 parent 434362c commit 108cb10
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,48 @@ def scatter_func(
return scatter_func


def get_top2_func(dtype, index_dtype):
assert index_dtype == "int32"

@T.prim_func
def top2_func(
x_handle: T.handle,
out_handle: T.handle,
out_index_handle: T.handle,
) -> None:
total_rows = T.int64()
num_experts = T.int64()
x = T.match_buffer(x_handle, (total_rows, num_experts), dtype)
out = T.match_buffer(out_handle, (total_rows, 2), dtype)
out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype)
local_top_k = T.alloc_buffer((2,), dtype=dtype, scope="local")
local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local")
T.func_attr({"global_symbol": "top2", "tir.noalias": True, "tir.is_scheduled": True})
for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"):
for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"):
if io * T.int64(1024) + ii < total_rows:
local_top_k[0] = T.min_value(dtype)
local_top_k_index[0] = 0
for k in range(num_experts):
if x[io * T.int64(1024) + ii, k] > local_top_k[0]:
local_top_k[1] = local_top_k[0]
local_top_k_index[1] = local_top_k_index[0]
local_top_k[0] = x[io * T.int64(1024) + ii, k]
local_top_k_index[0] = k
elif x[io * T.int64(1024) + ii, k] > local_top_k[1]:
local_top_k[1] = x[io * T.int64(1024) + ii, k]
local_top_k_index[1] = k

for k in T.unroll(2):
out[io * T.int64(1024) + ii, k] = local_top_k[k]
out_index[io * T.int64(1024) + ii, k] = local_top_k_index[k]

return top2_func


def emit_tir_funcs(bb: relax.BlockBuilder, config: MixtralConfig):
bb.add_func(get_scatter_func(config.dtype), "scatter")
bb.add_func(get_top2_func(config.dtype, "int32"), "top2")


class MoELinear(nn.Module):
Expand Down Expand Up @@ -171,6 +211,21 @@ def __init__(self, config: MixtralConfig):
self.num_experts = config.num_local_experts

def topk(self, x, is_ascend, index_dtype, k=-1):
if not is_ascend and k == 2:
# fast path
total_rows = x.struct_info.shape[0]
result = nn.emit(
relax.call_dps_packed(
"top2",
[x],
out_sinfo=[
relax.TensorStructInfo([total_rows, k], x.struct_info.dtype),
relax.TensorStructInfo([total_rows, k], index_dtype),
],
)
)
return relax.TupleGetItem(result, 0), relax.TupleGetItem(result, 1)

# topk along axis -1
result = nn.emit(
relax.call_dps_packed(
Expand Down

0 comments on commit 108cb10

Please sign in to comment.