Skip to content

Commit

Permalink
Add rerank model support for python backend (#25)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu authored Jul 26, 2024
1 parent 3c8fd6f commit 53e8f76
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 16 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ Not all features of TEI are currently supported as this is still a work in progr
> The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE
>
> Please reach out to [email protected] if you have any question.
21 changes: 21 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,25 @@ impl Client {
let response = self.stub.embed(request).await?.into_inner();
Ok(response.embeddings)
}

#[instrument(skip_all)]
pub async fn predict(
&mut self,
input_ids: Vec<u32>,
token_type_ids: Vec<u32>,
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Ok(response.scores)
}
}
10 changes: 10 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ service EmbeddingService {
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
/// Predict
rpc Predict (EmbedRequest) returns (PredictResponse);
}

message HealthRequest {}
Expand All @@ -28,3 +30,11 @@ message Embedding {
message EmbedResponse {
repeated Embedding embeddings = 1;
}

message Score {
repeated float values = 1;
}

message PredictResponse {
repeated Score scores = 1;
}
15 changes: 11 additions & 4 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel

__all__ = ["Model"]

Expand Down Expand Up @@ -66,10 +68,15 @@ def get_model(model_path: Path, dtype: Optional[str]):
):
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
else:
try:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
return ClassificationModel(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
except:
raise RuntimeError(f"Unknown model_type {config.model_type}")

raise RuntimeError(f"Unsupported model_type {config.model_type}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import inspect
import torch

from loguru import logger
from pathlib import Path
from typing import Type, List
from transformers import AutoModelForSequenceClassification
from opentelemetry import trace

from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)

class ClassificationModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()

model = AutoModelForSequenceClassification.from_pretrained(model_path)
model = model.to(dtype).to(device)
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)

self.hidden_size = model.config.hidden_size
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
position_offset = model.config.pad_token_id + 1
max_input_length = 0
if hasattr(model.config, "max_seq_length"):
max_input_length = model.config.max_seq_length
else:
max_input_length = model.config.max_position_embeddings - position_offset
self.max_input_length = max_input_length
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(ClassificationModel, self).__init__(model=model, dtype=dtype, device=device)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
pass

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs, return_dict=True)
scores = output.logits.view(-1, ).tolist()
return [
Score(
values=scores[i:i+1]
)
for i in range(len(batch))
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -72,3 +72,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
)
for i in range(len(batch))
]

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import List, TypeVar, Type

from text_embeddings_server.models.types import Batch, Embedding
from text_embeddings_server.models.types import Batch, Embedding, Score

B = TypeVar("B", bound=Batch)

Expand All @@ -27,3 +27,7 @@ def batch_type(self) -> Type[B]:
@abstractmethod
def embed(self, batch: B) -> List[Embedding]:
raise NotImplementedError

@abstractmethod
def predict(self, batch: B) -> List[Score]:
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from opentelemetry import trace

from text_embeddings_server.pb import embed_pb2
from text_embeddings_server.pb.embed_pb2 import Embedding
from text_embeddings_server.pb.embed_pb2 import Embedding, Score

tracer = trace.get_tracer(__name__)
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
Expand Down
8 changes: 8 additions & 0 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ async def Embed(self, request, context):

return embed_pb2.EmbedResponse(embeddings=embeddings)

async def Predict(self, request, context):
max_input_length = self.model.max_input_length
batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length)

scores = self.model.predict(batch)

return embed_pb2.PredictResponse(scores=scores)


def serve(
model_path: Path,
Expand Down
37 changes: 29 additions & 8 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ impl PythonBackend {
) -> Result<Self, BackendError> {
match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
None
}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool
Some(pool)
}
};

Expand Down Expand Up @@ -109,9 +107,32 @@ impl Backend for PythonBackend {
Ok(embeddings)
}

fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> {
Err(BackendError::Inference(
"`predict` is not implemented".to_string(),
))
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
if !batch.raw_indices.is_empty() {
return Err(BackendError::Inference(
"raw embeddings are not supported for the Python backend.".to_string(),
));
}
let batch_size = batch.len();
let results = self
.tokio_runtime
.block_on(self.backend_client.clone().predict(
batch.input_ids,
batch.token_type_ids,
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
))
.map_err(|err| BackendError::Inference(err.to_string()))?;
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();

let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());

for (i, r) in raw_results.into_iter().enumerate() {
predictions.insert(i, r);
}

Ok(predictions)
}
}

0 comments on commit 53e8f76

Please sign in to comment.