Skip to content

Commit

Permalink
feat: GTE classification head (huggingface#441)
Browse files Browse the repository at this point in the history
Co-authored-by: Hyeongchan Kim <[email protected]>
  • Loading branch information
OlivierDehaene and kozistr authored Nov 27, 2024
1 parent 7c4f67e commit 0bfeb7e
Show file tree
Hide file tree
Showing 18 changed files with 145 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence

Below are some examples of the currently supported models:

| Task | Model Type | Model ID |
|--------------------|-------------|---------------------------------------------------------------------------------------------|
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
| Task | Model Type | Model ID |
|--------------------|-------------|-----------------------------------------------------------------------------------------------------------------|
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
| Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) |
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |

### Docker

Expand Down Expand Up @@ -372,7 +373,7 @@ docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingf

### Using Re-rankers models

`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa, XLM-RoBERTa, and GTE Sequence Classification models.
Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity
between a query and a text.

Expand All @@ -392,7 +393,7 @@ And then you can rank the similarity between a query and a list of texts with:
```bash
curl 127.0.0.1:8080/rerank \
-X POST \
-d '{"query":"What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
-d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
-H 'Content-Type: application/json'
```

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ impl Model for FlashBertModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
76 changes: 73 additions & 3 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,66 @@ impl GTELayer {
}
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}

pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct FlashGTEModel {
word_embeddings: Embedding,
token_type_embeddings: Option<Embedding>,
layers: Vec<GTELayer>,
embeddings_norm: LayerNorm,
cos_cache: Tensor,
sin_cache: Tensor,
classifier: Option<GTEClassificationHead>,
pool: Pool,
pub device: Device,

Expand Down Expand Up @@ -233,11 +286,14 @@ impl FlashGTEModel {
candle::bail!("Only `PositionEmbeddingType::Rope` is supported");
}

let pool = match model_type {
let (pool, classifier) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for GTE")
let pool = Pool::Cls;

let classifier = GTEClassificationHead::load(vb.clone(), config)?;
(pool, Some(classifier))
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => (pool, None),
};

let word_embeddings = Embedding::new(
Expand Down Expand Up @@ -292,6 +348,7 @@ impl FlashGTEModel {
embeddings_norm,
cos_cache,
sin_cache,
classifier,
pool,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand Down Expand Up @@ -457,7 +514,20 @@ impl Model for FlashGTEModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}
2 changes: 2 additions & 0 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::layers::HiddenAct;
use crate::models::PositionEmbeddingType;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
Expand Down Expand Up @@ -32,4 +33,5 @@ pub struct GTEConfig {
pub logn_attention_scale: bool,
#[serde(default)]
pub logn_attention_clip1: bool,
pub id2label: Option<HashMap<String, String>>,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
source: backends/candle/tests/test_flash_gte.rs
assertion_line: 86
expression: predictions_single
---
- - -0.74365234
8 changes: 4 additions & 4 deletions backends/candle/tests/test_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn test_mini() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down Expand Up @@ -73,7 +73,7 @@ fn test_mini_pooled_raw() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -142,7 +142,7 @@ fn test_emotions() -> Result<()> {
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;

let input_batch = batch(
vec![
Expand Down Expand Up @@ -192,7 +192,7 @@ fn test_bert_classification() -> Result<()> {
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
Expand Down
8 changes: 4 additions & 4 deletions backends/candle/tests/test_flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn test_flash_mini() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down Expand Up @@ -83,7 +83,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -156,7 +156,7 @@ fn test_flash_emotions() -> Result<()> {
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_batch = batch(
vec![
Expand Down Expand Up @@ -210,7 +210,7 @@ fn test_flash_bert_classification() -> Result<()> {
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
Expand Down
39 changes: 36 additions & 3 deletions backends/candle/tests/test_flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use text_embeddings_backend_core::{Backend, ModelType, Pool};

Expand All @@ -15,7 +15,7 @@ fn test_flash_gte() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -51,3 +51,36 @@ fn test_flash_gte() -> Result<()> {

Ok(())
}

#[test]
#[serial_test::serial]
#[cfg(all(
feature = "cuda",
any(feature = "flash-attn", feature = "flash-attn-v1")
))]
fn test_flash_gte_classification() -> Result<()> {
let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
.unwrap()],
[0].to_vec(),
vec![],
);

let predictions: Vec<Vec<f32>> = backend
.predict(input_single)?
.into_iter()
.map(|(_, v)| v)
.collect();
let predictions_single = SnapshotScores::from(predictions);

let matcher = relative_matcher();
insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher);

Ok(())
}
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_jina_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_jina_code_base() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_mistral() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_nomic_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn test_flash_qwen2() -> Result<()> {
};

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::LastToken),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_jina_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_jina_code_base() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_nomic_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
Loading

0 comments on commit 0bfeb7e

Please sign in to comment.