Skip to content

Commit

Permalink
feat(candle): Support for Jina Code model (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricebechard authored Jun 21, 2024
1 parent ce2f210 commit b8f6c78
Show file tree
Hide file tree
Showing 15 changed files with 4,477 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Examples of supported models:
| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |

You can explore the list of best performing text embeddings
models [here](https://huggingface.co/spaces/mteb/leaderboard).
Expand Down
67 changes: 47 additions & 20 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
NomicConfig, PositionEmbeddingType,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
Expand All @@ -37,6 +37,10 @@ enum Config {
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "jina_bert"))]
JinaBert(JinaConfig),
#[serde(rename(deserialize = "jina_code_bert"))]
JinaCodeBert(JinaCodeConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
Expand Down Expand Up @@ -120,13 +124,16 @@ impl CandleBackend {
"`cuda` feature is not enabled".to_string(),
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
(
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
Expand Down Expand Up @@ -157,23 +164,43 @@ impl CandleBackend {
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting FlashBert model on {:?}", device);
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
}
} else {
if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting Bert model on {:?}", device);
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(
Expand Down
10 changes: 9 additions & 1 deletion backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod flash_bert;
#[cfg(feature = "cuda")]
mod flash_jina;

#[cfg(feature = "cuda")]
mod flash_jina_code;

#[cfg(feature = "cuda")]
mod flash_nomic;

Expand All @@ -24,7 +27,8 @@ mod flash_distilbert;
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
pub use jina::JinaBertModel;
pub use jina::{JinaConfig, JinaBertModel};
pub use jina_code::{JinaCodeConfig, JinaCodeBertModel};
pub use nomic::{NomicBertModel, NomicConfig};
use text_embeddings_backend_core::Batch;

Expand All @@ -34,6 +38,10 @@ pub use flash_bert::FlashBertModel;
#[cfg(feature = "cuda")]
pub use flash_jina::FlashJinaBertModel;

#[cfg(feature = "cuda")]
pub use flash_jina_code::FlashJinaCodeBertModel;


#[cfg(feature = "cuda")]
pub use flash_nomic::FlashNomicBertModel;

Expand Down
3 changes: 2 additions & 1 deletion backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::{BertConfig, PositionEmbeddingType};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::{JinaConfig, BertEmbeddings};
use crate::models::jina::BertEmbeddings;
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
Expand Down
Loading

0 comments on commit b8f6c78

Please sign in to comment.