diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 4d8a70783e..64c291aa1c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -146,7 +146,7 @@ def _is_safe_to_sample(prob_like): def load_disco_module(artifact_path, lib_path, num_shards): - sess = di.ProcessSession(num_workers=num_shards) + sess = di.ProcessSession(num_workers=num_shards, entrypoint="tvm.exec.disco_worker") devices = range(num_shards) sess.init_ccl("nccl", *devices) module = sess.load_vm_module(lib_path)