Skip to content

Commit

Permalink
feat(router): add truncation direction parameter (huggingface#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jun 21, 2024
1 parent 734780c commit 99cdf22
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 52 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev =
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "33b7ecf9ed82bb7c20f1a94555218fabfbaa2fe3", package = "candle-flash-attn" }
hf-hub = { git = "https://github.com/huggingface/hf-hub", rev = "b167f69692be5f49eb8003788f7f8a499a98b096" }


[profile.release]
debug = 0
incremental = true
lto = "fat"
opt-level = 3
codegen-units = 1
strip = "symbols"
panic = "abort"

[profile.release-debug]
inherits = "release"
debug = 1
lto = "thin"
codegen-units = 16
strip = "none"
incremental = true
5 changes: 5 additions & 0 deletions Dockerfile-cuda
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ ARG CUDA_COMPUTE_CAP=80
ARG GIT_SHA
ARG DOCKER_LABEL

# Limit parallelism
ARG RAYON_NUM_THREADS
ARG CARGO_BUILD_JOBS
ARG CARGO_BUILD_INCREMENTAL

# sccache specific variables
ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile-cuda-all
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

# limit the number of kernels built at the same time
# Limit parallelism
ARG RAYON_NUM_THREADS=4
ARG CARGO_BUILD_JOBS
ARG CARGO_BUILD_INCREMENTAL

WORKDIR /usr/src

Expand Down
37 changes: 32 additions & 5 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::TextEmbeddingsError;
use std::sync::Arc;
use std::time::{Duration, Instant};
use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType};
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore};
use tracing::instrument;

Expand Down Expand Up @@ -117,6 +118,7 @@ impl Infer {
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
permit: OwnedSemaphorePermit,
) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -131,7 +133,14 @@ impl Infer {
}

let results = self
.embed(inputs, truncate, false, &start_time, permit)
.embed(
inputs,
truncate,
truncation_direction,
false,
&start_time,
permit,
)
.await?;

let InferResult::AllEmbedding(response) = results else {
Expand Down Expand Up @@ -165,6 +174,7 @@ impl Infer {
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -179,7 +189,14 @@ impl Infer {
}

let results = self
.embed(inputs, truncate, true, &start_time, permit)
.embed(
inputs,
truncate,
truncation_direction,
true,
&start_time,
permit,
)
.await?;

let InferResult::PooledEmbedding(response) = results else {
Expand Down Expand Up @@ -213,6 +230,7 @@ impl Infer {
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
normalize: bool,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
Expand All @@ -228,7 +246,14 @@ impl Infer {
}

let results = self
.embed(inputs, truncate, true, &start_time, permit)
.embed(
inputs,
truncate,
truncation_direction,
true,
&start_time,
permit,
)
.await?;

let InferResult::PooledEmbedding(mut response) = results else {
Expand Down Expand Up @@ -278,6 +303,7 @@ impl Infer {
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
pooling: bool,
start_time: &Instant,
_permit: OwnedSemaphorePermit,
Expand All @@ -296,7 +322,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate)
.encode(inputs.into(), truncate, truncation_direction)
.await
.map_err(|err| {
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
Expand Down Expand Up @@ -340,6 +366,7 @@ impl Infer {
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
raw_scores: bool,
_permit: OwnedSemaphorePermit,
) -> Result<ClassificationInferResponse, TextEmbeddingsError> {
Expand All @@ -357,7 +384,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate)
.encode(inputs.into(), truncate, truncation_direction)
.await
.map_err(|err| {
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
Expand Down
15 changes: 13 additions & 2 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ impl Tokenization {
&self,
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
Expand All @@ -80,6 +81,7 @@ impl Tokenization {
.send(TokenizerRequest::Encode(
inputs,
truncate,
truncation_direction,
response_sender,
Span::current(),
))
Expand Down Expand Up @@ -163,14 +165,21 @@ fn tokenizer_worker(
// Loop over requests
while let Some(request) = receiver.blocking_recv() {
match request {
TokenizerRequest::Encode(inputs, truncate, response_tx, parent_span) => {
TokenizerRequest::Encode(
inputs,
truncate,
truncation_direction,
response_tx,
parent_span,
) => {
parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(encode_input(
inputs,
truncate,
truncation_direction,
max_input_length,
position_offset,
&mut tokenizer,
Expand Down Expand Up @@ -247,13 +256,14 @@ fn tokenize_input(
fn encode_input(
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
max_input_length: usize,
position_offset: usize,
tokenizer: &mut Tokenizer,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Default truncation params
let truncate_params = truncate.then_some(TruncationParams {
direction: TruncationDirection::Right,
direction: truncation_direction,
max_length: max_input_length,
strategy: TruncationStrategy::LongestFirst,
stride: 0,
Expand Down Expand Up @@ -316,6 +326,7 @@ enum TokenizerRequest {
Encode(
EncodingInput,
bool,
TruncationDirection,
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
Span,
),
Expand Down
12 changes: 12 additions & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ message Metadata {
uint64 inference_time_ns = 6;
}

enum TruncationDirection {
TRUNCATION_DIRECTION_RIGHT = 0;
TRUNCATION_DIRECTION_LEFT = 1;
}

message EmbedRequest {
string inputs = 1;
bool truncate = 2;
bool normalize = 3;
TruncationDirection truncation_direction = 4;
}

message EmbedResponse {
Expand All @@ -83,6 +89,7 @@ message EmbedResponse {
message EmbedSparseRequest {
string inputs = 1;
bool truncate = 2;
TruncationDirection truncation_direction = 3;
}

message SparseValue {
Expand All @@ -98,6 +105,7 @@ message EmbedSparseResponse {
message EmbedAllRequest {
string inputs = 1;
bool truncate = 2;
TruncationDirection truncation_direction = 3;
}

message TokenEmbedding {
Expand All @@ -113,12 +121,14 @@ message PredictRequest {
string inputs = 1;
bool truncate = 2;
bool raw_scores = 3;
TruncationDirection truncation_direction = 4;
}

message PredictPairRequest {
repeated string inputs = 1;
bool truncate = 2;
bool raw_scores = 3;
TruncationDirection truncation_direction = 4;
}

message Prediction {
Expand All @@ -137,6 +147,7 @@ message RerankRequest {
bool truncate = 3;
bool raw_scores = 4;
bool return_text = 5;
TruncationDirection truncation_direction = 6;
}

message RerankStreamRequest{
Expand All @@ -147,6 +158,7 @@ message RerankStreamRequest{
bool raw_scores = 4;
// The server will only consider the first value
bool return_text = 5;
TruncationDirection truncation_direction = 6;
}

message Rank {
Expand Down
Loading

0 comments on commit 99cdf22

Please sign in to comment.