Skip to content

Commit

Permalink
FT group Q + multigpu works
Browse files Browse the repository at this point in the history
  • Loading branch information
Masahiro Masuda committed Nov 20, 2023
1 parent ba1c5d7 commit 5f253c7
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 61 deletions.
3 changes: 1 addition & 2 deletions mlc_llm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def main():

# if num_shard>1 without -convert-weight-only or --build-model-only, we implicitly run it sequentially
if parsed_args.num_shards > 1 and not (parsed_args.build_model_only or parsed_args.convert_weights_only):
parsed_args.build_model_only = Truep
parsed_args.build_model_only = True
parsed_args.convert_weights_only = False # just to be explicit
core.build_model_from_args(parsed_args)


parsed_args.build_model_only = False
parsed_args.convert_weights_only = True
core.build_model_from_args(parsed_args)
Expand Down
4 changes: 2 additions & 2 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def build_model_from_args(args: argparse.Namespace):
"and it is highly recommended to use q4f16_1 instead"
)

use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"]

if args.num_shards > 1:
if (not args.build_model_only) and (not args.convert_weights_only):
Expand Down Expand Up @@ -907,7 +907,7 @@ def build_model_from_args(args: argparse.Namespace):
if args.num_shards > 1 and use_ft_quant:
preprocessed = []
weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight")
is_int4 = args.quantization.name == "q4f16_ft"
is_int4 = args.quantization.name in ["q4f16_ft", "q4f16_ft_group"]
sm = get_cuda_sm_version()

for p in params:
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/quantization/ft_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def f_scale_weight(i, j):
def decoding_func(nbit: int, storage_nbit: int, group_size: int):
def te_decode_sym(data, scale):
n_float_per_int = storage_nbit // nbit
cur_group_size = weight.shape[1] if group_size == -1 else group_size
cur_group_size = data.shape[1] if group_size == -1 else group_size

def f_decode_sym(i, j):
if n_float_per_int == 1:
Expand Down
59 changes: 7 additions & 52 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _get_shard_strategies_ft(
q_heads = model_config.num_attention_heads
kv_heads = model_config.get_num_key_value_heads()

def shard_qkv_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
def shard_qkv_weight_scale(x: relax.TensorStructInfo):
(red, spatial), dtype = x.shape, x.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
Expand All @@ -114,31 +114,6 @@ def shard_qkv_weight(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_qkv_scale(scale: relax.TensorStructInfo):
(spatial,), dtype = scale.shape, scale.dtype
spatial = int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
head_dim = spatial // (q_heads + 2 * kv_heads)
a = te.placeholder((spatial,), dtype=dtype)
w = topi.reshape(a, (spatial // head_dim, head_dim))
q = te.compute((q_heads, head_dim), lambda i, j: w[i, j])
k = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + i, j])
v = te.compute((kv_heads, head_dim), lambda i, j: w[q_heads + kv_heads + i, j])
q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim))
k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim))
v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim))
w = topi.concatenate((q, k, v), axis=1)
w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim))
func = te.create_prim_func([a, w])
return func

def shard_qkv_weight_scale(x: relax.TensorStructInfo):
if x.ndim == 2:
return shard_qkv_weight(x)
else:
return shard_qkv_scale(x)

def shard_k_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
Expand All @@ -149,8 +124,8 @@ def shard_k_weight(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
def shard_gate_up_weight_scale(x: relax.TensorStructInfo):
(red, spatial), dtype = x.shape, x.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
Expand All @@ -165,27 +140,6 @@ def shard_gate_up_weight(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_gate_up_scale(weight: relax.TensorStructInfo):
(spatial,), dtype = weight.shape, weight.dtype
spatial = int(spatial)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((spatial,), dtype=dtype)
g = te.compute((spatial // 2,), lambda i: a[i])
u = te.compute((spatial // 2,), lambda i: a[spatial // 2 + i])
g = topi.reshape(g, (num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=1)
w = topi.reshape(w, (num_shards, spatial // num_shards))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight_scale(x: relax.TensorStructInfo):
if x.ndim == 2:
return shard_gate_up_weight(x)
else:
return shard_gate_up_scale(x)

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_mlp_k": shard_k_weight,
Expand Down Expand Up @@ -246,7 +200,7 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]):


def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule:
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"]

if use_ft_quant:
shard_strategy_to_func = _get_shard_strategies_ft(
Expand Down Expand Up @@ -307,13 +261,14 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I
if param.shard_strategy is None or (
use_ft_quant
and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"]
and len(qparam_sinfo.shape) == 1
and qparam_sinfo.shape[0] == 1
):
sharded = arg
else:
strategy_func = shard_strategy_to_func[param.shard_strategy](
qparam_sinfo
).without_attr("global_symbol")

strategy_gvar = bb.add_func(
strategy_func,
func_name=f"{arg_name}.sharding_func",
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
if args.quantization not in quantization_schemes:
raise ValueError(f'Quantization "{args.quantization}" is not supported.')

use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft"]
use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"]
args.quantization = quantization_schemes[args.quantization]

if use_ft_quant and args.num_shards > 1:
Expand Down
5 changes: 2 additions & 3 deletions serve/tests/test_engine_paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test(args: argparse.Namespace):

engine_config = get_engine_config({
"use_staging_engine": args.use_staging_engine,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
Expand Down Expand Up @@ -119,7 +119,6 @@ def test(args: argparse.Namespace):
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--max-num-batched-tokens", type=int, default=-1)
parser.add_argument("--max-input-len", type=int, default=-1)
parser.add_argument("--max-output-len", type=int, default=20)
Expand Down

0 comments on commit 5f253c7

Please sign in to comment.