Skip to content

Commit

Permalink
feat(candle): add FlashMistral (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jun 27, 2024
1 parent 5c6151c commit 7c9b7cb
Show file tree
Hide file tree
Showing 39 changed files with 17,578 additions and 176 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
window_size_left: Option<usize>,
) -> Result<Tensor, candle::Error> {
let runtime_compute_cap = get_runtime_compute_cap();

if runtime_compute_cap == 75 {
if alibi_slopes.is_some() {
candle::bail!("Flash attention v1 does not support alibi");
}
if window_size_left.is_some() {
candle::bail!("Flash attention v1 does not support attention windowing");
}

#[cfg(feature = "flash-attn-v1")]
{
Expand All @@ -59,10 +63,12 @@ pub(crate) fn flash_attn_varlen(
} else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 {
#[cfg(feature = "flash-attn")]
{
use candle_flash_attn::{flash_attn_varlen, flash_attn_varlen_alibi};
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};

let window_size_right = if causal { Some(0) } else { None };

let attention = if let Some(alibi_slopes) = alibi_slopes {
flash_attn_varlen_alibi(
flash_attn_varlen_alibi_windowed(
q,
k,
v,
Expand All @@ -72,10 +78,11 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
window_size_left,
window_size_right,
)
} else {
flash_attn_varlen(
flash_attn_varlen_windowed(
q,
k,
v,
Expand All @@ -84,7 +91,8 @@ pub(crate) fn flash_attn_varlen(
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
window_size_left,
window_size_right,
)
};

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::Deserialize;
pub enum HiddenAct {
Gelu,
Relu,
#[serde(alias = "silu")]
Swiglu,
}

Expand Down
4 changes: 4 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
mod cublaslt;
mod layer_norm;
mod linear;
#[allow(dead_code, unused)]
mod rms_norm;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
96 changes: 96 additions & 0 deletions backends/candle/src/layers/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::VarBuilder;

#[derive(Debug)]
pub struct RMSNorm {
weight: Tensor,
epsilon: f32,
span: tracing::Span,
}

impl RMSNorm {
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
Ok(Self {
weight: vb
.get(hidden_size, "weight")
.or_else(|_| vb.get(hidden_size, "gamma"))?,
epsilon,
span: tracing::span!(tracing::Level::TRACE, "rms-norm"),
})
}

pub fn forward(
&self,
hidden_states: &Tensor,
residual: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu | Device::Metal(_) => {
let mut hidden_states = hidden_states.clone();
let residual_add = if let Some(residual) = residual {
let residual_add = hidden_states.add(residual)?;
hidden_states = residual_add.clone();
residual_add
} else {
hidden_states.clone()
};

let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = hidden_states.dim(D::Minus1)?;
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
let norm_hidden_states =
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let hidden_states_normed = hidden_states
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
Ok((
hidden_states_normed
.to_dtype(hidden_states_dtype)?
.broadcast_mul(&self.weight)?,
residual_add,
))
}
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::{fused_add_rms_norm, rms_norm};

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;

if let Some(residual) = residual {
let residual = residual.flatten_to(D::Minus2)?;

let (result, residual_add) = fused_add_rms_norm(
&hidden_states,
&residual,
&self.weight,
None,
self.epsilon,
)?;
Ok((
result.reshape(original_shape)?,
residual_add.reshape(original_shape)?,
))
} else {
let residual_add = hidden_states.clone();

let result = rms_norm(&hidden_states, &self.weight, None, self.epsilon)?;

Ok((
result.reshape(original_shape)?,
residual_add.reshape(original_shape)?,
))
}
}
#[cfg(not(feature = "cuda"))]
candle::bail!("`cuda` feature is not enabled")
}
}
}
}
97 changes: 78 additions & 19 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
MistralConfig, Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
FlashNomicBertModel,
FlashMistralModel, FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
Expand Down Expand Up @@ -56,6 +56,7 @@ enum Config {
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
NomicBert(NomicConfig),
Mistral(MistralConfig),
}

pub struct CandleBackend {
Expand All @@ -69,6 +70,54 @@ impl CandleBackend {
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
// Default files
let default_safetensors = model_path.join("model.safetensors");
let default_pytorch = model_path.join("pytorch_model.bin");

// Single Files
let model_files = if default_safetensors.exists() {
vec![default_safetensors]
} else if default_pytorch.exists() {
vec![default_pytorch]
}
// Sharded weights
else {
// Get index file
let index_file = model_path.join("model.safetensors.index.json");

// Parse file
let index_file_string: String = std::fs::read_to_string(&index_file)
.map_err(|err| BackendError::Start(err.to_string()))?;
let json: serde_json::Value = serde_json::from_str(&index_file_string)
.map_err(|err| BackendError::Start(err.to_string()))?;

let weight_map = match json.get("weight_map") {
None => {
return Err(BackendError::Start(format!(
"no weight map in {index_file:?}"
)));
}
Some(serde_json::Value::Object(map)) => map,
Some(_) => {
return Err(BackendError::Start(format!(
"weight map in {index_file:?} is not a map"
)));
}
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}

// Collect paths
safetensors_files
.iter()
.map(|n| model_path.join(n))
.collect()
};

// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))
.context("Unable to read config file")
Expand Down Expand Up @@ -115,17 +164,10 @@ impl CandleBackend {
)))
}?;

let safetensors_path = model_path.join("model.safetensors");
let vb = if safetensors_path.exists() {
unsafe {
VarBuilder::from_mmaped_safetensors(
&[model_path.join("model.safetensors")],
dtype,
&device,
)
}
let vb = if model_files.len() == 1 && model_files[0].extension().unwrap() == "bin" {
VarBuilder::from_pth(&model_files[0], dtype, &device)
} else {
VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device) }
}
.s()?;

Expand All @@ -136,7 +178,7 @@ impl CandleBackend {
)),
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
tracing::info!("Starting JinaBert model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
Expand All @@ -160,15 +202,19 @@ impl CandleBackend {
))
}
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting DistilBertModel model on {:?}", device);
tracing::info!("Starting DistilBert model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting NomicBertModel model on {:?}", device);
tracing::info!("Starting NomicBert model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
}
(Config::Mistral(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -198,7 +244,7 @@ impl CandleBackend {
} else {
match config {
BertConfigWrapper::JinaBert(config) => {
tracing::info!("Starting JinaBertModel model on {:?}", device);
tracing::info!("Starting JinaBert model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
BertConfigWrapper::JinaCodeBert(config) => {
Expand Down Expand Up @@ -245,7 +291,7 @@ impl CandleBackend {
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashDistilBertModel model on {:?}", device);
tracing::info!("Starting FlashDistilBert model on {:?}", device);
Ok(Box::new(
FlashDistilBertModel::load(vb, &config, model_type).s()?,
))
Expand All @@ -265,15 +311,28 @@ impl CandleBackend {
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
tracing::info!("Starting FlashNomicBert model on {:?}", device);
Ok(Box::new(
FlashNomicBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting NomicBertModel model on {:?}", device);
tracing::info!("Starting NomicBert model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::Mistral(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
{
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
tracing::info!("Starting FlashMistral model on {:?}", device);
Ok(Box::new(
FlashMistralModel::load(vb, &config, model_type).s()?,
))
}
};

Ok(Self {
Expand Down
6 changes: 6 additions & 0 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ impl BertModel {
(pool, Some(classifier), None)
}
ModelType::Embedding(pool) => {
if pool == Pool::LastToken {
candle::bail!("`last_token` is not supported for Bert");
}

let splade = if pool == Pool::Splade {
Some(BertSpladeHead::load_roberta(vb.clone(), config)?)
} else {
Expand Down Expand Up @@ -832,6 +836,8 @@ impl BertModel {
let pooled_embeddings = match self.pool {
// CLS pooling
Pool::Cls => outputs.i((.., 0))?,
// Last token pooling is not supported for this model
Pool::LastToken => unreachable!(),
// Mean pooling
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
Expand Down
Loading

0 comments on commit 7c9b7cb

Please sign in to comment.