Skip to content

Commit

Permalink
[FIX] fix transform params without GPUs (octoml#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored May 23, 2023
1 parent 68be032 commit 110c6d3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ def set_item(i, value):
res[i] = tvm.nd.array(value, device=tvm.cpu())
return tvm.nd.empty((1,), device=device)

with tvm.target.Target(target):
mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform)
if target.kind.name != "llvm":
with tvm.target.Target(target):
mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform)

ex = relax.build(mod_transform, target=target)
vm = relax.vm.VirtualMachine(ex, device)
vm[transform_func_name]()
Expand Down

0 comments on commit 110c6d3

Please sign in to comment.