Skip to content

Commit

Permalink
Add warmup function to speed up the first embedding (huggingface#8)
Browse files Browse the repository at this point in the history
Signed-off-by: Liu, Kaixuan <[email protected]>
  • Loading branch information
kaixuanliu authored May 6, 2024
1 parent 54beebd commit 5e584b8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
2 changes: 1 addition & 1 deletion backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ text-embeddings-backend-python = { path = "python", optional = true }
text-embeddings-backend-candle = { path = "candle", optional = true }
tokio = { version = "^1.25", features = ["sync"] }
tracing = "^0.1"

rand = "^0.8"
[features]
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
python = ["dep:text-embeddings-backend-python"]
Expand Down
53 changes: 50 additions & 3 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
mod dtype;

use std::env;
use std::path::PathBuf;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use text_embeddings_backend_core::{Backend as CoreBackend, Predictions};
use tokio::sync::{mpsc, oneshot, watch};
use tracing::{instrument, Span};

use rand::Rng;
pub use crate::dtype::DType;
pub use text_embeddings_backend_core::{
BackendError, Batch, Embedding, Embeddings, ModelType, Pool,
Expand Down Expand Up @@ -98,6 +98,54 @@ impl Backend {
}
}

#[instrument(skip(self))]
pub async fn warmup(
&self,
max_input_length: u32,
max_token: u32,
) -> Result<(), BackendError> {
let read_env_var = |key: &str, default: u32| -> u32 {
env::var(key).ok().map_or(default, |value| value.parse::<u32>().unwrap())
};
// get all possible sequence lengths for prefill
let bucket_size: u32 = read_env_var("PAD_SEQUENCE_TO_MULTIPLE_OF", 128);
let mut seq_lengths: Vec<u32> = (bucket_size..max_input_length+1).step_by(bucket_size as usize).collect();
if let Some(&last) = seq_lengths.last() {
if last < max_input_length {
seq_lengths.push(max_input_length);
}
}
for &length in seq_lengths.iter() {
tracing::info!("warmup for length: {}", length);
let batch = self.create_warmup_batch(length, max_token);
match &self.model_type {
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
};
}
Ok(())
}

#[instrument(skip_all)]
pub fn create_warmup_batch(
&self,
length: u32,
max_token: u32,
) -> Batch {
let input_ids = (0..length).map(|_| rand::thread_rng().gen_range(0..max_token)).collect();
let token_type_ids: Vec<u32> = vec![0; length as usize];
let position_ids: Vec<u32> = (0..length).collect();
let cumulative_seq_lengths: Vec<u32> = vec![0, length - 1];
Batch {
input_ids: input_ids,
token_type_ids: token_type_ids,
position_ids: position_ids,
cumulative_seq_lengths: cumulative_seq_lengths,
max_length: length,
pooled_indices: vec![0],
raw_indices: vec![],
}
}
#[instrument(skip(self))]
pub fn health_watcher(&self) -> watch::Receiver<bool> {
self.health_receiver.clone()
Expand All @@ -106,7 +154,6 @@ impl Backend {
#[instrument(skip_all)]
pub async fn embed(&self, batch: Batch) -> Result<(Embeddings, Duration), BackendError> {
let (sender, receiver) = oneshot::channel();

self.backend_sender
.send(BackendCommand::Embed(batch, Span::current(), sender))
.expect("No backend receiver. This is a bug.");
Expand Down
6 changes: 6 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ pub async fn run(
.await
.context("Model backend is not healthy")?;

// Warmup
if backend.warmup(
max_input_length as u32,
max_batch_tokens as u32).await.is_ok() {
tracing::info!("Succeed doing warmup");
}
let max_batch_requests = backend
.max_batch_size
.map(|s| {
Expand Down

0 comments on commit 5e584b8

Please sign in to comment.