From 4b38c437797d2aadf562080cfb1c78aedeae9560 Mon Sep 17 00:00:00 2001 From: Daniel Huang Date: Mon, 25 Nov 2024 06:00:50 -0800 Subject: [PATCH] Splade enabling (#35) Signed-off-by: Daniel Huang --- README.md | 19 ++++----- backends/python/server/requirements.txt | 2 +- .../text_embeddings_server/models/__init__.py | 6 ++- .../models/default_model.py | 39 +++++++++++------- .../text_embeddings_server/models/pooling.py | 40 +++++++++++++++++++ backends/python/src/lib.rs | 12 ++---- backends/python/src/management.rs | 2 +- 7 files changed, 84 insertions(+), 36 deletions(-) create mode 100644 backends/python/server/text_embeddings_server/models/pooling.py diff --git a/README.md b/README.md index c5c34ea2..04b8c2f5 100644 --- a/README.md +++ b/README.md @@ -38,15 +38,16 @@ Not all features of TEI are currently supported as this is still a work in progr ## Validated Models -| Architecture | Model Type | Models | -|--------------|------------|--------| -| BERT | Embedding |
  • [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)
  • [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
  • [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)
  • [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)
  • [sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)
  • [sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2)
  • | -| MPNet | Embedding |
  • [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
  • [sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)
  • [sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1)
  • | -| ALBERT | Embedding |
  • [sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2)
  • | -| Mistral | Embedding |
  • [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)
  • [Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R)
  • | -| GTE | Embedding |
  • [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)
  • | -| JinaBERT | Embedding |
  • [jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)
  • | -| Roberta | Sequence Classification |
  • [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions)
  • | +| Architecture | Model Type | Pooling | Models | +|--------------|------------|---------|--------| +| BERT | Embedding | Cls |
  • [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)
  • [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
  • [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)
  • [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)
  • [sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)
  • [sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2)
  • | +| BERT | Embedding | Splade |
  • [naver/efficient-splade-VI-BT-large-query](https://huggingface.co/naver/efficient-splade-VI-BT-large-query)
  • | +| MPNet | Embedding | Mean |
  • [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
  • [sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)
  • [sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1)
  • | +| ALBERT | Embedding | Mean |
  • [sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2)
  • | +| Mistral | Embedding | Last token |
  • [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)
  • [Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R)
  • | +| GTE | Embedding | Cls |
  • [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)
  • | +| JinaBERT | Embedding | Mean |
  • [jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)
  • | +| Roberta | Sequence Classification | N/A |
  • [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions)
  • | > The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE > diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index a0a758c6..ddf04531 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -62,7 +62,7 @@ safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index c4783061..d28ae0f5 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -4,7 +4,7 @@ from loguru import logger from pathlib import Path from typing import Optional -from transformers import AutoConfig +from transformers import AutoConfig, BertForMaskedLM from transformers.models.bert import BertConfig from transformers.models.auto.modeling_auto import ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, @@ -77,6 +77,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values() ): return ClassificationModel(model_path, device, dtype) + elif config.architectures[0] == "BertForMaskedLM": + return DefaultModel( + model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE, model_class=BertForMaskedLM + ) else: return DefaultModel( model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 54003465..da3f342b 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -4,14 +4,14 @@ from loguru import logger from pathlib import Path from typing import Type, List -from transformers import AutoModel -from sentence_transformers.models import Pooling +from transformers import AutoModel, PreTrainedModel from opentelemetry import trace from habana_frameworks.torch.hpu import wrap_in_hpu_graph from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from text_embeddings_server.models import Model +from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling from text_embeddings_server.models.types import PaddedBatch, Embedding tracer = trace.get_tracer(__name__) @@ -25,19 +25,26 @@ def __init__( dtype: torch.dtype, pool: str = "cls", trust_remote: bool = False, + model_class: type[PreTrainedModel] = AutoModel, # type: ignore ): if device == torch.device("hpu"): adapt_transformers_to_gaudi() model = ( - AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote) - .to(dtype) - .to(device) + model_class.from_pretrained(model_path, trust_remote_code=trust_remote) # type: ignore + .to(dtype=dtype) + .to(device=device) ) + if device == torch.device("hpu"): logger.info("Use graph mode for HPU") model = wrap_in_hpu_graph(model, disable_tensor_cache=True) self.hidden_size = model.config.hidden_size - self.pooling = Pooling(self.hidden_size, pooling_mode=pool) + self.vocab_size = model.config.vocab_size + self.pooling_mode = pool + if pool == "splade": + self.pooling = SpladePooling() + else: + self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) position_offset = 0 model_type = model.config.model_type if model_type in ["xlm-roberta", "camembert", "roberta"]: @@ -72,17 +79,19 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - pooling_features = { - "token_embeddings": output[0], - "attention_mask": batch.attention_mask, - } - embedding = self.pooling.forward(pooling_features)["sentence_embedding"] + embedding = self.pooling.forward(output, batch.attention_mask) cpu_results = embedding.reshape(-1).tolist() - + step_size = embedding.shape[-1] + if self.pooling_mode == "splade": + assert ( + step_size == self.vocab_size + ), f"Step size for splade pooling expected vocab size ({self.vocab_size}) but got {step_size}. Check splade pooling implementation" + else: + assert ( + step_size == self.hidden_size + ), f"Step size expected hidden size ({self.hidden_size}) but got {step_size}. Please check model outputs." return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] - ) + Embedding(values=cpu_results[i * step_size : (i + 1) * step_size]) for i in range(len(batch)) ] diff --git a/backends/python/server/text_embeddings_server/models/pooling.py b/backends/python/server/text_embeddings_server/models/pooling.py new file mode 100644 index 00000000..43f77b14 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/pooling.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod + +import torch +from opentelemetry import trace +from sentence_transformers.models import Pooling +from torch import Tensor + +tracer = trace.get_tracer(__name__) + + +class _Pooling(ABC): + @abstractmethod + def forward(self, model_output, attention_mask) -> Tensor: + pass + + +class DefaultPooling(_Pooling): + def __init__(self, hidden_size, pooling_mode) -> None: + assert ( + pooling_mode != "splade" + ), "Splade pooling is not supported for DefaultPooling" + self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode) + + @tracer.start_as_current_span("pooling") + def forward(self, model_output, attention_mask) -> Tensor: + pooling_features = { + "token_embeddings": model_output[0], + "attention_mask": attention_mask, + } + return self.pooling.forward(pooling_features)["sentence_embedding"] + + +class SpladePooling(_Pooling): + @tracer.start_as_current_span("pooling") + def forward(self, model_output, attention_mask) -> Tensor: + # Implement Splade pooling + hidden_states = torch.relu(model_output[0]) + hidden_states = (1 + hidden_states).log() + hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1)) + return hidden_states.max(dim=1).values diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 7068865c..f1d280c8 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -24,15 +24,9 @@ impl PythonBackend { otlp_endpoint: Option, otlp_service_name: String, ) -> Result { - let mut pool_type = Pool::Cls; - match model_type { - ModelType::Classifier => {} - ModelType::Embedding(pool) => { - if pool == Pool::Splade { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); - } - pool_type = pool; - } + let pool_type = match model_type { + ModelType::Classifier => Pool::Cls, + ModelType::Embedding(pool) => pool, }; let backend_process = management::BackendProcess::new( diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 1ae004fd..a635cb03 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -37,7 +37,7 @@ impl BackendProcess { Pool::Mean => "mean", Pool::LastToken => "lasttoken", Pool::Splade => { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + "splade" } };