diff --git a/README.md b/README.md
index c5c34ea2..04b8c2f5 100644
--- a/README.md
+++ b/README.md
@@ -38,15 +38,16 @@ Not all features of TEI are currently supported as this is still a work in progr
## Validated Models
-| Architecture | Model Type | Models |
-|--------------|------------|--------|
-| BERT | Embedding |
[BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)[sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)[sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)[sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)[sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)[sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2) |
-| MPNet | Embedding | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)[sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)[sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1) |
-| ALBERT | Embedding | [sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2) |
-| Mistral | Embedding | [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)[Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R) |
-| GTE | Embedding | [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5) |
-| JinaBERT | Embedding | [jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) |
-| Roberta | Sequence Classification | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
+| Architecture | Model Type | Pooling | Models |
+|--------------|------------|---------|--------|
+| BERT | Embedding | Cls | [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)[sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)[sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)[sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)[sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)[sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2) |
+| BERT | Embedding | Splade | [naver/efficient-splade-VI-BT-large-query](https://huggingface.co/naver/efficient-splade-VI-BT-large-query) |
+| MPNet | Embedding | Mean | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)[sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)[sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1) |
+| ALBERT | Embedding | Mean | [sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2) |
+| Mistral | Embedding | Last token | [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)[Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R) |
+| GTE | Embedding | Cls | [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5) |
+| JinaBERT | Embedding | Mean | [jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) |
+| Roberta | Sequence Classification | N/A | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
> The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE
>
diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt
index a0a758c6..ddf04531 100644
--- a/backends/python/server/requirements.txt
+++ b/backends/python/server/requirements.txt
@@ -62,7 +62,7 @@ safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
-sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
+sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py
index c4783061..d28ae0f5 100644
--- a/backends/python/server/text_embeddings_server/models/__init__.py
+++ b/backends/python/server/text_embeddings_server/models/__init__.py
@@ -4,7 +4,7 @@
from loguru import logger
from pathlib import Path
from typing import Optional
-from transformers import AutoConfig
+from transformers import AutoConfig, BertForMaskedLM
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
@@ -77,6 +77,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
+ elif config.architectures[0] == "BertForMaskedLM":
+ return DefaultModel(
+ model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE, model_class=BertForMaskedLM
+ )
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py
index 54003465..da3f342b 100644
--- a/backends/python/server/text_embeddings_server/models/default_model.py
+++ b/backends/python/server/text_embeddings_server/models/default_model.py
@@ -4,14 +4,14 @@
from loguru import logger
from pathlib import Path
from typing import Type, List
-from transformers import AutoModel
-from sentence_transformers.models import Pooling
+from transformers import AutoModel, PreTrainedModel
from opentelemetry import trace
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from text_embeddings_server.models import Model
+from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling
from text_embeddings_server.models.types import PaddedBatch, Embedding
tracer = trace.get_tracer(__name__)
@@ -25,19 +25,26 @@ def __init__(
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
+ model_class: type[PreTrainedModel] = AutoModel, # type: ignore
):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()
model = (
- AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote)
- .to(dtype)
- .to(device)
+ model_class.from_pretrained(model_path, trust_remote_code=trust_remote) # type: ignore
+ .to(dtype=dtype)
+ .to(device=device)
)
+
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
self.hidden_size = model.config.hidden_size
- self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
+ self.vocab_size = model.config.vocab_size
+ self.pooling_mode = pool
+ if pool == "splade":
+ self.pooling = SpladePooling()
+ else:
+ self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
@@ -72,17 +79,19 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids
output = self.model(**kwargs)
- pooling_features = {
- "token_embeddings": output[0],
- "attention_mask": batch.attention_mask,
- }
- embedding = self.pooling.forward(pooling_features)["sentence_embedding"]
+ embedding = self.pooling.forward(output, batch.attention_mask)
cpu_results = embedding.reshape(-1).tolist()
-
+ step_size = embedding.shape[-1]
+ if self.pooling_mode == "splade":
+ assert (
+ step_size == self.vocab_size
+ ), f"Step size for splade pooling expected vocab size ({self.vocab_size}) but got {step_size}. Check splade pooling implementation"
+ else:
+ assert (
+ step_size == self.hidden_size
+ ), f"Step size expected hidden size ({self.hidden_size}) but got {step_size}. Please check model outputs."
return [
- Embedding(
- values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
- )
+ Embedding(values=cpu_results[i * step_size : (i + 1) * step_size])
for i in range(len(batch))
]
diff --git a/backends/python/server/text_embeddings_server/models/pooling.py b/backends/python/server/text_embeddings_server/models/pooling.py
new file mode 100644
index 00000000..43f77b14
--- /dev/null
+++ b/backends/python/server/text_embeddings_server/models/pooling.py
@@ -0,0 +1,40 @@
+from abc import ABC, abstractmethod
+
+import torch
+from opentelemetry import trace
+from sentence_transformers.models import Pooling
+from torch import Tensor
+
+tracer = trace.get_tracer(__name__)
+
+
+class _Pooling(ABC):
+ @abstractmethod
+ def forward(self, model_output, attention_mask) -> Tensor:
+ pass
+
+
+class DefaultPooling(_Pooling):
+ def __init__(self, hidden_size, pooling_mode) -> None:
+ assert (
+ pooling_mode != "splade"
+ ), "Splade pooling is not supported for DefaultPooling"
+ self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode)
+
+ @tracer.start_as_current_span("pooling")
+ def forward(self, model_output, attention_mask) -> Tensor:
+ pooling_features = {
+ "token_embeddings": model_output[0],
+ "attention_mask": attention_mask,
+ }
+ return self.pooling.forward(pooling_features)["sentence_embedding"]
+
+
+class SpladePooling(_Pooling):
+ @tracer.start_as_current_span("pooling")
+ def forward(self, model_output, attention_mask) -> Tensor:
+ # Implement Splade pooling
+ hidden_states = torch.relu(model_output[0])
+ hidden_states = (1 + hidden_states).log()
+ hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1))
+ return hidden_states.max(dim=1).values
diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs
index 7068865c..f1d280c8 100644
--- a/backends/python/src/lib.rs
+++ b/backends/python/src/lib.rs
@@ -24,15 +24,9 @@ impl PythonBackend {
otlp_endpoint: Option,
otlp_service_name: String,
) -> Result {
- let mut pool_type = Pool::Cls;
- match model_type {
- ModelType::Classifier => {}
- ModelType::Embedding(pool) => {
- if pool == Pool::Splade {
- return Err(BackendError::Start(format!("{pool:?} is not supported")));
- }
- pool_type = pool;
- }
+ let pool_type = match model_type {
+ ModelType::Classifier => Pool::Cls,
+ ModelType::Embedding(pool) => pool,
};
let backend_process = management::BackendProcess::new(
diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs
index 1ae004fd..a635cb03 100644
--- a/backends/python/src/management.rs
+++ b/backends/python/src/management.rs
@@ -37,7 +37,7 @@ impl BackendProcess {
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
- return Err(BackendError::Start(format!("{pool:?} is not supported")));
+ "splade"
}
};