Skip to content

Commit

Permalink
Accelerating quantization computation by weight compression (octoml#45)
Browse files Browse the repository at this point in the history
This PR enables weight compression in GPU. Previously the weight
compression is run in CPU because the uncompressed weight is too large
to fit in GPU, and running on CPU is pretty slow in fp16 case. Now we
switch to GPU. The technique we use to fit the uncompressed weight into
GPU memory is lazy loading. We load the weight right before the first
use, and instantly free it after the last use.

By testing, this PR reduces the quantization computation time for
Vicuna-v1-7b under `q3f16_0` quantization setting by **6 min** on Linux
machine with **RTX 4090, Ryzen 3970X and 64GB of RAM**, and reduces
the time by 40 sec on Mac Studio with 32 GB of memory.

At this moment, to build the vicuna-7b-v1 model, for Linux machines
we only need more than 28GB of memory in total (compared with over
50GB previously). We are continuously working on reducing the memory
size requirement.
  • Loading branch information
jinhongyii authored May 22, 2023
1 parent 8f78235 commit 68be032
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,48 @@ def split_transform_deploy_mod(


def transform_params(
mod_transform: tvm.IRModule, model_params: List[tvm.nd.NDArray]
mod_transform: tvm.IRModule,
model_params: List[tvm.nd.NDArray],
) -> List[tvm.nd.NDArray]:
# Remove the dataflow block inside the param transform function,
# so that the LazyTransformParams pass can be applied.
mod_transform = relax.transform.ToNonDataflow()(mod_transform)
mod_transform = relax.transform.LazyTransformParams()(mod_transform)

transform_func_name = None
for gv, func in mod_transform.functions.items():
if isinstance(func, relax.Function):
transform_func_name = gv.name_hint
assert transform_func_name is not None

ex = relax.build(mod_transform, target="llvm")
vm = relax.vm.VirtualMachine(ex, tvm.cpu())
res = vm[transform_func_name](model_params)
if tvm.cuda().exist:
target = "cuda"
elif tvm.metal().exist:
target = "metal"
else:
target = "llvm"
target = tvm.target.Target(target)
device = tvm.device(target.kind.default_keys[0])

@tvm.register_func("get_item", override=True)
def get_item(i):
gpu_input = tvm.nd.array(model_params[i], device=device)
return gpu_input

res = []

@tvm.register_func("set_item", override=True)
def set_item(i, value):
if len(res) <= i:
res.extend([None for _ in range(i - len(res) + 1)])
res[i] = tvm.nd.array(value, device=tvm.cpu())
return tvm.nd.empty((1,), device=device)

with tvm.target.Target(target):
mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform)
ex = relax.build(mod_transform, target=target)
vm = relax.vm.VirtualMachine(ex, device)
vm[transform_func_name]()
return res


Expand Down

0 comments on commit 68be032

Please sign in to comment.