Skip to content

Commit

Permalink
fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Dec 11, 2024
1 parent 9ced12b commit 3d2cc71
Show file tree
Hide file tree
Showing 12 changed files with 3,597 additions and 275 deletions.
2 changes: 2 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ mod layer_norm;
mod linear;
#[allow(dead_code, unused)]
mod rms_norm;
mod rotary;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};
73 changes: 73 additions & 0 deletions backends/candle/src/layers/rotary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use candle::{DType, Device, Result, Tensor, D};
use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
pub factor: f32,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum RopeScaling {
Ntk(NTKScaling),
}

pub fn get_inv_freqs(
dim: usize,
base: f32,
device: &Device,
rope_scaling: Option<&RopeScaling>,
) -> Result<Tensor> {
let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
};

if let Some(rope_scaling) = rope_scaling {
match rope_scaling {
RopeScaling::Ntk(ntk_scaling) => {
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
return inv_freqs / s;
}
}
}
get_inv_freqs_inner(dim, base, device)
}

pub fn get_cos_sin(
length: usize,
inv_freqs: &Tensor,
dtype: DType,
repeat_freqs: bool,
) -> Result<(Tensor, Tensor)> {
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
.to_dtype(DType::F32)?
.reshape((length, 1))?;
let mut freqs = t.matmul(inv_freqs)?;
if repeat_freqs {
freqs = Tensor::cat(&[&freqs, &freqs], 1)?;
}

let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}

pub fn apply_rotary(
x: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_head_size: usize,
) -> Result<Tensor> {
let dim = attention_head_size / 2;
let x1 = x.narrow(D::Minus1, 0, dim)?;
let x2 = x.narrow(D::Minus1, dim, dim)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?;
Ok(rope)
}
40 changes: 17 additions & 23 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::models::{
GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, GTEMLP,
};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct GTEAttention {
Expand Down Expand Up @@ -74,7 +73,7 @@ impl GTEAttention {
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand Down Expand Up @@ -219,24 +218,19 @@ impl FlashGTEModel {
config.layer_norm_eps,
)?;

let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling {
let inv_freqs = candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta * factor,
vb.device(),
)?;
let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64;
inv_freqs / s
} else {
candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
)
}?;

let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;

let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;

Ok(Self {
word_embeddings,
Expand Down
16 changes: 11 additions & 5 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{MistralConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct MistralAttention {
Expand Down Expand Up @@ -90,7 +91,7 @@ impl MistralAttention {
self.num_key_value_heads,
)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand Down Expand Up @@ -267,13 +268,18 @@ impl FlashMistralModel {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

let inv_freqs = candle_rotary::inv_freqs(
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
None,
)?;
let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;
let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;

Ok(Self {
embeddings,
Expand Down
14 changes: 8 additions & 6 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
use crate::models::{Model, NomicConfig};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct NomicAttention {
Expand Down Expand Up @@ -68,7 +69,7 @@ impl NomicAttention {
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
let qkv = qkv.chunk(3, 1)?;

candle_rotary::apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;
apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;

let attention = flash_attn_varlen(
&qkv[0],
Expand Down Expand Up @@ -221,20 +222,21 @@ impl FlashNomicBertModel {
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;

let rotary_dim = encoder.layers[0].attention.attention_head_size;
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?;
let rotary_cache = candle_rotary::cos_sin(config.n_positions, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?;

let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
let new_base = (config.rotary_emb_base
* ((scaling_factor * config.n_positions as f32
/ config.max_trained_positions as f32)
- (scaling_factor - 1.0)))
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, new_base, vb.device())?;
Some(candle_rotary::cos_sin(
let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
Some(get_cos_sin(
config.n_positions,
&inv_freqs,
vb.dtype(),
false,
)?)
} else {
None
Expand Down
16 changes: 11 additions & 5 deletions backends/candle/src/models/flash_qwen2.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen2Config};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct Qwen2Attention {
Expand Down Expand Up @@ -98,7 +99,7 @@ impl Qwen2Attention {
self.num_key_value_heads,
)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand Down Expand Up @@ -277,13 +278,18 @@ impl FlashQwen2Model {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

let inv_freqs = candle_rotary::inv_freqs(
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
None,
)?;
let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;
let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;

Ok(Self {
embeddings,
Expand Down
Loading

0 comments on commit 3d2cc71

Please sign in to comment.