diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 460907d0a1..937ff4d920 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -145,8 +145,10 @@ def set_item(i, value): res[i] = tvm.nd.array(value, device=tvm.cpu()) return tvm.nd.empty((1,), device=device) - with tvm.target.Target(target): - mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform) + if target.kind.name != "llvm": + with tvm.target.Target(target): + mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform) + ex = relax.build(mod_transform, target=target) vm = relax.vm.VirtualMachine(ex, device) vm[transform_func_name]()