diff --git a/Cargo.lock b/Cargo.lock index 22ad1ee7..ffe0b446 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4229,6 +4229,7 @@ version = "1.6.0" dependencies = [ "clap", "hf-hub", + "rand", "serde_json", "text-embeddings-backend-candle", "text-embeddings-backend-core", diff --git a/Cargo.toml b/Cargo.toml index 8c4a4726..b979d86c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ tracing = "0.1" serde = { version = "1.0", features = ["serde_derive"] } serde_json = "1.0" thiserror = "1.0" +rand = "0.8" [patch.crates-io] diff --git a/Dockerfile-intel b/Dockerfile-intel new file mode 100644 index 00000000..8b3b6e21 --- /dev/null +++ b/Dockerfile-intel @@ -0,0 +1,146 @@ +ARG PLATFORM=cpu +FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef +WORKDIR /usr/src +ENV SCCACHE=0.5.4 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Download and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG ACTIONS_CACHE_URL +ARG ACTIONS_RUNTIME_TOKEN +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +FROM builder as http-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder as grpc-builder + +COPY proto proto + +RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s + +FROM intel/intel-optimized-pytorch:2.4.0-pip-base AS cpu +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + git \ + cmake \ + ninja-build \ + python3-dev &&\ + rm -rf /var/lib/apt/lists/* + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt + +RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu + +RUN cd backends/python/server && \ + make install + +FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + git \ + cmake \ + ninja-build \ + python3-dev &&\ + rm -rf /var/lib/apt/lists/* + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +COPY backends/python/server/requirements-hpu.txt backends/python/server/requirements.txt + +RUN cd backends/python/server && \ + make install + +FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 +RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ + dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb + +RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ +| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list + +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils +WORKDIR /usr/src +RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir + +ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest +ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest +ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric +ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: +ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV CCL_ZE_IPC_EXCHANGE=sockets +ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest +ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include + +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt +RUN cd backends/python/server && \ + make install + +FROM ${PLATFORM} AS grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM ${PLATFORM} + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/Cargo.toml b/backends/Cargo.toml index f29283ad..42ecb20f 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -15,6 +15,7 @@ text-embeddings-backend-candle = { path = "candle", optional = true } text-embeddings-backend-ort = { path = "ort", optional = true } tokio = { workspace = true } tracing = { workspace = true } +rand = { workspace = true } [features] clap = ["dep:clap", "text-embeddings-backend-core/clap"] diff --git a/backends/python/server/Makefile b/backends/python/server/Makefile index 4e9d7d7c..6402d63f 100644 --- a/backends/python/server/Makefile +++ b/backends/python/server/Makefile @@ -15,8 +15,7 @@ gen-server: install: gen-server pip install pip --upgrade - pip install torch==2.5.1 - pip install -r requirements.txt + pip install --no-deps -r requirements.txt pip install -e . run-dev: diff --git a/backends/python/server/requirements-hpu.txt b/backends/python/server/requirements-hpu.txt new file mode 100644 index 00000000..6bf51a7f --- /dev/null +++ b/backends/python/server/requirements-hpu.txt @@ -0,0 +1,61 @@ +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" +fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.24.5 ; python_version >= "3.9" and python_version < "3.13" +humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" +jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13" +mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" +pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" +setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" +six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13" +transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" +tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13" +yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13" +pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/requirements-intel.txt b/backends/python/server/requirements-intel.txt new file mode 100644 index 00000000..36b330db --- /dev/null +++ b/backends/python/server/requirements-intel.txt @@ -0,0 +1,44 @@ +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.19.3 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" +mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" +setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13" +pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 8398f65f..c4dfaa4c 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -48,7 +48,6 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path, pool) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 2605f2df..eb0d19f6 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -10,6 +10,7 @@ from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel +from text_embeddings_server.utils.device import get_device, use_ipex __all__ = ["Model"] @@ -37,17 +38,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: raise RuntimeError(f"Unknown dtype {dtype}") - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") + device = get_device() + logger.info(f"backend device: {device}") config = AutoConfig.from_pretrained(model_path) - if config.model_type == "bert": config: BertConfig if ( - device.type == "cuda" + use_ipex() or device.type in ["cuda", "hpu"] and config.position_embedding_type == "absolute" and datatype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION @@ -55,16 +53,26 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): if pool != "cls": raise ValueError("FlashBert only supports cls pooling") return FlashBert(model_path, device, datatype) + if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + return ClassificationModel(model_path, device, datatype) else: - if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): - return ClassificationModel(model_path, device, datatype) - else: - return DefaultModel(model_path, device, datatype, pool) + return DefaultModel(model_path, device, datatype, pool) else: - try: + if device.type == "hpu": + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + from optimum.habana.transformers.modeling_utils import ( + adapt_transformers_to_gaudi, + ) + + adapt_transformers_to_gaudi() if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): - return ClassificationModel(model_path, device, datatype) + model_handle = ClassificationModel(model_path, device, datatype) else: - return DefaultModel(model_path, device, datatype, pool) - except: - raise RuntimeError(f"Unsupported model_type {config.model_type}") + model_handle = DefaultModel(model_path, device, datatype) + model_handle.model = wrap_in_hpu_graph(model_handle.model) + return model_handle + elif use_ipex(): + if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + return ClassificationModel(model_path, device, datatype) + else: + return DefaultModel(model_path, device, datatype) diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 759203da..80b8b09b 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -43,7 +43,6 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["token_type_ids"] = batch.token_type_ids if self.has_position_ids: kwargs["position_ids"] = batch.position_ids - output = self.model(**kwargs) pooling_features = { diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 50b8d70d..fce5c3f2 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -1,50 +1,95 @@ import torch - from pathlib import Path from torch import nn +import torch.nn.functional as F from typing import Type, List from safetensors import safe_open from transformers.activations import ACT2FN from transformers.models.bert import BertConfig from opentelemetry import trace - -# Flash attention imports -import dropout_layer_norm - from text_embeddings_server.models import Model from text_embeddings_server.models.types import FlashBatch, Embedding from text_embeddings_server.utils.flash_attn import attention +from text_embeddings_server.utils.device import use_ipex tracer = trace.get_tracer(__name__) +def hpu_add_layer_norm( + add: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + epsilon: float, + add_back: bool, +): + if add is not None: + added_tensor = torch.add(add, x, alpha=1.0) + output = F.layer_norm(added_tensor, [x.size(-1)], weight, bias, epsilon) + if add_back: + add.add_(x) + return output + else: + return F.layer_norm(x, [x.size(-1)], weight=weight, bias=bias, eps=epsilon) + + class FastLayerNorm: def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) self.variance_epsilon = config.layer_norm_eps + self.device = device + self.use_ipex = use_ipex() def forward(self, hidden_states, residual=None): - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - False, - ) - if res is None: - res = hidden_states + # Flash attention imports + normed_hidden_states = None + res = None + if self.device.type == "cuda": + import dropout_layer_norm + + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + False, + ) + if res is None: + res = hidden_states + elif self.use_ipex: + import intel_extension_for_pytorch as ipex + + normed_hidden_states = ipex.llm.functional.add_layer_norm( + residual, + hidden_states, + self.weight, + self.bias, + self.variance_epsilon, + residual is not None, + ) + res = residual if residual is not None else hidden_states + elif self.device.type == "hpu": + normed_hidden_states = hpu_add_layer_norm( + residual, + hidden_states, + self.weight, + self.bias, + self.variance_epsilon, + residual is not None, + ) + res = residual if residual is not None else hidden_states return normed_hidden_states, res @@ -119,6 +164,7 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.head_size = config.hidden_size // config.num_attention_heads self.softmax_scale = self.head_size**-0.5 self.num_heads = config.num_attention_heads + self.device = device def forward(self, hidden_states, cu_seqlens, max_s): residual = hidden_states @@ -225,7 +271,10 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) + if device.type == "hpu": + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=False) self.hidden_size = config.hidden_size super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index 2df62987..8114be12 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -78,9 +78,6 @@ class FlashBatch(Batch): @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch": - if device.type != "cuda": - raise RuntimeError(f"FlashBatch does not support device {device}") - batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device) batch_token_type_ids = torch.tensor( pb.token_type_ids, dtype=torch.int32, device=device diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py new file mode 100644 index 00000000..d450b373 --- /dev/null +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -0,0 +1,66 @@ +import os +from loguru import logger +import importlib.metadata +import importlib.util +from packaging import version +import torch +import subprocess + + +def _is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return ( + str(version.parse(full_version).major) + + "." + + str(version.parse(full_version).minor) + ) + + _torch_version = importlib.metadata.version("torch") + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + _ipex_version = "N/A" + try: + _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") + except importlib.metadata.PackageNotFoundError: + return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + logger.warning( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +def is_hpu() -> bool: + is_hpu_available = True + try: + subprocess.run(["hl-smi"], capture_output=True, check=True) + except: + is_hpu_available = False + return is_hpu_available + + +def use_ipex() -> bool: + value = os.environ.get("USE_IPEX", "True").lower() + return value in ["true", "1"] and _is_ipex_available() + + +def get_device(): + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + elif is_hpu(): + import habana_frameworks.torch.core as htcore + + if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore + device = torch.device("hpu") + elif use_ipex(): + import intel_extension_for_pytorch as ipex + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + + return device diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/utils/flash_attn.py index 1d325351..ecfc4f9e 100644 --- a/backends/python/server/text_embeddings_server/utils/flash_attn.py +++ b/backends/python/server/text_embeddings_server/utils/flash_attn.py @@ -1,74 +1,189 @@ import os import torch +from text_embeddings_server.utils.device import use_ipex, is_hpu from loguru import logger if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 - HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2 = False -try: - try: - import flash_attn_2_cuda - except ImportError: - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) + +is_hpu = is_hpu() +use_ipex = use_ipex() + +if use_ipex or is_hpu: HAS_FLASH_ATTN_V2 = True -except ImportError as e: +else: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def hpu_attn( + q, + k, + v, + out, + seqlen_q, + seqlen_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal=False, +): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + total_q, num_head, head_size = q.size() + total_k, num_head_k, _ = k.size() + batch_size = seqlen_q.size(0) - 1 + seqlen_q_ = seqlen_q.clone() + seqlen_q_[:batch_size] = seqlen_q[1:] + seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size] + seqlen_k_ = seqlen_k.clone() + seqlen_k_[:batch_size] = seqlen_k[1:] + seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size] + + pad_q = torch.zeros( + [batch_size, max_seqlen_q, num_head, head_size], + dtype=q.dtype, + device=q.device, + ) + pad_k = torch.zeros( + [batch_size, max_seqlen_k, num_head_k, head_size], + dtype=k.dtype, + device=k.device, + ) + pad_v = torch.zeros( + [batch_size, max_seqlen_k, num_head_k, head_size], + dtype=v.dtype, + device=v.device, + ) + q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat( + batch_size, 1 + ) + q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1)) + k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat( + batch_size, 1 + ) + k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1)) + align_mask_seqlen = max_seqlen_k + attn_mask = torch.empty( + [batch_size, 1, 1, align_mask_seqlen], + dtype=q.dtype, + device=q.device, + ).fill_(float("-inf")) + attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0) + + pad_q[q_mask] = q + pad_k[k_mask] = k + pad_v[k_mask] = v + + pad_q = pad_q.permute(0, 2, 1, 3) + pad_k = pad_k.permute(0, 2, 1, 3) + pad_v = pad_v.permute(0, 2, 1, 3) + if is_causal: + attn_mask = None + + out_ = FusedSDPA.apply( + pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale + ) + out_ = out_.permute(0, 2, 1, 3) + out.copy_(out_[q_mask]) + return out def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): if HAS_FLASH_ATTN_V2: - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - is_causal, - -1, - -1, - False, - None, - ) + if use_ipex: + import intel_extension_for_pytorch as ipex + + return ipex.llm.functional.varlen_attention( + q.contiguous() if q.device.type == "xpu" else q, + k.contiguous() if k.device.type == "xpu" else k, + v.contiguous() if v.device.type == "xpu" else v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0, + softmax_scale, + zero_tensors=False, + is_causal=False, + return_softmax=False, + gen_=None, + ) + elif is_hpu: + return hpu_attn( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + softmax_scale, + is_causal=False, + ) + + else: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + -1, + -1, + False, + None, + ) if HAS_FLASH_ATTN: return flash_attn_cuda.fwd( diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 53193ef7..3b08e92a 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -14,6 +14,8 @@ pub enum DType { Float16, #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] Float32, + #[cfg(feature = "python")] + Bfloat16, } impl fmt::Display for DType { @@ -27,6 +29,8 @@ impl fmt::Display for DType { DType::Float16 => write!(f, "float16"), #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] DType::Float32 => write!(f, "float32"), + #[cfg(feature = "python")] + DType::Bfloat16 => write!(f, "bfloat16"), } } } @@ -47,10 +51,15 @@ impl Default for DType { feature = "accelerate", feature = "mkl", feature = "mkl-dynamic", - feature = "ort" + feature = "ort", + feature = "python" )))] { DType::Float16 } + #[cfg(feature = "python")] + { + DType::Bfloat16 + } } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index bafd97ac..74ba5e8a 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -1,8 +1,11 @@ mod dtype; use hf_hub::api::tokio::{ApiError, ApiRepo}; +use rand::Rng; use std::cmp::{max, min}; +use std::env; use std::path::PathBuf; +use std::process::Command; use std::sync::Arc; use std::thread::JoinHandle; use std::time::{Duration, Instant}; @@ -24,6 +27,28 @@ use text_embeddings_backend_ort::OrtBackend; #[cfg(feature = "python")] use text_embeddings_backend_python::PythonBackend; +fn powers_of_two(max_value: usize) -> Vec { + let mut result = Vec::new(); + let mut power: usize = 1; + + while power <= max_value { + result.push(power); + power *= 2; + } + + result +} + +fn is_hpu() -> bool { + match Command::new("hl-smi") + .args(["-Q", "name", "-f", "csv"]) + .output() + { + Ok(output) => output.status.success(), + Err(_) => false, + } +} + #[derive(Debug, Clone)] pub struct Backend { /// Channel to communicate with the background thread @@ -75,6 +100,101 @@ impl Backend { }) } + #[instrument(skip(self))] + pub async fn warmup_hpu( + &self, + mut max_input_length: usize, + max_token: usize, + max_bs: Option, + ) -> Result<(), BackendError> { + let read_env_var = |key: &str, default: usize| -> usize { + env::var(key) + .ok() + .map_or(default, |value| value.parse::().unwrap()) + }; + let seq_bucket_size: usize = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128); + let max_warmup_length: usize = read_env_var("MAX_WARMUP_SEQUENCE_LENGTH", 1024); + + let max_batch_size = max_bs.unwrap_or_else(|| read_env_var("MAX_WARMUP_BATCH_SIZE", 8)); + + let mut batch_sizes: Vec = powers_of_two(max_batch_size); + if let Some(&last) = batch_sizes.last() { + if last < max_batch_size { + batch_sizes.push(max_batch_size); + } + } + if max_warmup_length > max_input_length { + return Err(BackendError::Start( + format!("max_warmup_length ({max_warmup_length}) exceeds model's max_input_length ({max_input_length}), you can modify this value adding `-e MAX_WARMUP_SEQUENCE_LENGTH=` to your Docker run command") + )); + } + if seq_bucket_size > max_warmup_length { + return Err(BackendError::Start( + format!("PAD_SEQUENCE_TO_MULTIPLE_OF ({seq_bucket_size}) exceeds model's max warmup length ({max_warmup_length}), you can modify these values adding `-e PAD_SEQUENCE_TO_MULTIPLE_OF=` or `-e MAX_WARMUP_SEQUENCE_LENGTH= to your Docker run command`") + )); + } + + max_input_length = std::cmp::min(max_input_length, max_warmup_length); + let mut seq_lengths: Vec = (seq_bucket_size..=max_input_length) + .step_by(seq_bucket_size) + .collect(); + if let Some(&last) = seq_lengths.last() { + if last < max_input_length { + seq_lengths.push(max_input_length); + } + } + + let mut shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len() * seq_lengths.len()); + for batch_size in &batch_sizes { + for seq_length in &seq_lengths { + shapes.push((*batch_size as u32, *seq_length as u32)); + } + } + for shape in shapes.iter() { + let batch = self.create_warmup_batch(*shape, max_token as u32); + match &self.model_type { + ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), + }?; + tracing::info!("finish warmup for batch: {}, length: {}", shape.0, shape.1); + } + Ok(()) + } + + #[instrument(skip_all)] + pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32) -> Batch { + let (batch_size, length) = shape; + let mut batched_input_ids = Vec::new(); + let mut batched_token_type_ids = Vec::new(); + let mut batched_position_ids = Vec::new(); + let mut cumulative_seq_lengths = Vec::with_capacity(batch_size as usize + 1); + let mut pooled_indices = Vec::with_capacity(batch_size as usize); + cumulative_seq_lengths.push(0); + let input_ids: Vec = (0..length) + .map(|_| rand::thread_rng().gen_range(0..max_token)) + .collect(); + let token_type_ids: Vec = vec![0; length as usize]; + let position_ids: Vec = (0..length).collect(); + let mut current_length = 0; + for batch_id in 0..batch_size { + batched_input_ids.extend(input_ids.iter().cloned()); + batched_token_type_ids.extend(token_type_ids.iter().cloned()); + batched_position_ids.extend(position_ids.iter().cloned()); + current_length += input_ids.len(); + cumulative_seq_lengths.push(current_length as u32); + pooled_indices.push(batch_id); + } + Batch { + input_ids: batched_input_ids, + token_type_ids: batched_token_type_ids, + position_ids: batched_position_ids, + cumulative_seq_lengths, + max_length: length, + pooled_indices, + raw_indices: vec![], + } + } + #[instrument(skip(self))] pub async fn warmup( &self, @@ -82,6 +202,12 @@ impl Backend { max_batch_tokens: usize, max_batch_requests: Option, ) -> Result<(), BackendError> { + if is_hpu() { + return self + .warmup_hpu(max_input_length, max_batch_tokens, max_batch_requests) + .await; + } + let mut input_ids = Vec::with_capacity(max_batch_tokens); let mut token_type_ids = Vec::with_capacity(max_batch_tokens); let mut position_ids = Vec::with_capacity(max_batch_tokens);