Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jun 21, 2024
1 parent cfbaf65 commit 9941fcc
Show file tree
Hide file tree
Showing 21 changed files with 103 additions and 197 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/build_rocm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
type=semver,pattern=rocm-{{major}}.{{minor}}
type=raw,value=rocm-latest
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
id: build-and-push-rocm
uses: docker/build-push-action@v4
Expand All @@ -98,7 +98,7 @@
labels: ${{ steps.meta-rocm.outputs.labels }}
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max

- name: Extract metadata (tags, labels) for Docker
id: meta-rocm-grpc
uses: docker/[email protected]
Expand All @@ -113,7 +113,7 @@
type=semver,pattern=rocm-{{major}}.{{minor}}-grpc
type=raw,value=rocm-latest-grpc
type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc
- name: Build and push Docker image
id: build-and-push-rocm-grpc
uses: docker/build-push-action@v4
Expand Down
54 changes: 32 additions & 22 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel,
Model, NomicBertModel, NomicConfig,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel,
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::BertConfig;
use nohash_hasher::BuildNoHashHasher;
use serde::Deserialize;
use std::collections::HashMap;
Expand Down Expand Up @@ -133,7 +133,9 @@ impl CandleBackend {
}
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
(
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
Expand Down Expand Up @@ -171,8 +173,9 @@ impl CandleBackend {
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
}
#[cfg(feature = "cuda")]
(Config::JinaBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
Expand All @@ -181,25 +184,32 @@ impl CandleBackend {
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,))
Ok(Box::new(
FlashJinaBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting JinaBertModel model on {:?}", device);
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
Ok(Box::new(
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
Ok(Box::new(
JinaCodeBertModel::load(vb, &config, model_type).s()?,
))
}
}
#[cfg(feature = "cuda")]
Expand Down
6 changes: 3 additions & 3 deletions backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ extern crate accelerate_src;
mod bert;
mod distilbert;
mod jina;
mod jina_code;
mod nomic;

#[cfg(feature = "cuda")]
Expand All @@ -27,8 +28,8 @@ mod flash_distilbert;
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
pub use jina::{JinaConfig, JinaBertModel};
pub use jina_code::{JinaCodeConfig, JinaCodeBertModel};
pub use jina::{JinaBertModel, JinaConfig};
pub use jina_code::{JinaCodeBertModel, JinaCodeConfig};
pub use nomic::{NomicBertModel, NomicConfig};
use text_embeddings_backend_core::Batch;

Expand All @@ -41,7 +42,6 @@ pub use flash_jina::FlashJinaBertModel;
#[cfg(feature = "cuda")]
pub use flash_jina_code::FlashJinaCodeBertModel;


#[cfg(feature = "cuda")]
pub use flash_nomic::FlashNomicBertModel;

Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::{JinaConfig, BertEmbeddings};
use crate::models::jina::BertEmbeddings;
use crate::models::jina::{BertEmbeddings, JinaConfig};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
Expand Down
31 changes: 23 additions & 8 deletions backends/candle/src/models/flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes;
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::bert::PositionEmbeddingType;
use crate::models::jina::{JinaCodeConfig, BertEmbeddings};
use crate::models::jina::{BertEmbeddings, JinaCodeConfig};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
Expand All @@ -28,7 +28,11 @@ struct AlibiBertAttention {
}

impl AlibiBertAttention {
pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi_slopes: Option<Tensor>) -> Result<Self> {
pub fn load(
vb: VarBuilder,
config: &JinaCodeConfig,
alibi_slopes: Option<Tensor>,
) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let hidden_size = config.hidden_size;
Expand Down Expand Up @@ -116,9 +120,15 @@ impl AlibiBertAttention {
new_qkv_shape.push(self.num_attention_heads);
new_qkv_shape.push(self.attention_head_size);

let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let query_layer = query_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;
let key_layer = key_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;
let value_layer = value_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;

let attention = flash_attn_varlen(
query_layer,
Expand All @@ -135,7 +145,9 @@ impl AlibiBertAttention {
let attention = attention.flatten_from(candle::D::Minus2)?;

let hidden_states = self.dense.forward(&attention)?;
let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?;
let hidden_states = self
.layer_norm_out
.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -168,7 +180,10 @@ impl JinaBertLayer {
.pp("mlp")
.pp("down_layer")
.get((config.hidden_size, config.intermediate_size), "weight")?;
let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?;
let down_bias = vb
.pp("mlp")
.pp("down_layer")
.get(config.hidden_size, "bias")?;
let down_layer = Linear::new(down_weight, Some(down_bias), None);

let layer_norm_1 = LayerNorm::load(
Expand Down Expand Up @@ -455,4 +470,4 @@ impl Model for FlashJinaCodeBertModel {
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
}
}
1 change: 0 additions & 1 deletion backends/candle/src/models/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub struct JinaConfig {
pub id2label: Option<HashMap<String, String>>,
}


#[derive(Debug)]
pub struct BertEmbeddings {
word_embeddings: Embedding,
Expand Down
22 changes: 16 additions & 6 deletions backends/candle/src/models/jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub struct JinaCodeConfig {
pub id2label: Option<HashMap<String, String>>,
}


#[derive(Debug)]
pub struct BertEmbeddings {
word_embeddings: Embedding,
Expand Down Expand Up @@ -201,9 +200,15 @@ impl BertAttention {
new_qkv_shape.push(self.num_attention_heads);
new_qkv_shape.push(self.attention_head_size);

let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let query_layer = query_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;
let key_layer = key_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;
let value_layer = value_layer
.reshape(new_qkv_shape.as_slice())?
.transpose(1, 2)?;

#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
Expand Down Expand Up @@ -276,7 +281,9 @@ impl BertAttention {
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;

let hidden_states = self.dense.forward(&context_layer)?;
let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?;
let hidden_states = self
.layer_norm_out
.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -309,7 +316,10 @@ impl JinaCodeBertLayer {
.pp("mlp")
.pp("down_layer")
.get((config.hidden_size, config.intermediate_size), "weight")?;
let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?;
let down_bias = vb
.pp("mlp")
.pp("down_layer")
.get(config.hidden_size, "bias")?;
let down_layer = Linear::new(down_weight, Some(down_bias), None);

let layer_norm_1 = LayerNorm::load(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
is_causal,
False,
None,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
self.variance_epsilon = config.layer_norm_eps

def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
Expand All @@ -51,4 +51,4 @@ def forward(self, hidden_states, residual=None):

return hidden_states, residual
else:
raise ValueError("System not recognized")
raise ValueError("System not recognized")
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def mean_pooling(embedding, cu_seqlens, max_s):
indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()

embedding_padded = pad_input(embedding, indices, batch_size, max_s)

sum_embeddings = torch.sum(embedding_padded, 1)

return sum_embeddings / seqlens[:, None]
return sum_embeddings / seqlens[:, None]
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,4 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
for i in range(len(batch))
]
else:
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")
raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend")
Loading

0 comments on commit 9941fcc

Please sign in to comment.