diff --git a/README.md b/README.md index 021ded52..001a1e6f 100644 --- a/README.md +++ b/README.md @@ -152,13 +152,16 @@ Options: --pooling Optionally control the pooling method for embedding models. - If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` - configuration. + If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` configuration. If `pooling` is set, it will override the model pooling configuration [env: POOLING=] - [possible values: cls, mean] + + Possible values: + - cls: Select the CLS token as embedding + - mean: Apply Mean pooling to the model embeddings + - splade: Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. This option is only available if the loaded model is a `ForMaskedLM` Transformer model --max-concurrent-requests The maximum amount of concurrent requests for this particular deployment. diff --git a/backends/candle/src/layers/layer_norm.rs b/backends/candle/src/layers/layer_norm.rs index 879e2b1f..0c360572 100644 --- a/backends/candle/src/layers/layer_norm.rs +++ b/backends/candle/src/layers/layer_norm.rs @@ -23,12 +23,15 @@ impl LayerNorm { }) } - pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result { + pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result { let _enter = self.span.enter(); match hidden_states.device() { Device::Cpu | Device::Metal(_) => { - let hidden_states = hidden_states.add(residual)?; + let mut hidden_states = hidden_states.clone(); + if let Some(residual) = residual { + hidden_states = hidden_states.add(residual)?; + } let hidden_states_dtype = hidden_states.dtype(); let internal_dtype = match hidden_states_dtype { DType::F16 | DType::BF16 => DType::F32, @@ -51,19 +54,25 @@ impl LayerNorm { Device::Cuda(_) => { #[cfg(feature = "cuda")] { - use candle_layer_norm::fused_add_layer_norm; + use candle_layer_norm::{fused_add_layer_norm, layer_norm}; let original_shape = hidden_states.shape(); let hidden_states = hidden_states.flatten_to(D::Minus2)?; - let residual = residual.flatten_to(D::Minus2)?; - let (result, _) = fused_add_layer_norm( - &hidden_states, - &residual, - &self.weight, - Some(&self.bias), - self.epsilon, - )?; + let result = if let Some(residual) = residual { + let residual = residual.flatten_to(D::Minus2)?; + + let (result, _) = fused_add_layer_norm( + &hidden_states, + &residual, + &self.weight, + Some(&self.bias), + self.epsilon, + )?; + Ok(result) + } else { + layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon) + }?; result.reshape(original_shape) } #[cfg(not(feature = "cuda"))] diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index c25413bf..972148db 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,10 +11,13 @@ use crate::compute_cap::{ get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap, }; use crate::models::{ - BertModel, JinaBertModel, Model, NomicBertModel, NomicConfig, PositionEmbeddingType, + BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel, + NomicConfig, PositionEmbeddingType, }; #[cfg(feature = "cuda")] -use crate::models::{FlashBertModel, FlashJinaBertModel, FlashNomicBertModel}; +use crate::models::{ + FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel, +}; use candle::{DType, Device}; use candle_nn::VarBuilder; use models::BertConfig; @@ -33,6 +36,8 @@ enum Config { XlmRoberta(BertConfig), Camembert(BertConfig), Roberta(BertConfig), + #[serde(rename(deserialize = "distilbert"))] + DistilBert(DistilBertConfig), #[serde(rename(deserialize = "nomic_bert"))] NomicBert(NomicConfig), } @@ -119,6 +124,12 @@ impl CandleBackend { BertModel::load_roberta(vb, &config, model_type).s()?, )) } + (Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting DistilBertModel model on {:?}", device); + Ok(Box::new( + DistilBertModel::load(vb, &config, model_type).s()?, + )) + } (Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => { tracing::info!("Starting NomicBertModel model on {:?}", device); Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?)) @@ -175,6 +186,26 @@ impl CandleBackend { } } #[cfg(feature = "cuda")] + (Config::DistilBert(config), Device::Cuda(_)) => { + if cfg!(feature = "flash-attn") + && dtype == DType::F16 + && &std::env::var("USE_FLASH_ATTENTION") + .unwrap_or("True".to_string()) + .to_lowercase() + == "true" + { + tracing::info!("Starting FlashNomicBertModel model on {:?}", device); + Ok(Box::new( + FlashDistilBertModel::load(vb, &config, model_type).s()?, + )) + } else { + tracing::info!("Starting DistilBertModel model on {:?}", device); + Ok(Box::new( + DistilBertModel::load(vb, &config, model_type).s()?, + )) + } + } + #[cfg(feature = "cuda")] (Config::NomicBert(config), Device::Cuda(_)) => { if cfg!(feature = "flash-attn") && dtype == DType::F16 diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index 4e54c99d..ed89482e 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -5,20 +5,25 @@ extern crate intel_mkl_src; extern crate accelerate_src; mod bert; +mod distilbert; +mod jina; +mod nomic; #[cfg(feature = "cuda")] mod flash_bert; #[cfg(feature = "cuda")] mod flash_jina; -mod jina; #[cfg(feature = "cuda")] mod flash_nomic; -mod nomic; + +#[cfg(feature = "cuda")] +mod flash_distilbert; pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; +pub use distilbert::{DistilBertConfig, DistilBertModel}; pub use jina::JinaBertModel; pub use nomic::{NomicBertModel, NomicConfig}; use text_embeddings_backend_core::Batch; @@ -32,6 +37,9 @@ pub use flash_jina::FlashJinaBertModel; #[cfg(feature = "cuda")] pub use flash_nomic::FlashNomicBertModel; +#[cfg(feature = "cuda")] +pub use flash_distilbert::FlashDistilBertModel; + pub(crate) trait Model { fn is_padded(&self) -> bool; diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index 7620fe37..f5c74133 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -26,7 +26,6 @@ pub struct BertConfig { #[serde(default)] pub use_cache: bool, pub classifier_dropout: Option, - pub model_type: Option, pub id2label: Option>, } @@ -39,7 +38,7 @@ pub enum PositionEmbeddingType { } #[derive(Debug)] -struct BertEmbeddings { +pub struct BertEmbeddings { word_embeddings: Embedding, token_type_embeddings: Embedding, position_embeddings: Embedding, @@ -80,7 +79,7 @@ impl BertEmbeddings { }) } - fn forward( + pub fn forward( &self, input_ids: &Tensor, token_type_ids: &Tensor, @@ -93,7 +92,9 @@ impl BertEmbeddings { let position_embeddings = self.position_embeddings.forward(position_ids)?; let embeddings = input_embeddings.add(&token_type_embeddings)?; - let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?; + let embeddings = self + .layer_norm + .forward(&embeddings, Some(&position_embeddings))?; Ok(embeddings) } @@ -255,7 +256,7 @@ impl BertAttention { let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; let hidden_states = self.dense.forward(&context_layer)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -324,7 +325,7 @@ impl BertLayer { let hidden_states = self.intermediate.forward(&hidden_states)?; let hidden_states = self.output.forward(&hidden_states)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -469,7 +470,12 @@ impl BertModel { Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?); (pool, Some(classifier)) } - ModelType::Embedding(pool) => (pool, None), + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Nomic") + } + (pool, None) + } }; let (embeddings, encoder) = match ( @@ -724,6 +730,7 @@ impl BertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::Splade => unreachable!(), }; Some(pooled_embeddings) } else { diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs new file mode 100644 index 00000000..5caa18af --- /dev/null +++ b/backends/candle/src/models/distilbert.rs @@ -0,0 +1,658 @@ +use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; +use crate::models::Model; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Embedding, VarBuilder}; +use serde::Deserialize; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct DistilBertConfig { + pub vocab_size: usize, + pub dim: usize, + pub n_layers: usize, + pub n_heads: usize, + pub hidden_dim: usize, + pub activation: HiddenAct, + pub max_position_embeddings: usize, + pub pad_token_id: usize, + pub model_type: Option, +} + +#[derive(Debug)] +pub struct DistilBertEmbeddings { + word_embeddings: Embedding, + position_embeddings: Embedding, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl DistilBertEmbeddings { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + Ok(Self { + word_embeddings: Embedding::new( + vb.pp("word_embeddings") + .get((config.vocab_size, config.dim), "weight")?, + config.dim, + ), + position_embeddings: Embedding::new( + vb.pp("position_embeddings") + .get((config.max_position_embeddings, config.dim), "weight")?, + config.dim, + ), + layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config.dim, 1e-12f32)?, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result { + let _enter = self.span.enter(); + + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let position_embeddings = self.position_embeddings.forward(position_ids)?; + + let embeddings = self + .layer_norm + .forward(&input_embeddings, Some(&position_embeddings))?; + + Ok(embeddings) + } +} + +#[derive(Debug)] +struct DistilBertAttention { + qkv_linear: Linear, + dense: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f64, + + span: tracing::Span, +} + +impl DistilBertAttention { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let attention_head_size = config.dim / config.n_heads; + let all_head_size = config.n_heads * attention_head_size; + let hidden_size = config.dim; + + let query_weight = vb.pp("q_lin").get((all_head_size, hidden_size), "weight")?; + let query_bias = vb.pp("q_lin").get(all_head_size, "bias")?; + + let key_weight = vb.pp("k_lin").get((all_head_size, hidden_size), "weight")?; + let key_bias = vb.pp("k_lin").get(all_head_size, "bias")?; + + let value_weight = vb.pp("v_lin").get((all_head_size, hidden_size), "weight")?; + let value_bias = vb.pp("v_lin").get(all_head_size, "bias")?; + + let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; + let qkv_bias = Tensor::cat(&[&query_bias, &key_bias, &value_bias], 0)?; + + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); + + let dense_weight = vb.pp("out_lin").get((hidden_size, hidden_size), "weight")?; + let dense_bias = vb.pp("out_lin").get(hidden_size, "bias")?; + + let dense = Linear::new(dense_weight, Some(dense_bias), None); + + let softmax_scale = 1. / (attention_head_size as f64).sqrt(); + + Ok(Self { + qkv_linear, + dense, + num_attention_heads: config.n_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + let qkv = qkv.chunk(3, 1)?; + let query_layer = &qkv[0].contiguous()?; + let key_layer = &qkv[1].contiguous()?; + let value_layer = &qkv[2]; + + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = + (device, get_cublas_lt_wrapper()) + { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; + let key_layer = key_layer.flatten(0, 1)?; + let query_layer = query_layer.flatten(0, 1)?; + let value_layer = value_layer.flatten(0, 1)?; + let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &key_layer, + &query_layer, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &value_layer.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&query_layer), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let mut attention_scores = (attention_scores * self.softmax_scale)?; + + if let Some(attention_bias) = attention_bias { + attention_scores = attention_scores.add(attention_bias)?; + } + + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&value_layer.contiguous()?) + }?; + + let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + + let hidden_states = self.dense.forward(&context_layer)?; + + Ok(hidden_states) + } +} + +#[derive(Debug)] +pub struct DistilBertMLP { + lin1: Linear, + lin2: Linear, + + span: tracing::Span, +} + +impl DistilBertMLP { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let lin1_weight = vb + .pp("lin1") + .get((config.hidden_dim, config.dim), "weight")?; + let lin1_bias = vb.pp("lin1").get(config.hidden_dim, "bias")?; + let lin1 = Linear::new( + lin1_weight, + Some(lin1_bias), + Some(config.activation.clone()), + ); + + let lin2_weight = vb + .pp("lin2") + .get((config.dim, config.hidden_dim), "weight")?; + let lin2_bias = vb.pp("lin2").get(config.dim, "bias")?; + let lin2 = Linear::new(lin2_weight, Some(lin2_bias), None); + + Ok(Self { + lin1, + lin2, + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = self.lin1.forward(hidden_states)?; + self.lin2.forward(&hidden_states) + } +} + +#[derive(Debug)] +struct DistilBertBlock { + attention: DistilBertAttention, + mlp: DistilBertMLP, + post_attention_layer_norm: LayerNorm, + output_layer_norm: LayerNorm, + + span: tracing::Span, +} + +impl DistilBertBlock { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let attention = DistilBertAttention::load(vb.pp("attention"), config)?; + let mlp = DistilBertMLP::load(vb.pp("ffn"), config)?; + + let post_attention_layer_norm = + LayerNorm::load(vb.pp("sa_layer_norm"), config.dim, 1e-12f32)?; + let output_layer_norm = LayerNorm::load(vb.pp("output_layer_norm"), config.dim, 1e-12f32)?; + + Ok(Self { + attention, + mlp, + post_attention_layer_norm, + output_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + + let attn_output = self.attention.forward(hidden_states, attention_bias)?; + let hidden_states = self + .post_attention_layer_norm + .forward(hidden_states, Some(&attn_output))?; + + let mlp_out = self.mlp.forward(&hidden_states)?; + + self.output_layer_norm + .forward(&hidden_states, Some(&mlp_out)) + } +} + +#[derive(Debug)] +struct DistilBertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl DistilBertEncoder { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let layers = (0..config.n_layers) + .map(|index| DistilBertBlock::load(vb.pp(format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(DistilBertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_bias)?; + } + + Ok(hidden_states) + } +} + +#[derive(Debug)] +pub struct DistilBertSpladeHead { + vocab_transform: Linear, + vocab_projector: Linear, + vocab_layer_norm: LayerNorm, + span: tracing::Span, +} + +impl DistilBertSpladeHead { + pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let vocab_transform_weight = vb + .pp("vocab_transform") + .get((config.dim, config.dim), "weight")?; + let vocab_transform_bias = vb.pp("vocab_transform").get(config.dim, "bias")?; + let vocab_transform = Linear::new( + vocab_transform_weight, + Some(vocab_transform_bias), + Some(config.activation.clone()), + ); + + let vocab_projector_weight = vb + .pp("vocab_projector") + .get((config.vocab_size, config.dim), "weight")?; + let vocab_projector_bias = vb.pp("vocab_projector").get(config.vocab_size, "bias")?; + let vocab_projector = Linear::new( + vocab_projector_weight, + Some(vocab_projector_bias), + Some(HiddenAct::Relu), + ); + + let vocab_layer_norm = LayerNorm::load(vb.pp("vocab_layer_norm"), config.dim, 1e-12f32)?; + + Ok(Self { + vocab_transform, + vocab_projector, + vocab_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "splade"), + }) + } + + pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = self.vocab_transform.forward(hidden_states)?; + let hidden_states = self.vocab_layer_norm.forward(&hidden_states, None)?; + let hidden_states = self.vocab_projector.forward(&hidden_states)?; + Ok(hidden_states) + + // (1.0 + hidden_states)?.log() + } +} + +#[derive(Debug)] +pub struct DistilBertModel { + embeddings: DistilBertEmbeddings, + encoder: DistilBertEncoder, + pool: Pool, + splade: Option, + + num_attention_heads: usize, + + device: Device, + dtype: DType, + + span: tracing::Span, +} + +impl DistilBertModel { + pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result { + let pool = match model_type { + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for DistilBert") + } + ModelType::Embedding(pool) => pool, + }; + + let (embeddings, encoder) = match ( + DistilBertEmbeddings::load(vb.pp("embeddings"), config), + DistilBertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(embeddings), Ok(encoder)) = ( + DistilBertEmbeddings::load(vb.pp("distilbert.embeddings"), config), + DistilBertEncoder::load(vb.pp("distilbert.transformer"), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } + }; + + let splade = if pool == Pool::Splade { + Some(DistilBertSpladeHead::load(vb.clone(), config)?) + } else { + None + }; + + Ok(Self { + embeddings, + encoder, + pool, + splade, + num_attention_heads: config.n_heads, + device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let shape = (batch_size, max_length); + + let (input_ids, position_ids, input_lengths, attention_bias, attention_mask) = + if batch_size > 1 { + // Prepare padded batch + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut attention_bias = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + // Bool to know if we need to use the attention mask + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + attention_bias.push(0.0); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + position_ids.push(0); + attention_mask.push(0.0_f32); + attention_bias.push(f32::NEG_INFINITY); + } + } + } + + let (attention_bias, attention_mask) = match masking { + true => { + // We only need the mask if we use mean pooling + // For CLS pooling, the bias is enough + let attention_mask = if self.pool == Pool::Mean { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; + + Some(attention_mask) + } else { + None + }; + + let attention_bias = Tensor::from_vec( + attention_bias, + (batch_size, 1, 1, max_length), + &self.device, + )? + .to_dtype(self.dtype)?; + // Broadcast once instead of at every layer + let attention_bias = attention_bias + .broadcast_as(( + batch_size, + self.num_attention_heads, + max_length, + max_length, + ))? + .contiguous()?; + (Some(attention_bias), attention_mask) + } + false => (None, None), + }; + + ( + input_ids, + position_ids, + input_lengths, + attention_bias, + attention_mask, + ) + } else { + ( + batch.input_ids, + batch.position_ids, + vec![batch.max_length as f32], + None, + None, + ) + }; + + // Create CPU tensors + let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?; + let input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; + + let embedding_output = self.embeddings.forward(&input_ids, &position_ids)?; + + let outputs = self + .encoder + .forward(&embedding_output, attention_bias.as_ref())?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_indices_length = batch.pooled_indices.len(); + let mut outputs = outputs.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let pooled_indices = if has_raw_requests { + let pooled_indices = + Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; + + // Select values in the batch + outputs = outputs.index_select(&pooled_indices, 0)?; + Some(pooled_indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => outputs.i((.., 0))?, + // Mean pooling + Pool::Mean => { + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + outputs = outputs.broadcast_mul(&attention_mask)?; + } + + (outputs.sum(1)?.broadcast_div(&input_lengths))? + } + Pool::Splade => { + // Unwrap is safe here + let splade_head = self.splade.as_ref().unwrap(); + let mut relu_log = splade_head.forward(&outputs)?; + + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + relu_log = relu_log.broadcast_mul(&attention_mask)?; + } + + relu_log.max(1)? + } + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + // Reshape outputs + let (b, l, h) = outputs.shape().dims3()?; + let outputs = outputs.reshape((b * l, h))?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.into_iter() { + let start = i * batch.max_length; + let i = i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for DistilBertModel { + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } +} diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index f32137fc..01a5cc9d 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -1,76 +1,14 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{LayerNorm, Linear}; use crate::models::bert::{ - BertClassificationHead, BertConfig, ClassificationHead, PositionEmbeddingType, + BertClassificationHead, BertConfig, BertEmbeddings, ClassificationHead, PositionEmbeddingType, RobertaClassificationHead, }; use crate::models::Model; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; -#[derive(Debug)] -struct BertEmbeddings { - word_embeddings: Embedding, - token_type_embeddings: Embedding, - position_embeddings: Embedding, - layer_norm: LayerNorm, - span: tracing::Span, -} - -impl BertEmbeddings { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { - if config.position_embedding_type != PositionEmbeddingType::Absolute { - candle::bail!("FlashBert only supports absolute position embeddings"); - } - - Ok(Self { - word_embeddings: Embedding::new( - vb.pp("word_embeddings") - .get((config.vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ), - token_type_embeddings: Embedding::new( - vb.pp("token_type_embeddings") - .get((config.type_vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ), - position_embeddings: Embedding::new( - vb.pp("position_embeddings").get( - (config.max_position_embeddings, config.hidden_size), - "weight", - )?, - config.hidden_size, - ), - layer_norm: LayerNorm::load( - vb.pp("LayerNorm"), - config.hidden_size, - config.layer_norm_eps as f32, - )?, - span: tracing::span!(tracing::Level::TRACE, "embeddings"), - }) - } - - fn forward( - &self, - input_ids: &Tensor, - token_type_ids: &Tensor, - position_ids: &Tensor, - ) -> Result { - let _enter = self.span.enter(); - - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - let embeddings = input_embeddings.add(&token_type_embeddings)?; - - let position_embeddings = self.position_embeddings.forward(position_ids)?; - - let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?; - - Ok(embeddings) - } -} - struct BertAttention { qkv_linear: Linear, dense: Linear, @@ -169,7 +107,7 @@ impl BertAttention { let attention = attention.flatten_from(candle::D::Minus2)?; let hidden_states = self.dense.forward(&attention)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -239,7 +177,7 @@ impl BertLayer { let hidden_states = self.intermediate.forward(&hidden_states)?; let hidden_states = self.output.forward(&hidden_states)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -309,7 +247,12 @@ impl FlashBertModel { Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?); (pool, Some(classifier)) } - ModelType::Embedding(pool) => (pool, None), + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Nomic") + } + (pool, None) + } }; let (embeddings, encoder) = match ( @@ -368,7 +311,12 @@ impl FlashBertModel { ); (pool, Some(classifier)) } - ModelType::Embedding(pool) => (pool, None), + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Nomic") + } + (pool, None) + } }; let (embeddings, encoder) = match ( @@ -483,6 +431,9 @@ impl FlashBertModel { Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) } } + Pool::Splade => { + unreachable!(); + } } } else { None diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs new file mode 100644 index 00000000..26e99721 --- /dev/null +++ b/backends/candle/src/models/flash_distilbert.rs @@ -0,0 +1,379 @@ +use crate::flash_attn::flash_attn_varlen; +use crate::layers::{LayerNorm, Linear}; +use crate::models::distilbert::{ + DistilBertConfig, DistilBertEmbeddings, DistilBertMLP, DistilBertSpladeHead, +}; +use crate::models::Model; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::VarBuilder; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; + +#[derive(Debug)] +struct DistilBertAttention { + qkv_linear: Linear, + dense: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + softmax_scale: f32, + + span: tracing::Span, +} + +impl DistilBertAttention { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let attention_head_size = config.dim / config.n_heads; + let all_head_size = config.n_heads * attention_head_size; + let hidden_size = config.dim; + + let query_weight = vb.pp("q_lin").get((all_head_size, hidden_size), "weight")?; + let query_bias = vb.pp("q_lin").get(all_head_size, "bias")?; + let key_weight = vb.pp("k_lin").get((all_head_size, hidden_size), "weight")?; + let key_bias = vb.pp("k_lin").get(all_head_size, "bias")?; + let value_weight = vb.pp("v_lin").get((all_head_size, hidden_size), "weight")?; + let value_bias = vb.pp("v_lin").get(all_head_size, "bias")?; + + let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?; + let qkv_bias = Tensor::cat(&[&query_bias, &key_bias, &value_bias], 0)?; + + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); + + let dense_weight = vb.pp("out_lin").get((hidden_size, hidden_size), "weight")?; + let dense_bias = vb.pp("out_lin").get(hidden_size, "bias")?; + + let dense = Linear::new(dense_weight, Some(dense_bias), None); + + let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32; + + Ok(Self { + qkv_linear, + dense, + num_attention_heads: config.n_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + + let qkv = qkv.reshape(new_qkv_shape.as_slice())?; + let qkv = qkv.chunk(3, 1)?; + + let attention = flash_attn_varlen( + &qkv[0], + &qkv[1], + &qkv[2], + None, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + self.softmax_scale, + false, + )?; + let attention = attention.flatten_from(candle::D::Minus2)?; + + let hidden_states = self.dense.forward(&attention)?; + + Ok(hidden_states) + } +} + +#[derive(Debug)] +struct DistilBertBlock { + attention: DistilBertAttention, + mlp: DistilBertMLP, + post_attention_layer_norm: LayerNorm, + output_layer_norm: LayerNorm, + + span: tracing::Span, +} + +impl DistilBertBlock { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let attention = DistilBertAttention::load(vb.pp("attention"), config)?; + let mlp = DistilBertMLP::load(vb.pp("ffn"), config)?; + + let post_attention_layer_norm = + LayerNorm::load(vb.pp("sa_layer_norm"), config.dim, 1e-12f32)?; + let output_layer_norm = LayerNorm::load(vb.pp("output_layer_norm"), config.dim, 1e-12f32)?; + + Ok(Self { + attention, + mlp, + post_attention_layer_norm, + output_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + cu_seqlens: &Tensor, + max_s: usize, + ) -> Result { + let _enter = self.span.enter(); + + let attn_output = self.attention.forward(hidden_states, cu_seqlens, max_s)?; + let hidden_states = self + .post_attention_layer_norm + .forward(hidden_states, Some(&attn_output))?; + + let mlp_out = self.mlp.forward(&hidden_states)?; + + self.output_layer_norm + .forward(&hidden_states, Some(&mlp_out)) + } +} + +#[derive(Debug)] +struct DistilBertEncoder { + layers: Vec, + span: tracing::Span, +} + +impl DistilBertEncoder { + pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result { + let layers = (0..config.n_layers) + .map(|index| DistilBertBlock::load(vb.pp(format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + + Ok(DistilBertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, cu_seqlens: &Tensor, max_s: usize) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, cu_seqlens, max_s)?; + } + + Ok(hidden_states) + } +} + +pub struct FlashDistilBertModel { + embeddings: DistilBertEmbeddings, + encoder: DistilBertEncoder, + pool: Pool, + splade: Option, + + pub device: Device, + + span: tracing::Span, +} + +impl FlashDistilBertModel { + pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result { + match vb.device() { + Device::Cuda(_) => {} + _ => candle::bail!("FlashDistilBert requires Cuda"), + } + + if vb.dtype() != DType::F16 { + candle::bail!("FlashDistilBert requires DType::F16") + } + + let pool = match model_type { + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for DistilBert") + } + ModelType::Embedding(pool) => pool, + }; + + let (embeddings, encoder) = match ( + DistilBertEmbeddings::load(vb.pp("embeddings"), config), + DistilBertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(embeddings), Ok(encoder)) = ( + DistilBertEmbeddings::load(vb.pp("distilbert.embeddings"), config), + DistilBertEncoder::load(vb.pp("distilbert.transformer"), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } + }; + + let splade = if pool == Pool::Splade { + Some(DistilBertSpladeHead::load(vb.clone(), config)?) + } else { + None + }; + + Ok(Self { + embeddings, + encoder, + pool, + splade, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let shape = batch.input_ids.len(); + + // Create Cuda tensors + let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + let cu_seqlens = Tensor::from_vec( + batch.cumulative_seq_lengths.clone(), + batch_size + 1, + &self.device, + )?; + + let embedding_output = self.embeddings.forward(&input_ids, &position_ids)?; + + let outputs = + self.encoder + .forward(&embedding_output, &cu_seqlens, batch.max_length as usize)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => { + // Get the indices of the cls tokens from cu_seqlens + let mut cls_indices = cu_seqlens.narrow(0, 0, batch_size)?; + + // If raw_indices is empty, we don't need to do anything with + // the pooled_indices + if has_raw_requests { + // We need the pooled indices to select the correct cls indices + let pooled_indices = Tensor::from_vec( + batch.pooled_indices.clone(), + batch.pooled_indices.len(), + &self.device, + )?; + + // Only select indices that requires pooling + cls_indices = cls_indices.index_select(&pooled_indices, 0)? + } + + // Select cls tokens + outputs.index_select(&cls_indices, 0)? + } + // Mean pooling + Pool::Mean => { + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + // Mean + let embeddings = outputs.narrow(0, start as usize, len as usize)?; + embeddings.sum_keepdim(0)? / (len as f64) + }) + .collect(); + + // Concatenate all results + Tensor::cat(&results?, 0)? + } else { + (outputs.sum_keepdim(0)? / (batch.max_length as f64))? + } + } + Pool::Splade => { + // Unwrap is safe here + let splade_head = self.splade.as_ref().unwrap(); + let relu_log = splade_head.forward(&outputs)?; + + if batch_size > 1 { + // for each request that requires pooling + let results: Result> = batch + .pooled_indices + .into_iter() + .map(|i| { + let i = i as usize; + let start = batch.cumulative_seq_lengths[i]; + let len = batch.cumulative_seq_lengths[i + 1] - start; + + relu_log.narrow(0, start as usize, len as usize)?.max(0) + }) + .collect(); + + // Concatenate all results + Tensor::cat(&results?, 0)? + } else { + relu_log.max_keepdim(0)? + } + } + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + if batch_size > 1 && has_pooling_requests { + // Create indexing vector for the embeddings + let mut final_indices: Vec = Vec::with_capacity(shape); + for i in batch.raw_indices.into_iter() { + let i = i as usize; + // Get start/end token index of this specific member of the batch + let start = batch.cumulative_seq_lengths[i]; + let end = batch.cumulative_seq_lengths[i + 1]; + + for j in start..end { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for FlashDistilBertModel { + fn is_padded(&self) -> bool { + false + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } +} diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index 13029dd8..ebf2c292 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -2,78 +2,12 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::{BertConfig, PositionEmbeddingType}; +use crate::models::jina::BertEmbeddings; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; -#[derive(Debug)] -struct BertEmbeddings { - word_embeddings: Embedding, - token_type_embeddings: Embedding, - position_embeddings: Option, - layer_norm: LayerNorm, - span: tracing::Span, -} - -impl BertEmbeddings { - pub fn load(vb: VarBuilder, config: &BertConfig) -> Result { - let position_embeddings = - if config.position_embedding_type == PositionEmbeddingType::Absolute { - Some(Embedding::new( - vb.pp("position_embeddings").get( - (config.max_position_embeddings, config.hidden_size), - "weight", - )?, - config.hidden_size, - )) - } else { - None - }; - - Ok(Self { - word_embeddings: Embedding::new( - vb.pp("word_embeddings") - .get((config.vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ), - token_type_embeddings: Embedding::new( - vb.pp("token_type_embeddings") - .get((config.type_vocab_size, config.hidden_size), "weight")?, - config.hidden_size, - ), - position_embeddings, - layer_norm: LayerNorm::load( - vb.pp("LayerNorm"), - config.hidden_size, - config.layer_norm_eps as f32, - )?, - span: tracing::span!(tracing::Level::TRACE, "embeddings"), - }) - } - - fn forward( - &self, - input_ids: &Tensor, - token_type_ids: &Tensor, - position_ids: &Tensor, - ) -> Result { - let _enter = self.span.enter(); - - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - - if let Some(position_embeddings) = &self.position_embeddings { - let position_embeddings = position_embeddings.forward(position_ids)?; - let embeddings = input_embeddings.add(&token_type_embeddings)?; - self.layer_norm.forward(&embeddings, &position_embeddings) - } else { - self.layer_norm - .forward(&input_embeddings, &token_type_embeddings) - } - } -} - struct AlibiBertAttention { qkv_linear: Linear, dense: Linear, @@ -175,7 +109,7 @@ impl AlibiBertAttention { let attention = attention.flatten_from(candle::D::Minus2)?; let hidden_states = self.dense.forward(&attention)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -250,7 +184,7 @@ impl JinaBertLayer { let hidden_states = (gated * non_gated)?; let hidden_states = self.output.forward(&hidden_states)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -319,11 +253,15 @@ impl FlashJinaBertModel { } let pool = match model_type { - // Classifier models always use CLS pooling ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for Jina") } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Jina") + } + pool + } }; let (embeddings, encoder) = match ( @@ -332,18 +270,7 @@ impl FlashJinaBertModel { ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { - let model_type = config.model_type.clone().unwrap_or("bert".to_string()); - if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config), - BertEncoder::load( - vb.pp(format!("{model_type}.encoder")), - config, - alibi.clone(), - ), - ) { - (embeddings, encoder) - } else if let (Ok(embeddings), Ok(encoder)) = ( BertEmbeddings::load(vb.pp("bert.embeddings"), config), BertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()), ) { @@ -442,6 +369,9 @@ impl FlashJinaBertModel { Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) } } + Pool::Splade => { + unreachable!(); + } } } else { None diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 2400d745..12eff0f1 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,96 +1,11 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{LayerNorm, Linear}; +use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP}; use crate::models::{Model, NomicConfig}; -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{Embedding, VarBuilder}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::VarBuilder; use text_embeddings_backend_core::{Batch, ModelType, Pool}; -#[derive(Debug)] -struct NomicBertEmbeddings { - word_embeddings: Embedding, - token_type_embeddings: Embedding, - layer_norm: LayerNorm, - span: tracing::Span, -} - -impl NomicBertEmbeddings { - pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { - Ok(Self { - word_embeddings: Embedding::new( - vb.pp("embeddings.word_embeddings") - .get((config.vocab_size, config.n_embd), "weight")?, - config.n_embd, - ), - token_type_embeddings: Embedding::new( - vb.pp("embeddings.token_type_embeddings") - .get((config.type_vocab_size, config.n_embd), "weight")?, - config.n_embd, - ), - layer_norm: LayerNorm::load(vb.pp("emb_ln"), config.n_embd, config.layer_norm_epsilon)?, - span: tracing::span!(tracing::Level::TRACE, "embeddings"), - }) - } - - fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { - let _enter = self.span.enter(); - - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - - let embeddings = self - .layer_norm - .forward(&input_embeddings, &token_type_embeddings)?; - - Ok(embeddings) - } -} - -struct NomicBertGatedMLP { - gate_up_proj: Linear, - down_proj: Linear, - - span: tracing::Span, -} - -impl NomicBertGatedMLP { - pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { - let intermediate_size = config.n_inner; - - let gate_proj_weight = vb - .pp("fc12") - .get((intermediate_size, config.n_embd), "weight")?; - - let up_proj_weight = vb - .pp("fc11") - .get((intermediate_size, config.n_embd), "weight")?; - - let gate_up_proj_weight = Tensor::cat(&[&gate_proj_weight, &up_proj_weight], 0)?; - let gate_up_proj = Linear::new( - gate_up_proj_weight, - None, - Some(config.activation_function.clone()), - ); - - let down_proj_weight = vb - .pp("fc2") - .get((config.n_embd, intermediate_size), "weight")?; - let down_proj = Linear::new(down_proj_weight, None, None); - - Ok(Self { - gate_up_proj, - down_proj, - span: tracing::span!(tracing::Level::TRACE, "mlp"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let gate_up_states = self.gate_up_proj.forward(hidden_states)?; - self.down_proj.forward(&gate_up_states) - } -} - struct NomicAttention { qkv_linear: Linear, out_proj: Linear, @@ -216,11 +131,12 @@ impl NomicBertBlock { .forward(&hidden_states, cu_seqlens, cos, sin, max_s)?; let hidden_states = self .post_attention_layer_norm - .forward(&hidden_states, &attn_output)?; + .forward(&hidden_states, Some(&attn_output))?; let mlp_out = self.mlp.forward(&hidden_states)?; - self.output_layer_norm.forward(&hidden_states, &mlp_out) + self.output_layer_norm + .forward(&hidden_states, Some(&mlp_out)) } } @@ -289,11 +205,15 @@ impl FlashNomicBertModel { } let pool = match model_type { - // Classifier models always use CLS pooling ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + candle::bail!("`classifier` model type is not supported for Nomic") + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Nomic") + } + pool } - ModelType::Embedding(pool) => pool, }; let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?; @@ -434,6 +354,9 @@ impl FlashNomicBertModel { Some((outputs.sum_keepdim(0)? / (batch.max_length as f64))?) } } + Pool::Splade => { + unreachable!(); + } } } else { None diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 65c57217..97bc7f96 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -51,7 +51,7 @@ impl BertEmbeddings { }) } - fn forward( + pub fn forward( &self, input_ids: &Tensor, token_type_ids: &Tensor, @@ -65,10 +65,11 @@ impl BertEmbeddings { if let Some(position_embeddings) = &self.position_embeddings { let position_embeddings = position_embeddings.forward(position_ids)?; let embeddings = input_embeddings.add(&token_type_embeddings)?; - self.layer_norm.forward(&embeddings, &position_embeddings) + self.layer_norm + .forward(&embeddings, Some(&position_embeddings)) } else { self.layer_norm - .forward(&input_embeddings, &token_type_embeddings) + .forward(&input_embeddings, Some(&token_type_embeddings)) } } } @@ -229,7 +230,7 @@ impl BertAttention { let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; let hidden_states = self.dense.forward(&context_layer)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -303,7 +304,7 @@ impl JinaBertLayer { let hidden_states = (gated * non_gated)?; let hidden_states = self.output.forward(&hidden_states)?; - let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?; + let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -365,11 +366,15 @@ impl JinaBertModel { }; let pool = match model_type { - // Classifier models always use CLS pooling ModelType::Classifier => { candle::bail!("`classifier` model type is not supported for Jina") } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Jina") + } + pool + } }; let (embeddings, encoder) = match ( @@ -378,14 +383,7 @@ impl JinaBertModel { ) { (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), (Err(err), _) | (_, Err(err)) => { - let model_type = config.model_type.clone().unwrap_or("bert".to_string()); - if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config), - BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config), - ) { - (embeddings, encoder) - } else if let (Ok(embeddings), Ok(encoder)) = ( BertEmbeddings::load(vb.pp("bert.embeddings"), config), BertEncoder::load(vb.pp("bert.encoder"), config), ) { @@ -612,6 +610,7 @@ impl JinaBertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::Splade => unreachable!(), }; Some(pooled_embeddings) } else { diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 18b7fa15..4f9e7551 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -50,7 +50,7 @@ impl NomicConfig { } #[derive(Debug)] -struct NomicBertEmbeddings { +pub struct NomicBertEmbeddings { word_embeddings: Embedding, token_type_embeddings: Embedding, layer_norm: LayerNorm, @@ -75,7 +75,7 @@ impl NomicBertEmbeddings { }) } - fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { let _enter = self.span.enter(); let input_embeddings = self.word_embeddings.forward(input_ids)?; @@ -83,13 +83,13 @@ impl NomicBertEmbeddings { let embeddings = self .layer_norm - .forward(&input_embeddings, &token_type_embeddings)?; + .forward(&input_embeddings, Some(&token_type_embeddings))?; Ok(embeddings) } } -struct NomicBertGatedMLP { +pub struct NomicBertGatedMLP { gate_up_proj: Linear, down_proj: Linear, @@ -329,11 +329,12 @@ impl NomicBertBlock { .forward(hidden_states, attention_bias, cos, sin)?; let hidden_states = self .post_attention_layer_norm - .forward(hidden_states, &attn_output)?; + .forward(hidden_states, Some(&attn_output))?; let mlp_out = self.mlp.forward(&hidden_states)?; - self.output_layer_norm.forward(&hidden_states, &mlp_out) + self.output_layer_norm + .forward(&hidden_states, Some(&mlp_out)) } } @@ -398,9 +399,14 @@ impl NomicBertModel { let pool = match model_type { // Classifier models always use CLS pooling ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for Jina") + candle::bail!("`classifier` model type is not supported for Nomic") + } + ModelType::Embedding(pool) => { + if pool == Pool::Splade { + candle::bail!("`splade` is not supported for Nomic") + } + pool } - ModelType::Embedding(pool) => pool, }; let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?; @@ -620,6 +626,7 @@ impl NomicBertModel { (outputs.sum(1)?.broadcast_div(&input_lengths))? } + Pool::Splade => unreachable!(), }; Some(pooled_embeddings) } else { diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 79f2761d..06cef3ed 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -55,8 +55,14 @@ pub enum ModelType { #[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "clap", derive(ValueEnum))] pub enum Pool { + /// Select the CLS token as embedding Cls, + /// Apply Mean pooling to the model embeddings Mean, + /// Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. + /// This option is only available if the loaded model is a `ForMaskedLM` Transformer + /// model. + Splade, } impl fmt::Display for Pool { @@ -64,6 +70,7 @@ impl fmt::Display for Pool { match self { Pool::Cls => write!(f, "cls"), Pool::Mean => write!(f, "mean"), + Pool::Splade => write!(f, "splade"), } } } diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 68c570f7..f3519ee5 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -23,19 +23,20 @@ impl PythonBackend { uds_path: String, otlp_endpoint: Option, ) -> Result { - let pool = match model_type { + match model_type { ModelType::Classifier => { return Err(BackendError::Start( "`classifier` model type is not supported".to_string(), )) } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => { + if pool != Pool::Cls { + return Err(BackendError::Start(format!("{pool:?} is not supported"))); + } + pool + } }; - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); - } - let backend_process = management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() diff --git a/core/src/infer.rs b/core/src/infer.rs index e19899f4..f9428a82 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -105,6 +105,15 @@ impl Infer { ) -> Result { let start_time = Instant::now(); + if self.is_splade() { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "`embed_all` is not available for SPLADE models".to_string(); + tracing::error!("{message}"); + return Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))); + } + let results = self .embed(inputs, truncate, false, &start_time, permit) .await?; @@ -145,6 +154,15 @@ impl Infer { ) -> Result { let start_time = Instant::now(); + if self.is_splade() && normalize { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "`normalize` is not available for SPLADE models".to_string(); + tracing::error!("{message}"); + return Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))); + } + let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; @@ -366,6 +384,14 @@ impl Infer { matches!(self.backend.model_type, ModelType::Classifier) } + #[instrument(skip(self))] + pub fn is_splade(&self) -> bool { + matches!( + self.backend.model_type, + ModelType::Embedding(text_embeddings_backend::Pool::Splade) + ) + } + #[instrument(skip(self))] pub async fn health(&self) -> bool { self.backend.health().await.is_ok() diff --git a/docs/source/en/cli_arguments.md b/docs/source/en/cli_arguments.md index d5b1eb16..9dd407bd 100644 --- a/docs/source/en/cli_arguments.md +++ b/docs/source/en/cli_arguments.md @@ -55,13 +55,16 @@ Options: --pooling Optionally control the pooling method for embedding models. - If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` - configuration. + If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` configuration. If `pooling` is set, it will override the model pooling configuration [env: POOLING=] - [possible values: cls, mean] + + Possible values: + - cls: Select the CLS token as embedding + - mean: Apply Mean pooling to the model embeddings + - splade: Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. This option is only available if the loaded model is a `ForMaskedLM` Transformer model --max-concurrent-requests The maximum amount of concurrent requests for this particular deployment. diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 4220d4b2..b7958ece 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -300,17 +300,10 @@ async fn rerank( } match &info.model_type { - ModelType::Classifier(_) => { - metrics::increment_counter!("te_request_failure", "err" => "model_type"); - let message = "model is not a re-ranker model".to_string(); - Err(TextEmbeddingsError::Backend(BackendError::Inference( - message, - ))) - } ModelType::Reranker(_) => Ok(()), - ModelType::Embedding(_) => { + ModelType::Classifier(_) | ModelType::Embedding(_) => { metrics::increment_counter!("te_request_failure", "err" => "model_type"); - let message = "model is not a classifier model".to_string(); + let message = "model is not a re-ranker model".to_string(); Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))) diff --git a/router/src/lib.rs b/router/src/lib.rs index a85b2e71..600e9d33 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -101,45 +101,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = { - // Check if the model is a classifier - let mut classifier = false; - for arch in &config.architectures { - if arch.ends_with("Classification") { - classifier = true; - break; - } - } - - if classifier { - if pooling.is_some() { - tracing::warn!( - "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." - ); - } - text_embeddings_backend::ModelType::Classifier - } else { - // Set pooling - let pool = match pooling { - Some(pool) => pool, - None => { - // Load pooling config - let config_path = model_root.join("1_Pooling/config.json"); - let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; - let config: PoolConfig = serde_json::from_str(&config) - .context("Failed to parse `1_Pooling/config.json`")?; - if config.pooling_mode_cls_token { - text_embeddings_backend::Pool::Cls - } else if config.pooling_mode_mean_tokens { - text_embeddings_backend::Pool::Mean - } else { - return Err(anyhow!("Pooling config {config:?} is not supported")); - } - } - }; - text_embeddings_backend::ModelType::Embedding(pool) - } - }; + let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; // Info model type let model_type = match &backend_model_type { @@ -315,6 +277,53 @@ pub async fn run( Ok(()) } +fn get_backend_model_type( + config: &ModelConfig, + model_root: &Path, + pooling: Option, +) -> Result { + for arch in &config.architectures { + if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { + return Ok(text_embeddings_backend::ModelType::Embedding( + text_embeddings_backend::Pool::Splade, + )); + } else if arch.ends_with("Classification") { + if pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." + ); + } + return Ok(text_embeddings_backend::ModelType::Classifier); + } + } + + if Some(text_embeddings_backend::Pool::Splade) == pooling { + return Err(anyhow!( + "Splade pooling is not supported: model is not a ForMaskedLM model" + )); + } + + // Set pooling + let pool = match pooling { + Some(pool) => pool, + None => { + // Load pooling config + let config_path = model_root.join("1_Pooling/config.json"); + let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; + let config: PoolConfig = + serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?; + if config.pooling_mode_cls_token { + text_embeddings_backend::Pool::Cls + } else if config.pooling_mode_mean_tokens { + text_embeddings_backend::Pool::Mean + } else { + return Err(anyhow!("Pooling config {config:?} is not supported")); + } + } + }; + Ok(text_embeddings_backend::ModelType::Embedding(pool)) +} + #[derive(Debug, Deserialize)] pub struct ModelConfig { pub architectures: Vec, diff --git a/router/tests/snapshots/test_http_embed__embeddings_raw.snap b/router/tests/snapshots/test_http_embed__embeddings_raw.snap new file mode 100644 index 00000000..b5b0c056 --- /dev/null +++ b/router/tests/snapshots/test_http_embed__embeddings_raw.snap @@ -0,0 +1,1156 @@ +--- +source: router/tests/test_http_embed.rs +expression: embeddings_raw +--- +- - - 0.05362602 + - -0.09624873 + - -0.15422677 + - 0.16551508 + - 0.023489304 + - -0.21831265 + - -0.0044252407 + - 0.002727163 + - 0.055650685 + - -0.047087643 + - -0.059601896 + - -0.02418546 + - -0.01694933 + - 0.04551375 + - -0.15409167 + - -0.12283319 + - -0.05245691 + - -0.31283477 + - 0.13004285 + - -0.20680268 + - -0.031662747 + - -0.22320485 + - 0.14813782 + - 0.13430233 + - 0.058092654 + - -0.015724655 + - -0.009829037 + - 0.13547266 + - -0.16554263 + - -0.7596838 + - 0.14668062 + - -0.020452041 + - -0.3560821 + - 0.06347646 + - -0.03694484 + - 0.19616315 + - 0.3361846 + - 0.047190003 + - 0.15021427 + - -0.03535195 + - 0.29383343 + - -0.25674072 + - 0.46257654 + - 0.14315909 + - 0.03592509 + - 0.049371216 + - -0.011901165 + - 0.023781417 + - -0.1727463 + - -0.2934507 + - -0.15773995 + - -0.040760476 + - -0.29319274 + - -0.06475893 + - 0.07598518 + - 0.027925171 + - 0.16007681 + - -0.17865865 + - 0.22355102 + - 0.19506413 + - 0.06829366 + - 0.05443536 + - -0.09455403 + - 0.11227459 + - -0.41728616 + - -0.026201978 + - 0.030254029 + - -0.12712762 + - -0.11732851 + - -0.001122945 + - -0.03057484 + - 0.0064175464 + - -0.11026013 + - 0.25587565 + - 0.19633786 + - 0.30550367 + - 0.0022042356 + - -0.16191646 + - 0.2945474 + - 0.17519158 + - -0.25888672 + - -0.10123868 + - -0.10138594 + - 0.3301158 + - 0.04399976 + - 0.5158396 + - 0.24492678 + - 0.15301128 + - -0.23632856 + - -0.026038088 + - -0.05722322 + - 0.29818803 + - 0.1868367 + - -0.17865065 + - -0.25335371 + - -0.10080598 + - -0.1607885 + - -0.119588226 + - -0.057394847 + - 5.723344 + - 0.032490775 + - 0.3530313 + - 0.1348905 + - 0.24220605 + - -0.38658923 + - -0.01821479 + - 0.18107991 + - -0.16821508 + - 0.4036339 + - 0.025022302 + - -0.030914325 + - -0.15392987 + - 0.38371697 + - -0.19422346 + - -0.022569854 + - -0.29219532 + - -0.32863507 + - 0.26936218 + - 0.014036803 + - 0.14080173 + - 0.37050244 + - 0.045396734 + - -0.21831934 + - 0.12694392 + - 0.41365933 + - -0.90396655 + - 0.040971227 + - -0.00000000000000000000000000000005246612 + - 0.27052212 + - -0.2654379 + - 0.03946123 + - 0.19328071 + - 0.037692957 + - 0.31038493 + - 0.072074585 + - 0.22249489 + - -0.10029659 + - -0.1680239 + - 0.33254364 + - -0.12934144 + - -0.025849117 + - 0.10875057 + - -0.07489428 + - 0.24911256 + - 0.1105005 + - -0.08491806 + - -0.014170787 + - 0.2530447 + - 0.042965006 + - 0.34765762 + - -0.07048108 + - -0.30054837 + - -0.0813425 + - -0.26823276 + - 0.021108601 + - 0.03587471 + - -0.018586367 + - -0.30330718 + - -0.6559383 + - 0.023775786 + - -0.40882462 + - -0.04372612 + - 0.07074663 + - -0.11261686 + - -0.21638495 + - 0.0054929573 + - 0.13377005 + - 0.25102043 + - 0.13641334 + - -0.14129911 + - -0.00960681 + - -0.1292076 + - 0.092134364 + - 0.11985314 + - -0.27784556 + - -0.22292812 + - 0.09219065 + - -0.075195044 + - 0.033164244 + - 0.0669493 + - 0.14347024 + - -0.015243015 + - -0.4721483 + - 0.01026145 + - -0.043858107 + - 0.14971209 + - -0.101460405 + - 0.5573691 + - -0.13179484 + - -0.03349352 + - -0.070452206 + - 0.27212715 + - 0.035100214 + - -0.41319823 + - -0.0101666935 + - -0.3604911 + - 0.17697893 + - 0.33473808 + - -0.041392684 + - -0.17730354 + - -0.16893569 + - -0.097044975 + - -0.06261413 + - -0.11515323 + - -0.21717575 + - 0.64800483 + - 0.05039451 + - -0.12829977 + - 0.36524624 + - -0.45265996 + - 0.12899241 + - -0.1518194 + - 0.39424738 + - -0.031050805 + - -0.16742408 + - -0.10804913 + - -0.1206848 + - -0.14689165 + - -0.29266074 + - 0.030588716 + - -0.010485575 + - -0.10741934 + - 0.21359915 + - 0.000000000000000000000000000000043513933 + - -0.037409775 + - 0.030618764 + - -0.2371208 + - 0.17326707 + - 0.12895373 + - 0.03190629 + - 0.25342023 + - -0.08149324 + - -0.22749813 + - 0.1638176 + - 0.1661469 + - -0.15858074 + - -0.113731846 + - 0.14054196 + - -0.03540598 + - -0.08210864 + - 0.15317258 + - -0.07926808 + - 0.02742753 + - -0.17011492 + - -0.070088625 + - 0.12648112 + - 0.1570639 + - 0.13020608 + - -0.09826183 + - 0.1441314 + - -0.09297671 + - -0.049254175 + - 0.066838585 + - 0.13196804 + - 0.002856657 + - 0.11642868 + - 0.043407857 + - 0.18099733 + - -0.018144768 + - 0.042576294 + - 0.044383463 + - -0.18780808 + - 0.012117827 + - 0.17451885 + - 0.09892663 + - 0.13026209 + - 0.15904307 + - 0.11400986 + - -0.073692165 + - 0.19707805 + - -0.02425171 + - 0.0018985122 + - 0.065153405 + - 0.09005826 + - -0.25437978 + - 0.24253485 + - 0.13076745 + - -0.1160331 + - 0.056037664 + - -0.119841136 + - -0.25754046 + - -0.025066173 + - -0.009398602 + - 0.27659163 + - -0.16430894 + - 0.5105751 + - -0.094437465 + - -0.066454455 + - -0.12024814 + - -0.02433209 + - -0.05913095 + - 0.06453659 + - -0.3723569 + - -0.05593538 + - -0.23421241 + - -0.26118255 + - -0.14977878 + - 0.078412734 + - 0.21446957 + - 0.008954648 + - -0.29328513 + - -0.1369865 + - 0.03630624 + - 0.033576246 + - -0.040133554 + - -0.09483389 + - -0.1597432 + - 0.04354272 + - -0.0567531 + - 0.15691322 + - 0.2740578 + - 0.26771182 + - -0.13835649 + - 0.099879526 + - 0.17656097 + - 0.18187986 + - 0.037354108 + - -0.5158529 + - -0.21966133 + - -0.000000088768445 + - -0.07895537 + - -0.24415097 + - -0.14226882 + - 0.31566626 + - -0.2678174 + - -0.10634744 + - -0.14548448 + - -0.16813678 + - -0.022673666 + - 0.07112027 + - 0.1473964 + - -0.13740689 + - 0.06979442 + - -0.027914286 + - -0.074311025 + - -0.41687658 + - -0.24032494 + - 0.032679375 + - 0.14397258 + - -0.052540865 + - 0.23543003 + - 0.07898207 + - -0.26782948 + - 0.21912418 + - -0.2746623 + - 0.41120437 + - 0.12615636 + - -0.19732329 + - -0.042119402 + - 0.163528 + - 0.103388235 + - -0.024959773 + - -0.18867652 + - -0.2189487 + - -0.021287393 + - 0.0031497534 + - -0.007843923 + - 0.14438185 + - 0.24991998 + - -0.0503937 + - -0.4090227 + - -0.13807997 + - 0.21715388 + - -0.02965941 + - -0.22383071 + - -0.47643352 + - -0.20129524 + - -0.37402472 + - -0.097741075 + - -0.32544568 + - -0.16130292 + - 0.13543358 + - -0.04548286 + - -0.09423275 + - 0.16028474 + - 0.17153065 + - -0.21021502 + - 0.29995695 + - -0.4298109 + - -0.12928377 + - 0.39504942 + - -0.38475 + - -0.07365963 + - -0.26514596 + - - -0.0011101477 + - 0.22351609 + - -0.2819248 + - 0.15021688 + - 0.2076316 + - -0.12000399 + - 0.6105142 + - -0.2081645 + - -0.32259408 + - -0.09641644 + - 0.366515 + - -0.90361553 + - 0.16291445 + - 0.35989842 + - 0.15060869 + - -0.6891818 + - -0.02622461 + - 0.72510093 + - -0.25998372 + - 0.06258218 + - -0.36001733 + - -0.5024996 + - -0.009923387 + - 0.2848148 + - -0.021727366 + - -0.089146 + - -0.708148 + - 0.072484136 + - 0.049141254 + - 0.0150754005 + - 0.4793997 + - -0.10826402 + - 0.08019836 + - 0.7663549 + - 0.80127734 + - -0.45062473 + - -0.09214136 + - 0.24421115 + - 0.57252073 + - 0.039397143 + - -0.37853286 + - -1.1295989 + - 0.48242953 + - -0.045319088 + - 0.33401513 + - 0.7360113 + - -0.39879486 + - 0.7584456 + - -0.60499007 + - 0.19188018 + - 0.16678968 + - -0.6256801 + - -0.08726675 + - -0.38615274 + - -0.6096996 + - 0.25548834 + - 0.045096576 + - -0.16942133 + - 0.32358563 + - 0.58938545 + - -0.25391635 + - -0.35265166 + - 0.055043034 + - -0.11067196 + - 0.47675478 + - 0.49513584 + - -0.068828225 + - -0.3197979 + - 0.32452348 + - -0.09713984 + - -0.2066382 + - -0.099625096 + - 0.5186132 + - 0.1827107 + - 0.45652673 + - -0.21659859 + - -0.002651425 + - -0.63069254 + - -0.011915099 + - -0.13378514 + - -0.8085194 + - -1.0689027 + - -0.14780587 + - 0.35633856 + - 0.33309394 + - 0.8085466 + - 1.0160096 + - -0.0007634163 + - -1.2137728 + - -0.006277621 + - 0.5938595 + - -0.11088293 + - -1.3181666 + - 0.2841096 + - -0.17579404 + - 0.32134178 + - 0.53678215 + - 0.031394973 + - 0.03894441 + - -0.34276098 + - 0.7996672 + - -0.26595217 + - -0.0505279 + - -0.5646885 + - -0.40646145 + - -0.14950627 + - 0.44034967 + - -1.2384757 + - 0.43281502 + - -0.0043197703 + - -0.17284618 + - -0.0014570281 + - 0.1140669 + - 0.28086355 + - 0.62751055 + - -1.0683147 + - -0.66996926 + - 0.42204607 + - -0.72811526 + - 1.1021469 + - 0.72578615 + - 0.09043567 + - 0.111438684 + - 0.0144068785 + - 0.33309186 + - 0.6349906 + - 0.62506115 + - 0.000000000000000000000000000000023458383 + - -0.0735795 + - -0.78800666 + - 1.0388364 + - 0.78434354 + - -0.6645803 + - -0.041556235 + - -0.142865 + - 1.1463563 + - -0.18437272 + - 0.577402 + - -0.09207101 + - -0.44458994 + - -0.19949526 + - -0.31356952 + - 0.3270907 + - 1.2421025 + - -0.6594614 + - -0.38223565 + - -0.6424595 + - 0.055277124 + - 0.22758417 + - -1.3904781 + - 0.4756206 + - -0.5897731 + - -0.456838 + - -0.6231428 + - -0.60977304 + - 0.026388507 + - 0.18365285 + - 0.007828049 + - 0.0006129034 + - 0.05306398 + - -1.1498935 + - 1.3087151 + - -0.36580715 + - 0.461096 + - 0.3202829 + - 0.055462047 + - -0.36584195 + - -0.36046287 + - -0.7300884 + - 0.062068004 + - 0.72145945 + - 0.64125955 + - 0.17172499 + - -0.78715354 + - -0.40075088 + - 0.3041808 + - 1.2194345 + - 0.010402157 + - -0.35985434 + - 0.6371716 + - -0.9040255 + - -0.08123861 + - 0.27690005 + - 0.60005456 + - 0.058723483 + - -0.56821483 + - 0.24865676 + - 0.36972472 + - 0.40694168 + - 0.105341464 + - -0.084122464 + - 1.1564518 + - -0.14731297 + - 0.06922119 + - -0.542001 + - -0.96185637 + - 0.51607627 + - 0.38785064 + - 0.32991126 + - -0.7751901 + - -0.27361813 + - 0.5187437 + - -0.19126584 + - -0.09179941 + - -0.17602682 + - 0.37244323 + - 0.11454858 + - -1.0551739 + - 1.4880986 + - -0.20994979 + - -0.29970407 + - -0.59106255 + - -0.51381326 + - -0.3563717 + - -0.13052636 + - -0.6334936 + - 0.37844363 + - -1.1808226 + - -0.021034423 + - 0.24003777 + - -0.055096976 + - -0.52118546 + - 1.493361 + - -0.000000000000000000000000000000017590748 + - -0.10167974 + - 0.42653853 + - 0.07086732 + - 1.4198663 + - 0.47516215 + - -0.2634161 + - 1.1395591 + - -0.2590642 + - -0.4823106 + - 1.4196008 + - 0.5712795 + - -0.0884507 + - 0.108477175 + - 0.035103593 + - -0.49815798 + - 0.13197437 + - -0.4934045 + - -0.89961886 + - -0.6748505 + - 0.06333339 + - -0.8040523 + - 1.1400731 + - 0.6720452 + - -0.3144133 + - -1.1666874 + - 0.035853047 + - 0.7197842 + - -0.6696387 + - 0.20240572 + - -0.42529634 + - 0.8146539 + - 0.26578462 + - -0.16155258 + - 0.7239907 + - 0.87415946 + - 0.0340038 + - 1.681239 + - -0.062231563 + - -0.09118725 + - 0.62527025 + - 0.2636379 + - 0.6311775 + - 0.22614485 + - -0.6042256 + - -0.25460654 + - 0.63968974 + - 0.7669817 + - -0.41930574 + - 0.25942168 + - 0.47131112 + - 0.1475179 + - -0.05814456 + - 0.10057445 + - 0.20424807 + - 0.13070352 + - -0.20877042 + - -0.9100716 + - 0.23903795 + - -0.86168754 + - 0.018424904 + - -0.57887846 + - 0.5038811 + - -0.13605328 + - 1.1866546 + - -0.82205665 + - -0.39026535 + - -0.26206505 + - 0.7296555 + - 1.3452846 + - 0.35876042 + - -0.52601403 + - 0.4158501 + - -0.21309114 + - -0.5892406 + - 0.31792742 + - -0.24144782 + - -0.7067909 + - -0.3850805 + - 0.41709447 + - -0.26764 + - -0.63474464 + - -0.5768752 + - -0.84135 + - 0.98949665 + - -0.15636528 + - 0.75872886 + - 0.27294034 + - 0.3248108 + - -0.3986218 + - 0.27769497 + - 0.6221113 + - 0.1270932 + - -0.45853662 + - -0.56592476 + - 0.21361037 + - -0.00000009438203 + - -0.5528122 + - -1.0602882 + - 0.768973 + - -0.10135596 + - -0.6583706 + - -0.10448761 + - -0.72117114 + - -0.36864772 + - -0.28054118 + - -0.47200808 + - 0.01777897 + - 0.23944177 + - -0.2570129 + - 0.6022824 + - 0.14370564 + - -0.5324604 + - -0.16479526 + - 0.36381578 + - -0.09372194 + - 0.07523826 + - -0.80028725 + - 0.26521978 + - 0.38890174 + - 0.9902756 + - -0.3516004 + - 0.6648714 + - 0.81488633 + - 1.0781528 + - -0.0027297195 + - -0.37004703 + - 1.0858134 + - 0.21160294 + - -0.5628791 + - -0.94626594 + - 0.39612126 + - 0.45894173 + - -0.2326393 + - 0.18993452 + - 0.02287276 + - 0.7622639 + - -0.67571396 + - 0.52507395 + - -0.034063846 + - -0.5302775 + - -0.35149163 + - -0.51815295 + - -0.77198505 + - -0.21484563 + - -0.31557128 + - -0.78601134 + - -0.16753308 + - -0.4690193 + - -0.37844452 + - -0.0045899805 + - -0.29465005 + - -0.61118555 + - 0.1455608 + - 0.1368874 + - -1.173136 + - 0.45374128 + - 1.1111366 + - 0.7236544 + - 0.5502881 + - -0.5017581 + - - 0.1663418 + - 0.34938747 + - -0.25814858 + - 0.80541164 + - -0.36740482 + - -0.44127563 + - 0.8543592 + - 0.9133245 + - 0.50391877 + - 0.027067095 + - 0.015010182 + - -0.529218 + - -0.15398294 + - 0.120985106 + - -0.29893336 + - -0.47913072 + - 0.2470217 + - -0.79693604 + - -1.3914946 + - -0.10353897 + - -0.3862775 + - 0.23836382 + - -0.64160454 + - 0.20638946 + - -0.45504192 + - 0.5028848 + - -0.37783185 + - 0.41650078 + - 0.67931604 + - -0.43285006 + - 0.10810716 + - 0.73358005 + - 0.5659016 + - 0.028506333 + - 0.24026747 + - 0.5070906 + - 0.5354525 + - 0.24056222 + - -0.35499325 + - -0.06580766 + - 0.016674206 + - -1.3168372 + - -0.22523041 + - -0.30558297 + - 0.12400336 + - -0.001492721 + - 0.10647156 + - 0.23126043 + - -0.29642436 + - 0.33245444 + - -0.67227095 + - -0.08549255 + - -0.49178472 + - -0.2892329 + - 0.19275403 + - 0.12623283 + - -0.32789695 + - -0.15411153 + - 0.37329254 + - 0.036564704 + - 1.061008 + - -0.35394213 + - -0.417361 + - 1.0518453 + - 0.40620863 + - -0.100823075 + - 0.25229195 + - -0.039583635 + - -0.8173082 + - -0.5204507 + - -0.039146986 + - 0.3708074 + - -0.21257849 + - 0.96459395 + - 0.8614282 + - -0.16850083 + - 0.05153028 + - -0.65903175 + - 0.8301618 + - -0.13228785 + - -0.5859121 + - 0.29194397 + - -0.500299 + - 0.3067681 + - 0.091077566 + - 0.20201316 + - 1.0065311 + - 0.44932112 + - -0.7054799 + - 0.23466057 + - -0.16275388 + - 0.5443229 + - 0.47633296 + - -0.22141032 + - -0.37649545 + - 0.20842686 + - -0.21745 + - -0.28055814 + - -0.42242056 + - -0.73893553 + - 0.103442356 + - 0.4653465 + - -0.1933622 + - -0.06425872 + - -0.67594767 + - -0.64596534 + - -0.61386466 + - 0.14173445 + - 0.50167876 + - 0.10199101 + - -0.90301186 + - 0.05929415 + - 0.08684504 + - 0.43414304 + - -0.091117054 + - 0.070790194 + - -0.24508373 + - 0.42712212 + - 0.11139226 + - -0.47403342 + - -0.017764144 + - -0.0035439734 + - 0.22429024 + - -0.39128265 + - -0.1562304 + - -0.4944849 + - 0.6740789 + - -0.00000000000000000000000000000006272999 + - 0.21808404 + - -0.8899419 + - -0.014505066 + - 0.87075824 + - -0.37142944 + - 0.09255007 + - -0.1266155 + - -0.03899938 + - 0.11129426 + - 0.71940273 + - -0.06369302 + - 0.28126672 + - -0.2148789 + - 0.65834224 + - 1.6800551 + - 0.27259567 + - -0.07883709 + - 0.6889739 + - -0.5847211 + - 0.28605342 + - -0.72014546 + - 0.1216401 + - -0.172548 + - 0.13226005 + - -0.90891045 + - 0.21820723 + - 0.56289876 + - -0.37412217 + - 0.15388325 + - 0.28750116 + - -0.34232554 + - 0.48939657 + - 0.1806356 + - 0.042681795 + - -0.038344674 + - -0.45354804 + - 0.14029938 + - -0.49126667 + - 0.7383026 + - -0.3784876 + - -0.16706632 + - -0.17556168 + - -0.69728386 + - 0.028124485 + - 0.34003454 + - 0.07052002 + - 0.12027936 + - -0.4630919 + - -0.4032769 + - 0.040038325 + - 0.06653214 + - -0.3250033 + - -0.53998965 + - -0.3179719 + - -0.39666858 + - 0.32221445 + - 0.21350656 + - -0.69490504 + - -0.6479269 + - 0.20426556 + - 1.006199 + - 0.571869 + - -0.84744525 + - -0.29662162 + - -0.36982152 + - -0.047038477 + - -0.46558896 + - -0.52000314 + - 0.21525657 + - 0.26805466 + - -0.73056364 + - -0.35449937 + - 0.7584346 + - 0.11875821 + - -0.29811698 + - -0.46503803 + - 0.15451741 + - 0.36625987 + - -0.5335682 + - -0.010969831 + - -0.03570492 + - -0.80088884 + - 0.47180107 + - -0.104975715 + - 0.4291523 + - 0.37317455 + - 0.13336691 + - -1.1287363 + - -0.1485198 + - -0.033127356 + - -0.30511004 + - 0.10746859 + - 0.1380052 + - 0.14448021 + - -0.14783889 + - 0.00000000000000000000000000000005137103 + - -0.41975737 + - 0.025760822 + - -0.80016124 + - 1.3465704 + - 0.38620642 + - -0.42292282 + - 1.1258832 + - -0.022942519 + - -0.4213821 + - 0.74355364 + - -0.54590017 + - -0.69432557 + - 1.1108004 + - 0.065217584 + - 0.21873109 + - 0.1913206 + - 1.205181 + - -0.3430985 + - -0.2713815 + - 0.012206523 + - -0.8360778 + - 0.056488294 + - -0.65050673 + - 0.056151483 + - -0.7422195 + - 0.40661868 + - 0.30962563 + - -0.12942404 + - -0.41059434 + - -0.34254158 + - 0.6225023 + - -0.24742776 + - -0.6807441 + - 0.43497828 + - 0.042185314 + - 0.3178843 + - 1.1764826 + - 0.09102576 + - -0.4120454 + - 0.34984812 + - 0.9001419 + - 0.46242487 + - 0.55659086 + - 2.1694803 + - -0.2955632 + - 0.49362123 + - -0.41688752 + - -0.15270649 + - 0.40339592 + - 0.35465178 + - -1.0368934 + - 0.39372078 + - 0.06356082 + - -0.66449136 + - -0.36622268 + - -0.4497857 + - -0.11897237 + - -0.020388396 + - 0.38540548 + - 0.11557645 + - -0.57879794 + - 0.73401904 + - -0.44543827 + - 0.2081637 + - -0.25936866 + - -0.20650204 + - -0.55193406 + - 0.22964026 + - 0.0046762805 + - 0.5105565 + - 1.0150495 + - 0.15882817 + - -0.47342417 + - 0.1382411 + - 0.15166539 + - -0.14027333 + - -1.189387 + - 0.6349987 + - -0.3674529 + - -0.6175469 + - -0.620966 + - -0.9433099 + - -0.3418888 + - 0.48906717 + - -0.8840631 + - 0.17825761 + - 0.4025382 + - 0.53189224 + - -0.15343714 + - -0.2003792 + - 0.19673516 + - 0.22192545 + - -0.20846549 + - 0.12140231 + - -0.23825766 + - -0.00000008734853 + - -0.13470313 + - -0.31860018 + - 0.22797514 + - 0.19470197 + - 0.50286853 + - 0.44204932 + - 0.25239787 + - 0.21702372 + - -0.20960683 + - 0.5131216 + - 0.5973904 + - 0.40830952 + - -0.48740363 + - 0.8268245 + - 0.54409564 + - -0.76430416 + - -0.19396517 + - 0.37688574 + - -0.23808461 + - 0.5568868 + - -0.8895464 + - 0.44158688 + - -0.11718406 + - -0.021308616 + - -0.056336947 + - -0.15322049 + - 0.087012604 + - 1.5152171 + - -0.027880192 + - 0.2222344 + - 0.12768978 + - 0.6452863 + - 0.14440456 + - -0.4466827 + - -0.11587736 + - -0.031724244 + - 0.18780065 + - -0.46342078 + - 0.43756333 + - -0.04429803 + - -0.17902891 + - 0.021462742 + - 0.5280802 + - -0.29984543 + - -0.39147717 + - -0.29082248 + - 0.38956422 + - -0.09818647 + - 0.13247946 + - -0.6619555 + - -0.27032954 + - 0.52747786 + - 0.7106601 + - 0.05472342 + - 0.5912403 + - 0.18672591 + - 0.1890892 + - 0.026695427 + - -0.9920362 + - 0.88510543 + - 2.2166533 + - -0.45323366 + - 0.53000206 + - -0.2772841