From c1bb219919f549c6d81da1699b7ea4c1e5c5cfdb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 15 Dec 2023 23:45:03 +0000 Subject: [PATCH] unfuse softmax --- mlc_llm/relax_model/mixtral.py | 35 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index 16839178cd..ea4b112780 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -7,10 +7,6 @@ from .llama import MixtralConfig, Linear -def get_top2_func(dtype, index_dtype): - assert index_dtype == "int32" - - class MoELinear(nn.Module): def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False): assert not bias, "bias not supported" @@ -251,9 +247,9 @@ def get_flattened_expert_indices_scheduled( ) T.writes(flattened_expert_indices[:]) expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local") - if ( - vi == 0 and cumsum_colwise_flattened[vi] > 0 - ) or cumsum_colwise_flattened[vi] != cumsum_colwise_flattened[vi - 1]: + if cumsum_colwise_flattened[vi] > T.if_then_else( + vi == 0, T.int32(0), cumsum_colwise_flattened[vi - 1] + ): idx: T.SizeVar("idx", "int32") = cumsum_colwise_flattened[vi] - 1 instance_id: T.SizeVar("instance_id", "int32") = T.truncmod( vi, batch_size @@ -328,11 +324,11 @@ def get_expert_instance_indptr( ) ) - def topk_softmax(self, x, k): + def topk(self, x, k): index_dtype = "int32" @T.prim_func - def top2_softmax( + def top2_func( x_handle: T.handle, out_handle: T.handle, out_index_handle: T.handle, @@ -366,16 +362,13 @@ def top2_softmax( for j in T.unroll(2): with T.block("output"): vj = T.axis.remap("S", [j]) - out[vi, vj] = T.exp(local_top_k[vj] - local_top_k[0]) / ( - T.exp(local_top_k[0] - local_top_k[0]) - + T.exp(local_top_k[1] - local_top_k[0]) - ) + out[vi, vj] = local_top_k[vj] out_index[vi, vj] = local_top_k_index[vj] if k != 2: raise NotImplementedError("only support num_experts_per_token=2 for now") bb = relax.BlockBuilder.current() - gvar = bb.add_func(top2_softmax, "top2_softmax") + gvar = bb.add_func(top2_func, "top2") return bb.emit( relax.call_tir( gvar, @@ -393,9 +386,19 @@ def forward(self, hidden_states): # reshape to 2D hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size))) - gate = self.gate(hidden_states) + router_logits = self.gate(hidden_states) + + if router_logits.struct_info.dtype != "float32": + router_logits = nn.emit(relax.op.astype(router_logits, "float32")) + expert_weights = nn.emit(relax.op.nn.softmax(router_logits)) + if expert_weights.struct_info.dtype != self.dtype: + expert_weights = nn.emit(relax.op.astype(expert_weights, self.dtype)) + + expert_weights, expert_indices = self.topk(expert_weights, k=self.num_experts_per_tok) + expert_weights = nn.emit( + relax.op.divide(expert_weights, relax.op.sum(expert_weights, axis=1, keepdims=True)) + ) - expert_weights, expert_indices = self.topk_softmax(gate, k=self.num_experts_per_tok) expert_mask = self.topk_mask(expert_indices) mask_T_flattened = nn.emit(relax.op.flatten(relax.op.permute_dims(expert_mask)))