Skip to content

Commit

Permalink
Add auto detect support for vulkan (octoml#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored May 24, 2023
1 parent f1dcc7f commit eed5a28
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def argparse_postproc_common(args: argparse.Namespace) -> None:
args.device_name = "cuda"
elif tvm.metal().exist:
args.device_name = "metal"
elif tvm.vulkan().exist:
args.device_name = "vulkan"
elif tvm.opencl().exist:
args.device_name = "opencl"
else:
raise ValueError("Cannot auto deduce device-name, please set it")
supported_model_prefix = {
Expand Down Expand Up @@ -259,6 +263,31 @@ def parse_target(args: argparse.Namespace) -> None:
if args.target == "auto":
if system() == "Darwin":
target = tvm.target.Target("apple/m1-gpu")
elif tvm.cuda().exist:
dev = tvm.cuda()
target = tvm.target.Target(
{
"kind": "cuda",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"registers_per_block": 65536,
"arch": "sm_" + tvm.cuda().compute_version.replace(".", ""),
}
),
elif tvm.vulkan().exist:
dev = tvm.vulkan()
target = tvm.target.Target(
{
"kind": "vulkan",
"max_threads_per_block": dev.max_threads_per_block,
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"thread_warp_size": dev.warp_size,
"supports_float16": 1,
"supports_int16": 1,
"supports_16bit_buffer": 1,
}
),
else:
has_gpu = tvm.cuda().exist
target = tvm.target.Target(
Expand Down

0 comments on commit eed5a28

Please sign in to comment.