Skip to content

Commit

Permalink
feat(candle): add flash gte
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 27, 2024
1 parent 7c9b7cb commit 264a079
Show file tree
Hide file tree
Showing 9 changed files with 520 additions and 4 deletions.
25 changes: 21 additions & 4 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
MistralConfig, Model, NomicBertModel, NomicConfig,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
FlashMistralModel, FlashNomicBertModel,
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
Expand Down Expand Up @@ -57,6 +57,8 @@ enum Config {
#[serde(rename(deserialize = "nomic_bert"))]
NomicBert(NomicConfig),
Mistral(MistralConfig),
#[serde(rename = "new")]
Gte(GTEConfig),
}

pub struct CandleBackend {
Expand Down Expand Up @@ -215,6 +217,10 @@ impl CandleBackend {
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -333,6 +339,17 @@ impl CandleBackend {
FlashMistralModel::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::Gte(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
{
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
};

Ok(Self {
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum PositionEmbeddingType {
#[default]
Absolute,
Alibi,
Rope,
}

#[derive(Debug)]
Expand Down
Loading

0 comments on commit 264a079

Please sign in to comment.