Skip to content

Commit

Permalink
unfuse softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Dec 17, 2023
1 parent 724fc62 commit c1bb219
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from .llama import MixtralConfig, Linear


def get_top2_func(dtype, index_dtype):
assert index_dtype == "int32"


class MoELinear(nn.Module):
def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False):
assert not bias, "bias not supported"
Expand Down Expand Up @@ -251,9 +247,9 @@ def get_flattened_expert_indices_scheduled(
)
T.writes(flattened_expert_indices[:])
expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local")
if (
vi == 0 and cumsum_colwise_flattened[vi] > 0
) or cumsum_colwise_flattened[vi] != cumsum_colwise_flattened[vi - 1]:
if cumsum_colwise_flattened[vi] > T.if_then_else(
vi == 0, T.int32(0), cumsum_colwise_flattened[vi - 1]
):
idx: T.SizeVar("idx", "int32") = cumsum_colwise_flattened[vi] - 1
instance_id: T.SizeVar("instance_id", "int32") = T.truncmod(
vi, batch_size
Expand Down Expand Up @@ -328,11 +324,11 @@ def get_expert_instance_indptr(
)
)

def topk_softmax(self, x, k):
def topk(self, x, k):
index_dtype = "int32"

@T.prim_func
def top2_softmax(
def top2_func(
x_handle: T.handle,
out_handle: T.handle,
out_index_handle: T.handle,
Expand Down Expand Up @@ -366,16 +362,13 @@ def top2_softmax(
for j in T.unroll(2):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = T.exp(local_top_k[vj] - local_top_k[0]) / (
T.exp(local_top_k[0] - local_top_k[0])
+ T.exp(local_top_k[1] - local_top_k[0])
)
out[vi, vj] = local_top_k[vj]
out_index[vi, vj] = local_top_k_index[vj]

if k != 2:
raise NotImplementedError("only support num_experts_per_token=2 for now")
bb = relax.BlockBuilder.current()
gvar = bb.add_func(top2_softmax, "top2_softmax")
gvar = bb.add_func(top2_func, "top2")
return bb.emit(
relax.call_tir(
gvar,
Expand All @@ -393,9 +386,19 @@ def forward(self, hidden_states):
# reshape to 2D
hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size)))

gate = self.gate(hidden_states)
router_logits = self.gate(hidden_states)

if router_logits.struct_info.dtype != "float32":
router_logits = nn.emit(relax.op.astype(router_logits, "float32"))
expert_weights = nn.emit(relax.op.nn.softmax(router_logits))
if expert_weights.struct_info.dtype != self.dtype:
expert_weights = nn.emit(relax.op.astype(expert_weights, self.dtype))

expert_weights, expert_indices = self.topk(expert_weights, k=self.num_experts_per_tok)
expert_weights = nn.emit(
relax.op.divide(expert_weights, relax.op.sum(expert_weights, axis=1, keepdims=True))
)

expert_weights, expert_indices = self.topk_softmax(gate, k=self.num_experts_per_tok)
expert_mask = self.topk_mask(expert_indices)
mask_T_flattened = nn.emit(relax.op.flatten(relax.op.permute_dims(expert_mask)))

Expand Down

0 comments on commit c1bb219

Please sign in to comment.