Skip to content

Commit

Permalink
feat: Python pooling (#442)
Browse files Browse the repository at this point in the history
Co-authored-by: Oskar Liew <[email protected]>
  • Loading branch information
OlivierDehaene and OskarLiew authored Dec 11, 2024
1 parent 0bfeb7e commit e27a4fb
Show file tree
Hide file tree
Showing 10 changed files with 1,749 additions and 640 deletions.
3 changes: 2 additions & 1 deletion backends/python/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ unit-tests:

gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
mkdir text_embeddings_server/pb || true
python -m grpc_tools.protoc -I../../proto --python_out=text_embeddings_server/pb \
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb ../../proto/embed.proto
Expand All @@ -15,6 +15,7 @@ gen-server:

install: gen-server
pip install pip --upgrade
pip install torch==2.5.1
pip install -r requirements.txt
pip install -e .

Expand Down
2,219 changes: 1,641 additions & 578 deletions backends/python/server/poetry.lock

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions backends/python/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ python-text-embeddings-server = 'text_embeddings_server.cli:app'

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
protobuf = ">=4.25.3,<6"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
safetensors = "^0.3.2"
safetensors = "^0.4"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
torch = { version = "^2.0.1" }
opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
sentence-transformers = "^3.3.1"
torch = "^2.5.1"

[tool.poetry.extras]

Expand Down
99 changes: 62 additions & 37 deletions backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,68 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.15 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.66.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.68.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.26.2 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==2.0.2 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cuda-runtime-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cufft-cu12==11.2.1.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-curand-cu12==10.3.5.147 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusolver-cu12==11.6.1.9 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-cusparse-cu12==12.3.1.170 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nccl-cu12==2.21.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvjitlink-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
nvidia-nvtx-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.5.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.46.3 ; python_version >= "3.9" and python_version < "3.13"
triton==3.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
3 changes: 2 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def serve(
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pool: str = "cls",
):
# Remove default handler
logger.remove()
Expand All @@ -48,7 +49,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path)
server.serve(model_path, dtype, uds_path, pool)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str]):
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
Expand All @@ -38,8 +38,6 @@ def get_model(model_path: Path, dtype: Optional[str]):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
if dtype != torch.float32:
raise ValueError("CPU device only supports float32 dtype")
device = torch.device("cpu")

config = AutoConfig.from_pretrained(model_path)
Expand All @@ -52,8 +50,10 @@ def get_model(model_path: Path, dtype: Optional[str]):
and dtype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype)
return DefaultModel(model_path, device, dtype, pool)

raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Type, List
from transformers import AutoModel
from opentelemetry import trace
from sentence_transformers.models import Pooling

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding
Expand All @@ -13,9 +14,12 @@


class DefaultModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
):
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
Expand All @@ -41,7 +45,13 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
embedding = output[0][:, 0]

pooling_features = {
"token_embeddings": output[0],
"attention_mask": batch.attention_mask,
}
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]

cpu_results = embedding.view(-1).tolist()

return [
Expand Down
4 changes: 2 additions & 2 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import torch

from grpc import aio
from loguru import logger

Expand Down Expand Up @@ -37,6 +36,7 @@ def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
):
async def serve_inner(
model_path: Path,
Expand All @@ -45,7 +45,7 @@ async def serve_inner(
unix_socket = f"unix://{uds_path}"

try:
model = get_model(model_path, dtype)
model = get_model(model_path, dtype, pool)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down
12 changes: 4 additions & 8 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
use tokio::runtime::Runtime;

Expand All @@ -24,18 +24,13 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
match model_type {
let pool = match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool
}
ModelType::Embedding(pool) => pool,
};

let backend_process = management::BackendProcess::new(
Expand All @@ -44,6 +39,7 @@ impl PythonBackend {
&uds_path,
otlp_endpoint,
otlp_service_name,
pool,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down
14 changes: 13 additions & 1 deletion backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::mpsc;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{env, fs, io, thread};
use text_embeddings_backend_core::BackendError;
use text_embeddings_backend_core::{BackendError, Pool};

#[derive(Debug)]
pub(crate) struct BackendProcess {
Expand All @@ -22,6 +22,7 @@ impl BackendProcess {
uds_path: &str,
otlp_endpoint: Option<String>,
otlp_service_name: String,
pool: Pool,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
Expand All @@ -31,6 +32,15 @@ impl BackendProcess {
fs::remove_file(uds).expect("could not remove UDS file");
}

let pool = match pool {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
};

// Process args
let mut python_server_args = vec![
model_path,
Expand All @@ -41,6 +51,8 @@ impl BackendProcess {
"--logger-level".to_owned(),
"INFO".to_owned(),
"--json-output".to_owned(),
"--pool".to_owned(),
pool.to_owned(),
];

// OpenTelemetry
Expand Down

0 comments on commit e27a4fb

Please sign in to comment.