diff --git a/mlc_llm/core.py b/mlc_llm/core.py index f2b8073192..18569b3063 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -784,7 +784,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: mod_deploy ) ) - if not args.enable_batching: + if not args.enable_batching and target_kind != "cuda": mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) if args.debug_load_script: diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 5b9c0a9e3c..ec1539d530 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1497,11 +1497,6 @@ def get_model(args, hf_config): param_manager = ParamManager(keep_params_after_load) bb = relax.BlockBuilder() - if isinstance(config, MixtralConfig): - from .mixtral import emit_tir_funcs - - emit_tir_funcs(bb, config) - if sep_embed: create_embed_func(bb, param_manager, config, args.quantization) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 9f4b57b617..0b48595eb4 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -663,11 +663,6 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") - if isinstance(config, MixtralConfig): - from .mixtral import emit_tir_funcs - - emit_tir_funcs(bb, config) - create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index 4c35fd9abc..ea4b112780 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -3,75 +3,10 @@ from tvm.relax.testing import nn from tvm.script import relax as R, tir as T from tvm import relax +from tvm.relax.frontend.nn import Tensor from .llama import MixtralConfig, Linear -def get_scatter_func(dtype): - @T.prim_func - def scatter_func( - x_handle: T.handle, - indices_handle: T.handle, - out_handle: T.handle, - ) -> None: - total_rows = T.int64() - hidden_size = T.int64() - x = T.match_buffer(x_handle, (total_rows, hidden_size), dtype) - indices = T.match_buffer(indices_handle, (total_rows,), "int32") - out = T.match_buffer(out_handle, (total_rows, hidden_size), dtype) - T.func_attr({"global_symbol": "scatter", "tir.noalias": True}) - for i in range(total_rows): - for j in range(hidden_size): - with T.block("scatter"): - vi, vj = T.axis.remap("SS", [i, j]) - out[indices[vi], vj] = x[vi, vj] - - return scatter_func - - -def get_top2_func(dtype, index_dtype): - assert index_dtype == "int32" - - @T.prim_func - def top2_func( - x_handle: T.handle, - out_handle: T.handle, - out_index_handle: T.handle, - ) -> None: - total_rows = T.int64() - num_experts = T.int64() - x = T.match_buffer(x_handle, (total_rows, num_experts), dtype) - out = T.match_buffer(out_handle, (total_rows, 2), dtype) - out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype) - local_top_k = T.alloc_buffer((2,), dtype=dtype, scope="local") - local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local") - T.func_attr({"global_symbol": "top2", "tir.noalias": True, "tir.is_scheduled": True}) - for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"): - for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"): - if io * T.int64(1024) + ii < total_rows: - local_top_k[0] = T.min_value(dtype) - local_top_k_index[0] = 0 - for k in range(num_experts): - if x[io * T.int64(1024) + ii, k] > local_top_k[0]: - local_top_k[1] = local_top_k[0] - local_top_k_index[1] = local_top_k_index[0] - local_top_k[0] = x[io * T.int64(1024) + ii, k] - local_top_k_index[0] = k - elif x[io * T.int64(1024) + ii, k] > local_top_k[1]: - local_top_k[1] = x[io * T.int64(1024) + ii, k] - local_top_k_index[1] = k - - for k in T.unroll(2): - out[io * T.int64(1024) + ii, k] = local_top_k[k] - out_index[io * T.int64(1024) + ii, k] = local_top_k_index[k] - - return top2_func - - -def emit_tir_funcs(bb: relax.BlockBuilder, config: MixtralConfig): - bb.add_func(get_scatter_func(config.dtype), "scatter") - bb.add_func(get_top2_func(config.dtype, "int32"), "top2") - - class MoELinear(nn.Module): def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False): assert not bias, "bias not supported" @@ -179,7 +114,6 @@ def __init__(self, config: MixtralConfig): ) def forward(self, hidden_states: relax.Expr, rows_before: relax.Expr): - # TODO: disco if self.combine_matmul: gate_up_results = nn.emit( relax.op.split( @@ -209,60 +143,31 @@ def __init__(self, config: MixtralConfig): ) self.num_experts_per_tok = config.num_experts_per_tok self.num_experts = config.num_local_experts - - def topk(self, x, is_ascend, index_dtype, k=-1): - if not is_ascend and k == 2: - # fast path - total_rows = x.struct_info.shape[0] - result = nn.emit( - relax.call_dps_packed( - "top2", - [x], - out_sinfo=[ - relax.TensorStructInfo([total_rows, k], x.struct_info.dtype), - relax.TensorStructInfo([total_rows, k], index_dtype), - ], - ) - ) - return relax.TupleGetItem(result, 0), relax.TupleGetItem(result, 1) - - # topk along axis -1 - result = nn.emit( - relax.call_dps_packed( - "tvm.contrib.thrust.sort_dps", - [x, is_ascend], - out_sinfo=[ - x.struct_info, - relax.TensorStructInfo(x.struct_info.shape, index_dtype), - ], - ) - ) - sorted_x = relax.TupleGetItem(result, 0) - indices = relax.TupleGetItem(result, 1) - if k != -1: - ndim = len(x.struct_info.shape) - beg = [0] * ndim - end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] - axes = list(range(ndim)) - sorted_x = nn.emit( - relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True) - ) - indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True)) - return sorted_x, indices - - def compute_rows_before(self, sorted_expert_ids): - return nn.emit( - relax.call_dps_packed( - "moe_compute_rows_before", - [sorted_expert_ids], - out_sinfo=relax.TensorStructInfo([self.num_experts], "int64"), - ) - ) + self.dtype = config.dtype + self.hidden_size = config.hidden_size def scatter(self, linear_out, indices): + @T.prim_func + def scatter_func( + x_handle: T.handle, + indices_handle: T.handle, + out_handle: T.handle, + ) -> None: + total_rows = T.int64() + x = T.match_buffer(x_handle, (total_rows, self.hidden_size), self.dtype) + indices = T.match_buffer(indices_handle, (total_rows,), "int32") + out = T.match_buffer(out_handle, (total_rows, self.hidden_size), self.dtype) + T.func_attr({"global_symbol": "scatter", "tir.noalias": True}) + for i in range(total_rows): + for j in range(self.hidden_size): + with T.block("scatter"): + vi, vj = T.axis.remap("SS", [i, j]) + out[indices[vi], vj] = x[vi, vj] + + scatter = relax.BlockBuilder.current().add_func(scatter_func, "scatter") return nn.emit( relax.call_dps_packed( - "scatter", + scatter, [linear_out, indices], out_sinfo=linear_out.struct_info, ) @@ -279,29 +184,233 @@ def te_compute(x): return nn.emit_te(te_compute, indices) + def topk_mask(self, indices): + from functools import reduce + + def te_topk_mask_op(topk_indices): + ntokens = topk_indices.shape[0] + assert topk_indices.shape[1] == self.num_experts_per_tok + return te.compute( + (ntokens, self.num_experts), + lambda i, j: tir.expr.Select( + reduce( + lambda a, b: tir.Or(a, b), + [topk_indices[i, k] == j for k in range(self.num_experts_per_tok)], + ), + true_value=tir.const(1, "int32"), + false_value=tir.const(0, "int32"), + ), + ) + + return nn.emit_te(te_topk_mask_op, indices) + + def get_indices( + self, cumsum_colwise_flattened: relax.Expr, expert_indices: relax.Expr + ) -> relax.Expr: + from tvm import relax + from tvm.script import tir as T + + @T.prim_func + def get_flattened_expert_indices_scheduled( + var_cumsum_colwise_flattened: T.handle, + var_expert_indices: T.handle, + var_flattened_expert_indices: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + batch_size = T.SizeVar("batch_size", "int32") + cumsum_flattened_length = T.SizeVar("cumsum_flattened_length", "int32") + + cumsum_colwise_flattened = T.match_buffer( + var_cumsum_colwise_flattened, shape=[cumsum_flattened_length], dtype="int32" + ) + expert_indices = T.match_buffer( + var_expert_indices, shape=[batch_size, self.num_experts_per_tok], dtype="int32" + ) + flattened_expert_indices = T.match_buffer( + var_flattened_expert_indices, + shape=[batch_size * self.num_experts_per_tok], + dtype="int32", + ) + + for io in T.thread_binding( + 0, T.ceildiv(cumsum_flattened_length, T.int32(1024)), "blockIdx.x" + ): + for ii in T.thread_binding( + 0, T.min(cumsum_flattened_length, T.int32(1024)), "threadIdx.x" + ): + with T.block("get_indices"): + vi = T.axis.spatial(cumsum_flattened_length, io * T.int32(1024) + ii) + T.where(io * T.int32(1024) + ii < cumsum_flattened_length) + T.reads( + cumsum_colwise_flattened[vi - 1 : vi - 1 + 2], + expert_indices[:, 0 : self.num_experts_per_tok], + ) + T.writes(flattened_expert_indices[:]) + expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local") + if cumsum_colwise_flattened[vi] > T.if_then_else( + vi == 0, T.int32(0), cumsum_colwise_flattened[vi - 1] + ): + idx: T.SizeVar("idx", "int32") = cumsum_colwise_flattened[vi] - 1 + instance_id: T.SizeVar("instance_id", "int32") = T.truncmod( + vi, batch_size + ) + expert_id: T.SizeVar("expert_id", "int32") = T.truncdiv(vi, batch_size) + for j in T.serial(0, self.num_experts_per_tok): + with T.block("select_expert"): + vj = T.axis.spatial(self.num_experts_per_tok, j) + vinstance_id = T.axis.spatial(batch_size, instance_id) + vexpert_id = T.axis.spatial( + T.truncdiv(cumsum_flattened_length, batch_size), expert_id + ) + if expert_indices[vinstance_id, vj] == vexpert_id: + expert_idx[()] = vj + flattened_expert_indices[idx] = ( + instance_id * self.num_experts_per_tok + expert_idx[()] + ) + + bb = relax.BlockBuilder.current() + gvar = bb.add_func(get_flattened_expert_indices_scheduled, "get_flattened_expert_indices") + return bb.emit( + relax.call_tir( + gvar, + [cumsum_colwise_flattened, expert_indices], + out_sinfo=relax.TensorStructInfo( + [expert_indices.struct_info.shape[0] * self.num_experts_per_tok], "int32" + ), + ) + ) + + def cumsum(self, data: relax.Expr) -> relax.Expr: + return nn.emit( + relax.call_dps_packed( + "tvm.contrib.thrust.sum_scan", + [data], + out_sinfo=data.struct_info, + ) + ) + + def get_indptr(self, cumsum_colwise_flattened: relax.Expr) -> relax.Expr: + from tvm import relax + from tvm.script import tir as T + + @T.prim_func + def get_expert_instance_indptr( + var_cumsum_colwise_flattened: T.handle, + var_expert_instance_indptr: T.handle, + batch_size: T.int32, + ): + cumsum_colwise_flattened = T.match_buffer( + var_cumsum_colwise_flattened, shape=[batch_size * self.num_experts], dtype="int32" + ) + expert_instance_indptr = T.match_buffer( + var_expert_instance_indptr, shape=[self.num_experts], dtype="int64" + ) + + for expert_id in range(self.num_experts): + with T.block("indptr"): + vexpert_id = T.axis.spatial(self.num_experts, expert_id) + expert_instance_indptr[vexpert_id] = T.cast( + cumsum_colwise_flattened[(vexpert_id + 1) * batch_size - 1], "int64" + ) + + bb = relax.BlockBuilder.current() + gvar = bb.add_func(get_expert_instance_indptr, "get_expert_instance_indptr") + return bb.emit( + relax.call_tir( + gvar, + [cumsum_colwise_flattened], + out_sinfo=relax.TensorStructInfo([self.num_experts], "int64"), + tir_vars=[cumsum_colwise_flattened.struct_info.shape[0] // self.num_experts], + ) + ) + + def topk(self, x, k): + index_dtype = "int32" + + @T.prim_func + def top2_func( + x_handle: T.handle, + out_handle: T.handle, + out_index_handle: T.handle, + ) -> None: + total_rows = T.int64() + x = T.match_buffer(x_handle, (total_rows, self.num_experts), self.dtype) + out = T.match_buffer(out_handle, (total_rows, 2), self.dtype) + out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype) + local_top_k = T.alloc_buffer((2,), dtype=self.dtype, scope="local") + local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local") + T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) + for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"): + for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"): + with T.block("top2"): + vi = T.axis.spatial(total_rows, io * T.int64(1024) + ii) + T.where(io * T.int64(1024) + ii < total_rows) + with T.block("init"): + local_top_k[0] = T.min_value(self.dtype) + local_top_k_index[0] = 0 + for k in range(self.num_experts): + with T.block("update"): + vk = T.axis.remap("S", [k]) + if x[vi, vk] > local_top_k[0]: + local_top_k[1] = local_top_k[0] + local_top_k_index[1] = local_top_k_index[0] + local_top_k[0] = x[vi, vk] + local_top_k_index[0] = vk + elif x[vi, vk] > local_top_k[1]: + local_top_k[1] = x[vi, vk] + local_top_k_index[1] = vk + for j in T.unroll(2): + with T.block("output"): + vj = T.axis.remap("S", [j]) + out[vi, vj] = local_top_k[vj] + out_index[vi, vj] = local_top_k_index[vj] + + if k != 2: + raise NotImplementedError("only support num_experts_per_token=2 for now") + bb = relax.BlockBuilder.current() + gvar = bb.add_func(top2_func, "top2") + return bb.emit( + relax.call_tir( + gvar, + [x], + out_sinfo=[ + relax.TensorStructInfo([x.struct_info.shape[0], k], x.struct_info.dtype), + relax.TensorStructInfo([x.struct_info.shape[0], k], index_dtype), + ], + ) + ) + def forward(self, hidden_states): hidden_states_shape = hidden_states.struct_info.shape hidden_size = hidden_states_shape[-1] # reshape to 2D hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size))) - gate = self.gate(hidden_states) - scores = nn.emit(relax.op.nn.softmax(gate, axis=-1)) + router_logits = self.gate(hidden_states) + + if router_logits.struct_info.dtype != "float32": + router_logits = nn.emit(relax.op.astype(router_logits, "float32")) + expert_weights = nn.emit(relax.op.nn.softmax(router_logits)) + if expert_weights.struct_info.dtype != self.dtype: + expert_weights = nn.emit(relax.op.astype(expert_weights, self.dtype)) - expert_weights, expert_indices = self.topk( - scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32" - ) # (num_tokens, top_k), (num_tokens, top_k) - expert_weights = nn.emit(expert_weights / R.sum(expert_weights, axis=-1, keepdims=True)) - flattened_indices = nn.emit(relax.op.flatten(expert_indices)) - sorted_expert_ids, indices = self.topk( - flattened_indices, is_ascend=True, index_dtype="int32" + expert_weights, expert_indices = self.topk(expert_weights, k=self.num_experts_per_tok) + expert_weights = nn.emit( + relax.op.divide(expert_weights, relax.op.sum(expert_weights, axis=1, keepdims=True)) ) - rows_before = self.compute_rows_before(sorted_expert_ids) - token_indices = self.get_token_indices(indices) + expert_mask = self.topk_mask(expert_indices) + mask_T_flattened = nn.emit(relax.op.flatten(relax.op.permute_dims(expert_mask))) + + cumsum_colwise_flattened = self.cumsum(mask_T_flattened) + flattened_indices = self.get_indices(cumsum_colwise_flattened, expert_indices) + indptr = self.get_indptr(cumsum_colwise_flattened) + token_indices = self.get_token_indices(flattened_indices) + gathered_x = nn.emit(relax.op.take(hidden_states, token_indices, axis=0)) - linear_out = self.experts(gathered_x, rows_before) - unpermuted = self.scatter(linear_out, indices) + linear_out = self.experts(gathered_x, indptr) + unpermuted = self.scatter(linear_out, flattened_indices) + unflattened = nn.emit( relax.op.reshape(unpermuted, (-1, self.num_experts_per_tok, hidden_size)) )