Skip to content

Commit

Permalink
Merge branch 'main' into reranker-python
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixuanliu authored Jan 10, 2025
2 parents a28ad28 + d3a8098 commit ac60fd1
Show file tree
Hide file tree
Showing 16 changed files with 723 additions and 102 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tracing = "0.1"
serde = { version = "1.0", features = ["serde_derive"] }
serde_json = "1.0"
thiserror = "1.0"
rand = "0.8"


[patch.crates-io]
Expand Down
146 changes: 146 additions & 0 deletions Dockerfile-intel
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
ARG PLATFORM=cpu
FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef
WORKDIR /usr/src
ENV SCCACHE=0.5.4
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Download and configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json

RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

FROM builder as http-builder

RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder as grpc-builder

COPY proto proto

RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s

FROM intel/intel-optimized-pytorch:2.4.0-pip-base AS cpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
git \
cmake \
ninja-build \
python3-dev &&\
rm -rf /var/lib/apt/lists/*

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt

RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu

RUN cd backends/python/server && \
make install

FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
git \
cmake \
ninja-build \
python3-dev &&\
rm -rf /var/lib/apt/lists/*

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-hpu.txt backends/python/server/requirements.txt

RUN cd backends/python/server && \
make install

FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb

RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
WORKDIR /usr/src
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir

ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include

COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
RUN cd backends/python/server && \
make install

FROM ${PLATFORM} AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM ${PLATFORM}

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
1 change: 1 addition & 0 deletions backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ text-embeddings-backend-candle = { path = "candle", optional = true }
text-embeddings-backend-ort = { path = "ort", optional = true }
tokio = { workspace = true }
tracing = { workspace = true }
rand = { workspace = true }

[features]
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
Expand Down
3 changes: 1 addition & 2 deletions backends/python/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ gen-server:

install: gen-server
pip install pip --upgrade
pip install torch==2.5.1
pip install -r requirements.txt
pip install --no-deps -r requirements.txt
pip install -e .

run-dev:
Expand Down
61 changes: 61 additions & 0 deletions backends/python/server/requirements-hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.24.5 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13"
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13"
44 changes: 44 additions & 0 deletions backends/python/server/requirements-intel.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.3 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
1 change: 0 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def serve(

# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path, pool)


Expand Down
38 changes: 23 additions & 15 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.utils.device import get_device, use_ipex

__all__ = ["Model"]

Expand Down Expand Up @@ -37,34 +38,41 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
else:
raise RuntimeError(f"Unknown dtype {dtype}")

if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
device = get_device()
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path)

if config.model_type == "bert":
config: BertConfig
if (
device.type == "cuda"
use_ipex() or device.type in ["cuda", "hpu"]
and config.position_embedding_type == "absolute"
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, datatype)
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, datatype)
else:
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype, pool)
return DefaultModel(model_path, device, datatype, pool)
else:
try:
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import (
adapt_transformers_to_gaudi,
)

adapt_transformers_to_gaudi()
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, datatype)
model_handle = ClassificationModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype, pool)
except:
raise RuntimeError(f"Unsupported model_type {config.model_type}")
model_handle = DefaultModel(model_path, device, datatype)
model_handle.model = wrap_in_hpu_graph(model_handle.model)
return model_handle
elif use_ipex():
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)

pooling_features = {
Expand Down
Loading

0 comments on commit ac60fd1

Please sign in to comment.