Skip to content

Commit

Permalink
feat: splade pooling (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Feb 29, 2024
1 parent 337fbd6 commit 9aa020e
Show file tree
Hide file tree
Showing 19 changed files with 2,448 additions and 348 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,16 @@ Options:
--pooling <POOLING>
Optionally control the pooling method for embedding models.
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
configuration.
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` configuration.
If `pooling` is set, it will override the model pooling configuration
[env: POOLING=]
[possible values: cls, mean]
Possible values:
- cls: Select the CLS token as embedding
- mean: Apply Mean pooling to the model embeddings
- splade: Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. This option is only available if the loaded model is a `ForMaskedLM` Transformer model
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
The maximum amount of concurrent requests for this particular deployment.
Expand Down
31 changes: 20 additions & 11 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ impl LayerNorm {
})
}

pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result<Tensor> {
pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu | Device::Metal(_) => {
let hidden_states = hidden_states.add(residual)?;
let mut hidden_states = hidden_states.clone();
if let Some(residual) = residual {
hidden_states = hidden_states.add(residual)?;
}
let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
Expand All @@ -51,19 +54,25 @@ impl LayerNorm {
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::fused_add_layer_norm;
use candle_layer_norm::{fused_add_layer_norm, layer_norm};

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
let residual = residual.flatten_to(D::Minus2)?;

let (result, _) = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
Some(&self.bias),
self.epsilon,
)?;
let result = if let Some(residual) = residual {
let residual = residual.flatten_to(D::Minus2)?;

let (result, _) = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
Some(&self.bias),
self.epsilon,
)?;
Ok(result)
} else {
layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon)
}?;
result.reshape(original_shape)
}
#[cfg(not(feature = "cuda"))]
Expand Down
35 changes: 33 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ use crate::compute_cap::{
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
};
use crate::models::{
BertModel, JinaBertModel, Model, NomicBertModel, NomicConfig, PositionEmbeddingType,
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
NomicConfig, PositionEmbeddingType,
};
#[cfg(feature = "cuda")]
use crate::models::{FlashBertModel, FlashJinaBertModel, FlashNomicBertModel};
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
};
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::BertConfig;
Expand All @@ -33,6 +36,8 @@ enum Config {
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
NomicBert(NomicConfig),
}
Expand Down Expand Up @@ -119,6 +124,12 @@ impl CandleBackend {
BertModel::load_roberta(vb, &config, model_type).s()?,
))
}
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting DistilBertModel model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting NomicBertModel model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
Expand Down Expand Up @@ -175,6 +186,26 @@ impl CandleBackend {
}
}
#[cfg(feature = "cuda")]
(Config::DistilBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
Ok(Box::new(
FlashDistilBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting DistilBertModel model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
}
#[cfg(feature = "cuda")]
(Config::NomicBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
Expand Down
12 changes: 10 additions & 2 deletions backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

mod bert;
mod distilbert;
mod jina;
mod nomic;

#[cfg(feature = "cuda")]
mod flash_bert;

#[cfg(feature = "cuda")]
mod flash_jina;
mod jina;

#[cfg(feature = "cuda")]
mod flash_nomic;
mod nomic;

#[cfg(feature = "cuda")]
mod flash_distilbert;

pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
pub use jina::JinaBertModel;
pub use nomic::{NomicBertModel, NomicConfig};
use text_embeddings_backend_core::Batch;
Expand All @@ -32,6 +37,9 @@ pub use flash_jina::FlashJinaBertModel;
#[cfg(feature = "cuda")]
pub use flash_nomic::FlashNomicBertModel;

#[cfg(feature = "cuda")]
pub use flash_distilbert::FlashDistilBertModel;

pub(crate) trait Model {
fn is_padded(&self) -> bool;

Expand Down
21 changes: 14 additions & 7 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub struct BertConfig {
#[serde(default)]
pub use_cache: bool,
pub classifier_dropout: Option<f64>,
pub model_type: Option<String>,
pub id2label: Option<HashMap<String, String>>,
}

Expand All @@ -39,7 +38,7 @@ pub enum PositionEmbeddingType {
}

#[derive(Debug)]
struct BertEmbeddings {
pub struct BertEmbeddings {
word_embeddings: Embedding,
token_type_embeddings: Embedding,
position_embeddings: Embedding,
Expand Down Expand Up @@ -80,7 +79,7 @@ impl BertEmbeddings {
})
}

fn forward(
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
Expand All @@ -93,7 +92,9 @@ impl BertEmbeddings {
let position_embeddings = self.position_embeddings.forward(position_ids)?;

let embeddings = input_embeddings.add(&token_type_embeddings)?;
let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?;
let embeddings = self
.layer_norm
.forward(&embeddings, Some(&position_embeddings))?;

Ok(embeddings)
}
Expand Down Expand Up @@ -255,7 +256,7 @@ impl BertAttention {
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;

let hidden_states = self.dense.forward(&context_layer)?;
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -324,7 +325,7 @@ impl BertLayer {

let hidden_states = self.intermediate.forward(&hidden_states)?;
let hidden_states = self.output.forward(&hidden_states)?;
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -469,7 +470,12 @@ impl BertModel {
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => (pool, None),
ModelType::Embedding(pool) => {
if pool == Pool::Splade {
candle::bail!("`splade` is not supported for Nomic")
}
(pool, None)
}
};

let (embeddings, encoder) = match (
Expand Down Expand Up @@ -724,6 +730,7 @@ impl BertModel {

(outputs.sum(1)?.broadcast_div(&input_lengths))?
}
Pool::Splade => unreachable!(),
};
Some(pooled_embeddings)
} else {
Expand Down
Loading

0 comments on commit 9aa020e

Please sign in to comment.