Skip to content

Commit

Permalink
Switch to cumsum based impl
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Dec 15, 2023
1 parent 38327f1 commit d848568
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 105 deletions.
2 changes: 1 addition & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
mod_deploy
)
)
if not args.enable_batching:
if not args.enable_batching and target_kind != "cuda":
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)

if args.debug_load_script:
Expand Down
5 changes: 0 additions & 5 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,11 +1497,6 @@ def get_model(args, hf_config):
param_manager = ParamManager(keep_params_after_load)
bb = relax.BlockBuilder()

if isinstance(config, MixtralConfig):
from .mixtral import emit_tir_funcs

emit_tir_funcs(bb, config)

if sep_embed:
create_embed_func(bb, param_manager, config, args.quantization)

Expand Down
5 changes: 0 additions & 5 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,6 @@ def get_model(args, hf_config):
# The CPU device to copy the result of relax.op.max(seq_lens) to CPU.
cpu_dev = VDevice("llvm", 0, "global")

if isinstance(config, MixtralConfig):
from .mixtral import emit_tir_funcs

emit_tir_funcs(bb, config)

create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed)
create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed)
create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization)
Expand Down
181 changes: 87 additions & 94 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,9 @@
from .llama import MixtralConfig, Linear


def get_scatter_func(dtype):
@T.prim_func
def scatter_func(
x_handle: T.handle,
indices_handle: T.handle,
out_handle: T.handle,
) -> None:
total_rows = T.int64()
hidden_size = T.int64()
x = T.match_buffer(x_handle, (total_rows, hidden_size), dtype)
indices = T.match_buffer(indices_handle, (total_rows,), "int32")
out = T.match_buffer(out_handle, (total_rows, hidden_size), dtype)
T.func_attr({"global_symbol": "scatter", "tir.noalias": True})
for i in range(total_rows):
for j in range(hidden_size):
with T.block("scatter"):
vi, vj = T.axis.remap("SS", [i, j])
out[indices[vi], vj] = x[vi, vj]

return scatter_func


def get_top2_func(dtype, index_dtype):
assert index_dtype == "int32"

@T.prim_func
def top2_func(
x_handle: T.handle,
out_handle: T.handle,
out_index_handle: T.handle,
) -> None:
total_rows = T.int64()
num_experts = T.int64()
x = T.match_buffer(x_handle, (total_rows, num_experts), dtype)
out = T.match_buffer(out_handle, (total_rows, 2), dtype)
out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype)
local_top_k = T.alloc_buffer((2,), dtype=dtype, scope="local")
local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local")
T.func_attr({"global_symbol": "top2", "tir.noalias": True, "tir.is_scheduled": True})
for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"):
for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"):
if io * T.int64(1024) + ii < total_rows:
local_top_k[0] = T.min_value(dtype)
local_top_k_index[0] = 0
for k in range(num_experts):
if x[io * T.int64(1024) + ii, k] > local_top_k[0]:
local_top_k[1] = local_top_k[0]
local_top_k_index[1] = local_top_k_index[0]
local_top_k[0] = x[io * T.int64(1024) + ii, k]
local_top_k_index[0] = k
elif x[io * T.int64(1024) + ii, k] > local_top_k[1]:
local_top_k[1] = x[io * T.int64(1024) + ii, k]
local_top_k_index[1] = k

for k in T.unroll(2):
out[io * T.int64(1024) + ii, k] = local_top_k[k]
out_index[io * T.int64(1024) + ii, k] = local_top_k_index[k]

return top2_func


def emit_tir_funcs(bb: relax.BlockBuilder, config: MixtralConfig):
bb.add_func(get_scatter_func(config.dtype), "scatter")
bb.add_func(get_top2_func(config.dtype, "int32"), "top2")


class MoELinear(nn.Module):
def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False):
Expand Down Expand Up @@ -180,7 +118,6 @@ def __init__(self, config: MixtralConfig):
)

def forward(self, hidden_states: relax.Expr, rows_before: relax.Expr):
# TODO: disco
if self.combine_matmul:
gate_up_results = nn.emit(
relax.op.split(
Expand Down Expand Up @@ -210,6 +147,8 @@ def __init__(self, config: MixtralConfig):
)
self.num_experts_per_tok = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.dtype = config.dtype
self.hidden_size = config.hidden_size

def topk(self, x, is_ascend, index_dtype, k=-1):
if not is_ascend and k == 2:
Expand Down Expand Up @@ -251,19 +190,28 @@ def topk(self, x, is_ascend, index_dtype, k=-1):
indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True))
return sorted_x, indices

def compute_rows_before(self, sorted_expert_ids):
return nn.emit(
relax.call_dps_packed(
"moe_compute_rows_before",
[sorted_expert_ids],
out_sinfo=relax.TensorStructInfo([self.num_experts], "int64"),
)
)

def scatter(self, linear_out, indices):
@T.prim_func
def scatter_func(
x_handle: T.handle,
indices_handle: T.handle,
out_handle: T.handle,
) -> None:
total_rows = T.int64()
x = T.match_buffer(x_handle, (total_rows, self.hidden_size), self.dtype)
indices = T.match_buffer(indices_handle, (total_rows,), "int32")
out = T.match_buffer(out_handle, (total_rows, self.hidden_size), self.dtype)
T.func_attr({"global_symbol": "scatter", "tir.noalias": True})
for i in range(total_rows):
for j in range(self.hidden_size):
with T.block("scatter"):
vi, vj = T.axis.remap("SS", [i, j])
out[indices[vi], vj] = x[vi, vj]

scatter = relax.BlockBuilder.current().add_func(scatter_func, "scatter")
return nn.emit(
relax.call_dps_packed(
"scatter",
scatter,
[linear_out, indices],
out_sinfo=linear_out.struct_info,
)
Expand Down Expand Up @@ -338,7 +286,8 @@ def get_flattened_expert_indices_scheduled(
vi = T.axis.spatial(cumsum_flattened_length, io * T.int32(1024) + ii)
T.where(io * T.int32(1024) + ii < cumsum_flattened_length)
T.reads(
cumsum_colwise_flattened[vi - 1 : vi - 1 + 2], expert_indices[:, 0:2]
cumsum_colwise_flattened[vi - 1 : vi - 1 + 2],
expert_indices[:, 0 : self.num_experts_per_tok],
)
T.writes(flattened_expert_indices[:])
expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local")
Expand Down Expand Up @@ -401,12 +350,12 @@ def get_expert_instance_indptr(
var_expert_instance_indptr, shape=[self.num_experts], dtype="int64"
)

for expert_id in T.serial(0, self.num_experts):
for expert_id in range(self.num_experts):
with T.block("indptr"):
vexpert_id = T.axis.spatial(self.num_experts, expert_id)
expert_instance_indptr[vexpert_id] = cumsum_colwise_flattened[
vexpert_id * batch_size - 1
]
expert_instance_indptr[vexpert_id] = T.cast(
cumsum_colwise_flattened[(vexpert_id + 1) * batch_size - 1], "int64"
)

bb = relax.BlockBuilder.current()
gvar = bb.add_func(get_expert_instance_indptr, "get_expert_instance_indptr")
Expand All @@ -419,6 +368,64 @@ def get_expert_instance_indptr(
)
)

def topk_softmax(self, x, k):
index_dtype = "int32"

@T.prim_func
def top2_softmax(
x_handle: T.handle,
out_handle: T.handle,
out_index_handle: T.handle,
) -> None:
total_rows = T.int64()
x = T.match_buffer(x_handle, (total_rows, self.num_experts), self.dtype)
out = T.match_buffer(out_handle, (total_rows, 2), self.dtype)
out_index = T.match_buffer(out_index_handle, (total_rows, 2), index_dtype)
local_top_k = T.alloc_buffer((2,), dtype=self.dtype, scope="local")
local_top_k_index = T.alloc_buffer((2,), dtype=index_dtype, scope="local")
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
for io in T.thread_binding(0, T.ceildiv(total_rows, T.int64(1024)), "blockIdx.x"):
for ii in T.thread_binding(0, T.min(total_rows, T.int64(1024)), "threadIdx.x"):
with T.block("top2"):
vi = T.axis.spatial(total_rows, io * T.int64(1024) + ii)
T.where(io * T.int64(1024) + ii < total_rows)
with T.block("init"):
local_top_k[0] = T.min_value(self.dtype)
local_top_k_index[0] = 0
for k in range(self.num_experts):
with T.block("update"):
vk = T.axis.remap("S", [k])
if x[vi, vk] > local_top_k[0]:
local_top_k[1] = local_top_k[0]
local_top_k_index[1] = local_top_k_index[0]
local_top_k[0] = x[vi, vk]
local_top_k_index[0] = vk
elif x[vi, vk] > local_top_k[1]:
local_top_k[1] = x[vi, vk]
local_top_k_index[1] = vk
for j in T.unroll(2):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = T.exp(local_top_k[vj] - local_top_k[0]) / (
T.exp(local_top_k[0] - local_top_k[0])
+ T.exp(local_top_k[1] - local_top_k[0])
)
out_index[vi, vj] = local_top_k_index[vj]

if k != 2:
raise NotImplementedError("only support num_experts_per_token=2 for now")
bb = relax.BlockBuilder.current()
gvar = bb.add_func(top2_softmax, "top2_softmax")
return bb.emit(
relax.call_tir(
gvar,
[x],
out_sinfo=[
relax.TensorStructInfo([x.struct_info.shape[0], k], x.struct_info.dtype),
relax.TensorStructInfo([x.struct_info.shape[0], k], index_dtype),
],
)
)

def forward(self, hidden_states):
hidden_states_shape = hidden_states.struct_info.shape
Expand All @@ -427,31 +434,17 @@ def forward(self, hidden_states):
hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size)))

gate = self.gate(hidden_states)
scores = nn.emit(relax.op.nn.softmax(gate, axis=-1))

expert_weights, expert_indices = self.topk(
scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32"
) # (num_tokens, top_k), (num_tokens, top_k)
expert_weights = nn.emit(expert_weights / R.sum(expert_weights, axis=-1, keepdims=True))

expert_weights, expert_indices = self.topk_softmax(gate, k=self.num_experts_per_tok)
expert_mask = self.topk_mask(expert_indices)
mask_T_flattened = nn.emit(relax.op.flatten(relax.op.permute_dims(expert_mask)))

cumsum_colwise_flattened = self.cumsum(mask_T_flattened)
flattened_indices = self.get_indices(cumsum_colwise_flattened, expert_indices)
indptr = self.get_indptr(cumsum_colwise_flattened)
# indptr = nn.emit(relax.op.strided_slice(indptr, axes=[0], begin=[1], assume_inbound=True))

# flattened_indices = nn.emit(relax.op.flatten(expert_indices))
# sorted_expert_ids, indices = self.topk(
# flattened_indices, is_ascend=True, index_dtype="int32"
# )

# rows_before = self.compute_rows_before(sorted_expert_ids)
token_indices = self.get_token_indices(flattened_indices)

gathered_x = nn.emit(relax.op.take(hidden_states, token_indices, axis=0))
# linear_out = self.experts(gathered_x, rows_before)
# unpermuted = self.scatter(linear_out, indices)
linear_out = self.experts(gathered_x, indptr)
unpermuted = self.scatter(linear_out, flattened_indices)

Expand Down

0 comments on commit d848568

Please sign in to comment.