diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 9497dc20..8398f65f 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -24,6 +24,7 @@ def serve( json_output: bool = False, otlp_endpoint: Optional[str] = None, otlp_service_name: str = "text-embeddings-inference.server", + pool: str = "cls", ): # Remove default handler logger.remove() @@ -48,7 +49,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pool) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index d6966046..c4783061 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -6,7 +6,9 @@ from typing import Optional from transformers import AutoConfig from transformers.models.bert import BertConfig -from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, +) from text_embeddings_server.models.model import Model, B from text_embeddings_server.models.default_model import DefaultModel @@ -37,7 +39,7 @@ __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +def get_model(model_path: Path, dtype: Optional[str], pool: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -66,17 +68,29 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): + if pool != "cls": + raise ValueError("FlashBert only supports cls pooling") return FlashBert(model_path, device, dtype) else: - if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + if ( + config.architectures[0] + in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values() + ): return ClassificationModel(model_path, device, dtype) else: - return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) + return DefaultModel( + model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE + ) else: try: - if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + if ( + config.architectures[0] + in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values() + ): return ClassificationModel(model_path, device, dtype) else: - return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) + return DefaultModel( + model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE + ) except: raise RuntimeError(f"Unsupported model_type {config.model_type}") diff --git a/backends/python/server/text_embeddings_server/models/classification_model.py b/backends/python/server/text_embeddings_server/models/classification_model.py index c63b45b2..d94a237f 100644 --- a/backends/python/server/text_embeddings_server/models/classification_model.py +++ b/backends/python/server/text_embeddings_server/models/classification_model.py @@ -57,7 +57,9 @@ def batch_type(self) -> Type[PaddedBatch]: @tracer.start_as_current_span("embed") def embed(self, batch): - raise NotImplementedError(f"Embed is not a valid operation for model type {self.model.config.model_type}") + raise NotImplementedError( + f"Embed is not a valid operation for model type {self.model.config.model_type}" + ) @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Score]: diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index ea12f847..54003465 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Type, List from transformers import AutoModel +from sentence_transformers.models import Pooling from opentelemetry import trace from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -17,18 +18,26 @@ class DefaultModel(Model): - def __init__(self, - model_path: Path, - device: torch.device, - dtype: torch.dtype, - trust_remote: bool=False): + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = False, + ): if device == torch.device("hpu"): adapt_transformers_to_gaudi() - model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote).to(dtype).to(device) + model = ( + AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote) + .to(dtype) + .to(device) + ) if device == torch.device("hpu"): logger.info("Use graph mode for HPU") model = wrap_in_hpu_graph(model, disable_tensor_cache=True) self.hidden_size = model.config.hidden_size + self.pooling = Pooling(self.hidden_size, pooling_mode=pool) position_offset = 0 model_type = model.config.model_type if model_type in ["xlm-roberta", "camembert", "roberta"]: @@ -63,7 +72,11 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - embedding = output[0][:, 0] + pooling_features = { + "token_embeddings": output[0], + "attention_mask": batch.attention_mask, + } + embedding = self.pooling.forward(pooling_features)["sentence_embedding"] cpu_results = embedding.reshape(-1).tolist() return [ @@ -75,4 +88,6 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: @tracer.start_as_current_span("predict") def predict(self, batch): - raise NotImplementedError(f"Predict is not a valid operation for model type {self.model.config.model_type}") + raise NotImplementedError( + f"Predict is not a valid operation for model type {self.model.config.model_type}" + ) diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index 5fcb2e35..009f31d4 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -9,15 +9,19 @@ from text_embeddings_server.pb.embed_pb2 import Embedding, Score tracer = trace.get_tracer(__name__) -PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) + def round_up(number, k): return (number + k - 1) // k * k + class Batch(ABC): @classmethod @abstractmethod - def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device, *args, **kwargs) -> "Batch": + def from_pb( + cls, pb: embed_pb2.EmbedRequest, device: torch.device, *args, **kwargs + ) -> "Batch": raise NotImplementedError @abstractmethod @@ -34,10 +38,9 @@ class PaddedBatch(Batch): @classmethod @tracer.start_as_current_span("from_pb") - def from_pb(cls, - pb: embed_pb2.EmbedRequest, - device: torch.device, - max_input_length: int) -> "PaddedBatch": + def from_pb( + cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int + ) -> "PaddedBatch": if pb.max_length > max_input_length: raise RuntimeError(f"input length exceeds model config's max_input_length") @@ -46,9 +49,7 @@ def from_pb(cls, batch_size = len(pb.cu_seq_lengths) - 1 new_bs = 2 ** math.ceil(math.log2(batch_size)) # Allocate padded tensors all at once - all_tensors = torch.zeros( - [4, new_bs, max_length], dtype=torch.int32 - ) + all_tensors = torch.zeros([4, new_bs, max_length], dtype=torch.int32) for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): end_index = pb.cu_seq_lengths[i + 1] input_length = end_index - start_index diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index 1135fa34..e1abf522 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -27,7 +27,9 @@ async def Health(self, request, context): async def Embed(self, request, context): max_input_length = self.model.max_input_length - batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length) + batch = self.model.batch_type.from_pb( + request, self.model.device, max_input_length + ) embeddings = self.model.embed(batch) @@ -35,7 +37,9 @@ async def Embed(self, request, context): async def Predict(self, request, context): max_input_length = self.model.max_input_length - batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length) + batch = self.model.batch_type.from_pb( + request, self.model.device, max_input_length + ) scores = self.model.predict(batch) @@ -46,6 +50,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pool: str, ): async def serve_inner( model_path: Path, @@ -54,7 +59,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pool) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index a4610711..736cb88d 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -24,11 +24,17 @@ impl PythonBackend { otlp_endpoint: Option, otlp_service_name: String, ) -> Result { + let mut pool_type = Pool::Cls; match model_type { ModelType::Classifier => {} ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + match pool { + Pool::Splade => { + return Err(BackendError::Start(format!("{pool:?} is not supported"))); + }, + _ => { + pool_type = pool; + } } } }; @@ -39,6 +45,7 @@ impl PythonBackend { &uds_path, otlp_endpoint, otlp_service_name, + pool_type, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..977ab045 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -8,7 +8,7 @@ use std::sync::mpsc; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{env, fs, io, thread}; -use text_embeddings_backend_core::BackendError; +use text_embeddings_backend_core::{BackendError, Pool}; #[derive(Debug)] pub(crate) struct BackendProcess { @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pool: Pool, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -31,6 +32,15 @@ impl BackendProcess { fs::remove_file(uds).expect("could not remove UDS file"); } + let pool = match pool { + Pool::Cls => "cls", + Pool::Mean => "mean", + Pool::LastToken => "lasttoken", + Pool::Splade => { + return Err(BackendError::Start(format!("{pool:?} is not supported"))); + }, + }; + // Process args let mut python_server_args = vec![ model_path, @@ -41,6 +51,8 @@ impl BackendProcess { "--logger-level".to_owned(), "INFO".to_owned(), "--json-output".to_owned(), + "--pool".to_owned(), + pool.to_owned(), ]; // OpenTelemetry