Skip to content

Commit

Permalink
Switch to cumsum based impl (#123)
Browse files Browse the repository at this point in the history
* WIP

* fix

* Switch to cumsum based impl

* cleanup

* unfuse softmax
  • Loading branch information
vinx13 authored Dec 18, 2023
1 parent fa424e2 commit 134b8ee
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 141 deletions.
2 changes: 1 addition & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
mod_deploy
)
)
if not args.enable_batching:
if not args.enable_batching and target_kind != "cuda":
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)

if args.debug_load_script:
Expand Down
5 changes: 0 additions & 5 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,11 +1497,6 @@ def get_model(args, hf_config):
param_manager = ParamManager(keep_params_after_load)
bb = relax.BlockBuilder()

if isinstance(config, MixtralConfig):
from .mixtral import emit_tir_funcs

emit_tir_funcs(bb, config)

if sep_embed:
create_embed_func(bb, param_manager, config, args.quantization)

Expand Down
5 changes: 0 additions & 5 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,6 @@ def get_model(args, hf_config):
# The CPU device to copy the result of relax.op.max(seq_lens) to CPU.
cpu_dev = VDevice("llvm", 0, "global")

if isinstance(config, MixtralConfig):
from .mixtral import emit_tir_funcs

emit_tir_funcs(bb, config)

create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed)
create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed)
create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization)
Expand Down
Loading

0 comments on commit 134b8ee

Please sign in to comment.