Skip to content

Commit

Permalink
change default param of DISABLE_TENSOR_CACHE to False, and make it can
Browse files Browse the repository at this point in the history
be set through ENV param

Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu committed Dec 16, 2024
1 parent 70d5d79 commit cb20b47
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
30 changes: 26 additions & 4 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

HTCORE_AVAILABLE = True
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
"true",
"1",
]

try:
import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -71,7 +75,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
else:
if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0] == "BertForMaskedLM":
return DefaultModel(
Expand All @@ -84,17 +92,31 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
)
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
pool,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
else:
try:
if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
pool,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
except:
raise RuntimeError(f"Unsupported model_type {config.model_type}")
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
model_path: Path,
device: torch.device,
dtype: torch.dtype,
disable_tensor_cache: bool = False,
trust_remote: bool = False,
):
if device == torch.device("hpu"):
Expand All @@ -33,7 +34,7 @@ def __init__(
model = model.to(dtype).to(device)
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = wrap_in_hpu_graph(model, disable_tensor_cache=disable_tensor_cache)

self.hidden_size = model.config.hidden_size
position_offset = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
disable_tensor_cache: bool = False,
trust_remote: bool = False,
model_class: type[PreTrainedModel] = AutoModel, # type: ignore
):
Expand All @@ -37,7 +38,7 @@ def __init__(

if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = wrap_in_hpu_graph(model, disable_tensor_cache=disable_tensor_cache)
self.hidden_size = model.config.hidden_size
self.vocab_size = model.config.vocab_size
self.pooling_mode = pool
Expand Down

0 comments on commit cb20b47

Please sign in to comment.