diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index d96c2703a6..4c35fd9abc 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -28,8 +28,48 @@ def scatter_func( return scatter_func +def get_top2_func(dtype, index_dtype): + assert index_dtype == "int32" + + @T.prim_func + def top2_func( + x_handle: T.handle, + out_handle: T.handle, + out_index_handle: T.handle, + ) -> None: + total_rows = T.int64() + num_experts = T.int64() + x = T.match_buffer(x_handle, (total_rows, num_experts), dtype) + out = T.match_buffer(out_handle, (total_rows, 2), dtype) + out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype) + local_top_k = T.alloc_buffer((2,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local") + T.func_attr({"global_symbol": "top2", "tir.noalias": True, "tir.is_scheduled": True}) + for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"): + for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"): + if io * T.int64(1024) + ii < total_rows: + local_top_k[0] = T.min_value(dtype) + local_top_k_index[0] = 0 + for k in range(num_experts): + if x[io * T.int64(1024) + ii, k] > local_top_k[0]: + local_top_k[1] = local_top_k[0] + local_top_k_index[1] = local_top_k_index[0] + local_top_k[0] = x[io * T.int64(1024) + ii, k] + local_top_k_index[0] = k + elif x[io * T.int64(1024) + ii, k] > local_top_k[1]: + local_top_k[1] = x[io * T.int64(1024) + ii, k] + local_top_k_index[1] = k + + for k in T.unroll(2): + out[io * T.int64(1024) + ii, k] = local_top_k[k] + out_index[io * T.int64(1024) + ii, k] = local_top_k_index[k] + + return top2_func + + def emit_tir_funcs(bb: relax.BlockBuilder, config: MixtralConfig): bb.add_func(get_scatter_func(config.dtype), "scatter") + bb.add_func(get_top2_func(config.dtype, "int32"), "top2") class MoELinear(nn.Module): @@ -171,6 +211,21 @@ def __init__(self, config: MixtralConfig): self.num_experts = config.num_local_experts def topk(self, x, is_ascend, index_dtype, k=-1): + if not is_ascend and k == 2: + # fast path + total_rows = x.struct_info.shape[0] + result = nn.emit( + relax.call_dps_packed( + "top2", + [x], + out_sinfo=[ + relax.TensorStructInfo([total_rows, k], x.struct_info.dtype), + relax.TensorStructInfo([total_rows, k], index_dtype), + ], + ) + ) + return relax.TupleGetItem(result, 0), relax.TupleGetItem(result, 1) + # topk along axis -1 result = nn.emit( relax.call_dps_packed(