Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GTE classification head #441

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence

Below are some examples of the currently supported models:

| Task | Model Type | Model ID |
|--------------------|-------------|---------------------------------------------------------------------------------------------|
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
| Task | Model Type | Model ID |
|--------------------|-------------|-----------------------------------------------------------------------------------------------------------------|
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) |
| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) |
| Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) |
| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) |

### Docker

Expand Down Expand Up @@ -372,7 +373,7 @@ docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingf

### Using Re-rankers models

`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa, XLM-RoBERTa, and GTE Sequence Classification models.
Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity
between a query and a text.

Expand All @@ -392,7 +393,7 @@ And then you can rank the similarity between a query and a list of texts with:
```bash
curl 127.0.0.1:8080/rerank \
-X POST \
-d '{"query":"What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
-d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
-H 'Content-Type: application/json'
```

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ impl Model for FlashBertModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
76 changes: 73 additions & 3 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,66 @@ impl GTELayer {
}
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}

pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct FlashGTEModel {
word_embeddings: Embedding,
token_type_embeddings: Option<Embedding>,
layers: Vec<GTELayer>,
embeddings_norm: LayerNorm,
cos_cache: Tensor,
sin_cache: Tensor,
classifier: Option<GTEClassificationHead>,
pool: Pool,
pub device: Device,

Expand Down Expand Up @@ -233,11 +286,14 @@ impl FlashGTEModel {
candle::bail!("Only `PositionEmbeddingType::Rope` is supported");
}

let pool = match model_type {
let (pool, classifier) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for GTE")
let pool = Pool::Cls;

let classifier = GTEClassificationHead::load(vb.clone(), config)?;
(pool, Some(classifier))
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => (pool, None),
};

let word_embeddings = Embedding::new(
Expand Down Expand Up @@ -292,6 +348,7 @@ impl FlashGTEModel {
embeddings_norm,
cos_cache,
sin_cache,
classifier,
pool,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand Down Expand Up @@ -457,7 +514,20 @@ impl Model for FlashGTEModel {
fn is_padded(&self) -> bool {
false
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classifier {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classifier) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classifier.forward(&pooled_embeddings)
}
}
}
}
2 changes: 2 additions & 0 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::layers::HiddenAct;
use crate::models::PositionEmbeddingType;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
Expand Down Expand Up @@ -32,4 +33,5 @@ pub struct GTEConfig {
pub logn_attention_scale: bool,
#[serde(default)]
pub logn_attention_clip1: bool,
pub id2label: Option<HashMap<String, String>>,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
source: backends/candle/tests/test_flash_gte.rs
assertion_line: 86
expression: predictions_single
---
- - -0.74365234
8 changes: 4 additions & 4 deletions backends/candle/tests/test_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn test_mini() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down Expand Up @@ -73,7 +73,7 @@ fn test_mini_pooled_raw() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -142,7 +142,7 @@ fn test_emotions() -> Result<()> {
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;

let input_batch = batch(
vec![
Expand Down Expand Up @@ -192,7 +192,7 @@ fn test_bert_classification() -> Result<()> {
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
Expand Down
8 changes: 4 additions & 4 deletions backends/candle/tests/test_flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn test_flash_mini() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down Expand Up @@ -83,7 +83,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -156,7 +156,7 @@ fn test_flash_emotions() -> Result<()> {
let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_batch = batch(
vec![
Expand Down Expand Up @@ -210,7 +210,7 @@ fn test_flash_bert_classification() -> Result<()> {
let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?;
let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
Expand Down
39 changes: 36 additions & 3 deletions backends/candle/tests/test_flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![allow(dead_code, unused_imports)]
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use text_embeddings_backend_core::{Backend, ModelType, Pool};

Expand All @@ -15,7 +15,7 @@ fn test_flash_gte() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down Expand Up @@ -51,3 +51,36 @@ fn test_flash_gte() -> Result<()> {

Ok(())
}

#[test]
#[serial_test::serial]
#[cfg(all(
feature = "cuda",
any(feature = "flash-attn", feature = "flash-attn-v1")
))]
fn test_flash_gte_classification() -> Result<()> {
let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?;

let input_single = batch(
vec![tokenizer
.encode(("What is Deep Learning?", "Deep Learning is not..."), true)
.unwrap()],
[0].to_vec(),
vec![],
);

let predictions: Vec<Vec<f32>> = backend
.predict(input_single)?
.into_iter()
.map(|(_, v)| v)
.collect();
let predictions_single = SnapshotScores::from(predictions);

let matcher = relative_matcher();
insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher);

Ok(())
}
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_jina_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_jina_code_base() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_mistral() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_flash_nomic_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn test_flash_qwen2() -> Result<()> {
};

let backend = CandleBackend::new(
model_root,
&model_root,
"float16".to_string(),
ModelType::Embedding(Pool::LastToken),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_jina_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_jina_code_base() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_nomic_small() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down
Loading
Loading