From 0a0384f424cc982fdfb717e2f275608d45c4a563 Mon Sep 17 00:00:00 2001 From: Hyeongchan Kim Date: Thu, 28 Nov 2024 00:15:09 +0900 Subject: [PATCH 1/2] Support for Classification Head in GTE Family (#438) --- Cargo.lock | 9 +++ README.md | 15 ++--- backends/candle/src/models/flash_bert.rs | 1 + backends/candle/src/models/flash_gte.rs | 26 +++++++- backends/candle/src/models/gte.rs | 63 +++++++++++++++++++ ..._flash_gte__gte_classification_single.snap | 6 ++ backends/candle/tests/test_flash_gte.rs | 40 +++++++++++- backends/ort/Cargo.toml | 2 +- 8 files changed, 149 insertions(+), 13 deletions(-) create mode 100644 backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap diff --git a/Cargo.lock b/Cargo.lock index 815568ee..9866d52a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3055,6 +3055,15 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "portable-atomic-util" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" diff --git a/README.md b/README.md index 002bcdfc..7584d987 100644 --- a/README.md +++ b/README.md @@ -92,11 +92,12 @@ Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence Below are some examples of the currently supported models: -| Task | Model Type | Model ID | -|--------------------|-------------|---------------------------------------------------------------------------------------------| -| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | -| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | -| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | +| Task | Model Type | Model ID | +|--------------------|-------------|-----------------------------------------------------------------------------------------------------------------| +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | +| Re-Ranking | GTE | [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) | +| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | ### Docker @@ -372,7 +373,7 @@ docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingf ### Using Re-rankers models -`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models. +`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa, XLM-RoBERTa, and GTE Sequence Classification models. Re-rankers models are Sequence Classification cross-encoders models with a single class that scores the similarity between a query and a text. @@ -392,7 +393,7 @@ And then you can rank the similarity between a query and a list of texts with: ```bash curl 127.0.0.1:8080/rerank \ -X POST \ - -d '{"query":"What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \ + -d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \ -H 'Content-Type: application/json' ``` diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index 8951dcc1..bfd29cab 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -529,6 +529,7 @@ impl Model for FlashBertModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index f3aec220..088fa42f 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,5 +1,6 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; +use crate::models::gte::{ClassificationHead, GTEClassificationHead}; use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -205,6 +206,7 @@ pub struct FlashGTEModel { embeddings_norm: LayerNorm, cos_cache: Tensor, sin_cache: Tensor, + classifier: Option>, pool: Pool, pub device: Device, @@ -233,11 +235,15 @@ impl FlashGTEModel { candle::bail!("Only `PositionEmbeddingType::Rope` is supported"); } - let pool = match model_type { + let (pool, classifier) = match model_type { ModelType::Classifier => { - candle::bail!("`classifier` model type is not supported for GTE") + let pool = Pool::Cls; + + let classifier: Box = + Box::new(GTEClassificationHead::load(vb.clone(), config)?); + (pool, Some(classifier)) } - ModelType::Embedding(pool) => pool, + ModelType::Embedding(pool) => (pool, None), }; let word_embeddings = Embedding::new( @@ -292,6 +298,7 @@ impl FlashGTEModel { embeddings_norm, cos_cache, sin_cache, + classifier, pool, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), @@ -457,7 +464,20 @@ impl Model for FlashGTEModel { fn is_padded(&self) -> bool { false } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } } diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index e5e75638..e0eadb35 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,10 @@ use crate::layers::HiddenAct; +use crate::layers::Linear; use crate::models::PositionEmbeddingType; +use candle::{Result, Tensor}; +use candle_nn::VarBuilder; use serde::Deserialize; +use std::collections::HashMap; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct NTKScaling { @@ -32,4 +36,63 @@ pub struct GTEConfig { pub logn_attention_scale: bool, #[serde(default)] pub logn_attention_clip1: bool, + pub id2label: Option>, +} + +pub trait ClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result; +} + +pub struct GTEClassificationHead { + pooler: Option, + classifier: Linear, + span: tracing::Span, +} + +impl GTEClassificationHead { + #[allow(dead_code)] + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let classifier_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); + + Ok(Self { + classifier, + pooler, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } +} + +impl ClassificationHead for GTEClassificationHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.classifier.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } } diff --git a/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap new file mode 100644 index 00000000..2de17c81 --- /dev/null +++ b/backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap @@ -0,0 +1,6 @@ +--- +source: backends/candle/tests/test_flash_gte.rs +assertion_line: 86 +expression: predictions_single +--- +- - -0.74365234 diff --git a/backends/candle/tests/test_flash_gte.rs b/backends/candle/tests/test_flash_gte.rs index 20b06b2f..1ecfc059 100644 --- a/backends/candle/tests/test_flash_gte.rs +++ b/backends/candle/tests/test_flash_gte.rs @@ -1,9 +1,9 @@ #![allow(dead_code, unused_imports)] mod common; -use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; use anyhow::Result; -use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; use text_embeddings_backend_candle::CandleBackend; use text_embeddings_backend_core::{Backend, ModelType, Pool}; @@ -51,3 +51,39 @@ fn test_flash_gte() -> Result<()> { Ok(()) } + +#[test] +#[serial_test::serial] +#[cfg(all( + feature = "cuda", + any(feature = "flash-attn", feature = "flash-attn-v1") +))] +fn test_flash_gte_classification() -> Result<()> { + let model_root = download_artifacts( + "Alibaba-NLP/gte-multilingual-reranker-base", + Some("refs/pr/11"), + )?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?; + + let input_single = batch( + vec![tokenizer + .encode(("What is Deep Learning?", "Deep Learning is not..."), true) + .unwrap()], + [0].to_vec(), + vec![], + ); + + let predictions: Vec> = backend + .predict(input_single)? + .into_iter() + .map(|(_, v)| v) + .collect(); + let predictions_single = SnapshotScores::from(predictions); + + let matcher = relative_matcher(); + insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher); + + Ok(()) +} diff --git a/backends/ort/Cargo.toml b/backends/ort/Cargo.toml index db2795a6..40ccc560 100644 --- a/backends/ort/Cargo.toml +++ b/backends/ort/Cargo.toml @@ -10,7 +10,7 @@ anyhow = { workspace = true } nohash-hasher = { workspace = true } ndarray = "0.16.1" num_cpus = { workspace = true } -ort = { version = "2.0.0-rc.4", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } +ort = { version = "2.0.0-rc.8", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } thiserror = { workspace = true } From 59740cce46bb0a0db5ddd24c4f934146088b4a4b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:32:59 +0100 Subject: [PATCH 2/2] feat: GTE classification head --- Cargo.lock | 9 --- backends/candle/src/models/flash_gte.rs | 58 ++++++++++++++++-- backends/candle/src/models/gte.rs | 61 ------------------- backends/candle/tests/test_bert.rs | 8 +-- backends/candle/tests/test_flash_bert.rs | 8 +-- backends/candle/tests/test_flash_gte.rs | 9 +-- backends/candle/tests/test_flash_jina.rs | 2 +- backends/candle/tests/test_flash_jina_code.rs | 2 +- backends/candle/tests/test_flash_mistral.rs | 2 +- backends/candle/tests/test_flash_nomic.rs | 2 +- backends/candle/tests/test_flash_qwen2.rs | 2 +- backends/candle/tests/test_jina.rs | 2 +- backends/candle/tests/test_jina_code.rs | 2 +- backends/candle/tests/test_nomic.rs | 2 +- 14 files changed, 73 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9866d52a..9bcf91ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3046,15 +3046,6 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" -[[package]] -name = "portable-atomic-util" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" -dependencies = [ - "portable-atomic", -] - [[package]] name = "portable-atomic-util" version = "0.2.3" diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 088fa42f..53e62f6d 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,6 +1,5 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::gte::{ClassificationHead, GTEClassificationHead}; use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -199,6 +198,58 @@ impl GTELayer { } } +pub struct GTEClassificationHead { + pooler: Option, + classifier: Linear, + span: tracing::Span, +} + +impl GTEClassificationHead { + #[allow(dead_code)] + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let classifier_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); + + Ok(Self { + classifier, + pooler, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.classifier.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + pub struct FlashGTEModel { word_embeddings: Embedding, token_type_embeddings: Option, @@ -206,7 +257,7 @@ pub struct FlashGTEModel { embeddings_norm: LayerNorm, cos_cache: Tensor, sin_cache: Tensor, - classifier: Option>, + classifier: Option, pool: Pool, pub device: Device, @@ -239,8 +290,7 @@ impl FlashGTEModel { ModelType::Classifier => { let pool = Pool::Cls; - let classifier: Box = - Box::new(GTEClassificationHead::load(vb.clone(), config)?); + let classifier = GTEClassificationHead::load(vb.clone(), config)?; (pool, Some(classifier)) } ModelType::Embedding(pool) => (pool, None), diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index e0eadb35..bc4bfdce 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,8 +1,5 @@ use crate::layers::HiddenAct; -use crate::layers::Linear; use crate::models::PositionEmbeddingType; -use candle::{Result, Tensor}; -use candle_nn::VarBuilder; use serde::Deserialize; use std::collections::HashMap; @@ -38,61 +35,3 @@ pub struct GTEConfig { pub logn_attention_clip1: bool, pub id2label: Option>, } - -pub trait ClassificationHead { - fn forward(&self, hidden_states: &Tensor) -> Result; -} - -pub struct GTEClassificationHead { - pooler: Option, - classifier: Linear, - span: tracing::Span, -} - -impl GTEClassificationHead { - #[allow(dead_code)] - pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { - let n_classes = match &config.id2label { - None => candle::bail!("`id2label` must be set for classifier models"), - Some(id2label) => id2label.len(), - }; - - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; - - let classifier_weight = vb - .pp("classifier") - .get((n_classes, config.hidden_size), "weight")?; - let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; - let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); - - Ok(Self { - classifier, - pooler, - span: tracing::span!(tracing::Level::TRACE, "classifier"), - }) - } -} - -impl ClassificationHead for GTEClassificationHead { - fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let mut hidden_states = hidden_states.unsqueeze(1)?; - if let Some(pooler) = self.pooler.as_ref() { - hidden_states = pooler.forward(&hidden_states)?; - hidden_states = hidden_states.tanh()?; - } - - let hidden_states = self.classifier.forward(&hidden_states)?; - let hidden_states = hidden_states.squeeze(1)?; - Ok(hidden_states) - } -} diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index 1bd5017f..35fb2f83 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -13,7 +13,7 @@ fn test_mini() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float32".to_string(), ModelType::Embedding(Pool::Mean), )?; @@ -73,7 +73,7 @@ fn test_mini_pooled_raw() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float32".to_string(), ModelType::Embedding(Pool::Cls), )?; @@ -142,7 +142,7 @@ fn test_emotions() -> Result<()> { let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?; let tokenizer = load_tokenizer(&model_root)?; - let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?; + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; let input_batch = batch( vec![ @@ -192,7 +192,7 @@ fn test_bert_classification() -> Result<()> { let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?; let tokenizer = load_tokenizer(&model_root)?; - let backend = CandleBackend::new(model_root, "float32".to_string(), ModelType::Classifier)?; + let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; let input_single = batch( vec![tokenizer diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index ea150e7f..3333e2f5 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -19,7 +19,7 @@ fn test_flash_mini() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Mean), )?; @@ -83,7 +83,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Cls), )?; @@ -156,7 +156,7 @@ fn test_flash_emotions() -> Result<()> { let model_root = download_artifacts("SamLowe/roberta-base-go_emotions", None)?; let tokenizer = load_tokenizer(&model_root)?; - let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?; + let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?; let input_batch = batch( vec![ @@ -210,7 +210,7 @@ fn test_flash_bert_classification() -> Result<()> { let model_root = download_artifacts("ibm/re2g-reranker-nq", Some("refs/pr/3"))?; let tokenizer = load_tokenizer(&model_root)?; - let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?; + let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?; let input_single = batch( vec![tokenizer diff --git a/backends/candle/tests/test_flash_gte.rs b/backends/candle/tests/test_flash_gte.rs index 1ecfc059..c8012eb2 100644 --- a/backends/candle/tests/test_flash_gte.rs +++ b/backends/candle/tests/test_flash_gte.rs @@ -15,7 +15,7 @@ fn test_flash_gte() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Cls), )?; @@ -59,13 +59,10 @@ fn test_flash_gte() -> Result<()> { any(feature = "flash-attn", feature = "flash-attn-v1") ))] fn test_flash_gte_classification() -> Result<()> { - let model_root = download_artifacts( - "Alibaba-NLP/gte-multilingual-reranker-base", - Some("refs/pr/11"), - )?; + let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?; let tokenizer = load_tokenizer(&model_root)?; - let backend = CandleBackend::new(model_root, "float16".to_string(), ModelType::Classifier)?; + let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?; let input_single = batch( vec![tokenizer diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index 255b82a2..d0ff5cf7 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -15,7 +15,7 @@ fn test_flash_jina_small() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index d84848dc..aa518b8e 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -15,7 +15,7 @@ fn test_flash_jina_code_base() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_flash_mistral.rs b/backends/candle/tests/test_flash_mistral.rs index 71749c8b..2c3ac47f 100644 --- a/backends/candle/tests/test_flash_mistral.rs +++ b/backends/candle/tests/test_flash_mistral.rs @@ -15,7 +15,7 @@ fn test_flash_mistral() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 263bbe43..ad45438e 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -15,7 +15,7 @@ fn test_flash_nomic_small() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_flash_qwen2.rs b/backends/candle/tests/test_flash_qwen2.rs index 38e45553..2b1d78b3 100644 --- a/backends/candle/tests/test_flash_qwen2.rs +++ b/backends/candle/tests/test_flash_qwen2.rs @@ -39,7 +39,7 @@ fn test_flash_qwen2() -> Result<()> { }; let backend = CandleBackend::new( - model_root, + &model_root, "float16".to_string(), ModelType::Embedding(Pool::LastToken), )?; diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 4aa30d03..ae162368 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -12,7 +12,7 @@ fn test_jina_small() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float32".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 6c3b3f20..83781ffa 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -12,7 +12,7 @@ fn test_jina_code_base() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float32".to_string(), ModelType::Embedding(Pool::Mean), )?; diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index ce0a4559..6c444ea7 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -12,7 +12,7 @@ fn test_nomic_small() -> Result<()> { let tokenizer = load_tokenizer(&model_root)?; let backend = CandleBackend::new( - model_root, + &model_root, "float32".to_string(), ModelType::Embedding(Pool::Mean), )?;