diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml index 8a9fde49..5971ed78 100644 --- a/.github/workflows/build_rocm.yaml +++ b/.github/workflows/build_rocm.yaml @@ -79,7 +79,7 @@ type=semver,pattern=rocm-{{major}}.{{minor}} type=raw,value=rocm-latest type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} - + - name: Build and push Docker image id: build-and-push-rocm uses: docker/build-push-action@v4 @@ -98,7 +98,7 @@ labels: ${{ steps.meta-rocm.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max - + - name: Extract metadata (tags, labels) for Docker id: meta-rocm-grpc uses: docker/metadata-action@v4.3.0 @@ -113,7 +113,7 @@ type=semver,pattern=rocm-{{major}}.{{minor}}-grpc type=raw,value=rocm-latest-grpc type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc - + - name: Build and push Docker image id: build-and-push-rocm-grpc uses: docker/build-push-action@v4 diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 27c5d843..4af5b704 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,17 +11,17 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel, - Model, NomicBertModel, NomicConfig, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel, + JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig, }; #[cfg(feature = "cuda")] use crate::models::{ - FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel, + FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, + FlashNomicBertModel, }; use anyhow::Context; use candle::{DType, Device}; use candle_nn::VarBuilder; -use models::BertConfig; use nohash_hasher::BuildNoHashHasher; use serde::Deserialize; use std::collections::HashMap; @@ -133,7 +133,9 @@ impl CandleBackend { } (Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => { tracing::info!("Starting JinaCodeBertModel model on {:?}", device); - Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) + Ok(Box::new( + JinaCodeBertModel::load(vb, &config, model_type).s()?, + )) } ( Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config), @@ -171,8 +173,9 @@ impl CandleBackend { Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) } } - #[cfg(feature = "cuda")] - (Config::JinaBert(config), Device::Cuda(_)) => { + } + #[cfg(feature = "cuda")] + (Config::JinaBert(config), Device::Cuda(_)) => { if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) && dtype == DType::F16 && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) @@ -181,25 +184,32 @@ impl CandleBackend { && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" { tracing::info!("Starting FlashJinaBertModel model on {:?}", device); - Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,)) + Ok(Box::new( + FlashJinaBertModel::load(vb, &config, model_type).s()?, + )) } else { tracing::info!("Starting JinaBertModel model on {:?}", device); Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) } - #[cfg(feature = "cuda")] - (Config::JinaCodeBert(config), Device::Cuda(_)) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { - tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device); - Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,)) - } else { - tracing::info!("Starting JinaCodeBertModel model on {:?}", device); - Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) + } + #[cfg(feature = "cuda")] + (Config::JinaCodeBert(config), Device::Cuda(_)) => { + if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) + && dtype == DType::F16 + && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) + // Allow disabling because of flash attention v1 precision problems + // See: https://github.com/huggingface/text-embeddings-inference/issues/37 + && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + { + tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device); + Ok(Box::new( + FlashJinaCodeBertModel::load(vb, &config, model_type).s()?, + )) + } else { + tracing::info!("Starting JinaCodeBertModel model on {:?}", device); + Ok(Box::new( + JinaCodeBertModel::load(vb, &config, model_type).s()?, + )) } } #[cfg(feature = "cuda")] diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index 7f098cfb..d64f5a35 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -7,6 +7,7 @@ extern crate accelerate_src; mod bert; mod distilbert; mod jina; +mod jina_code; mod nomic; #[cfg(feature = "cuda")] @@ -27,8 +28,8 @@ mod flash_distilbert; pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; -pub use jina::{JinaConfig, JinaBertModel}; -pub use jina_code::{JinaCodeConfig, JinaCodeBertModel}; +pub use jina::{JinaBertModel, JinaConfig}; +pub use jina_code::{JinaCodeBertModel, JinaCodeConfig}; pub use nomic::{NomicBertModel, NomicConfig}; use text_embeddings_backend_core::Batch; @@ -41,7 +42,6 @@ pub use flash_jina::FlashJinaBertModel; #[cfg(feature = "cuda")] pub use flash_jina_code::FlashJinaCodeBertModel; - #[cfg(feature = "cuda")] pub use flash_nomic::FlashNomicBertModel; diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index e128252a..0e1d3006 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -2,8 +2,8 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::{JinaConfig, BertEmbeddings}; use crate::models::jina::BertEmbeddings; +use crate::models::jina::{BertEmbeddings, JinaConfig}; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 97ca5fc0..0779c5d7 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::{JinaCodeConfig, BertEmbeddings}; +use crate::models::jina::{BertEmbeddings, JinaCodeConfig}; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; @@ -28,7 +28,11 @@ struct AlibiBertAttention { } impl AlibiBertAttention { - pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi_slopes: Option) -> Result { + pub fn load( + vb: VarBuilder, + config: &JinaCodeConfig, + alibi_slopes: Option, + ) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let hidden_size = config.hidden_size; @@ -116,9 +120,15 @@ impl AlibiBertAttention { new_qkv_shape.push(self.num_attention_heads); new_qkv_shape.push(self.attention_head_size); - let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let query_layer = query_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let key_layer = key_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let value_layer = value_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; let attention = flash_attn_varlen( query_layer, @@ -135,7 +145,9 @@ impl AlibiBertAttention { let attention = attention.flatten_from(candle::D::Minus2)?; let hidden_states = self.dense.forward(&attention)?; - let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + let hidden_states = self + .layer_norm_out + .forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -168,7 +180,10 @@ impl JinaBertLayer { .pp("mlp") .pp("down_layer") .get((config.hidden_size, config.intermediate_size), "weight")?; - let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_bias = vb + .pp("mlp") + .pp("down_layer") + .get(config.hidden_size, "bias")?; let down_layer = Linear::new(down_weight, Some(down_bias), None); let layer_norm_1 = LayerNorm::load( @@ -455,4 +470,4 @@ impl Model for FlashJinaCodeBertModel { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } -} \ No newline at end of file +} diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 3f5d5916..b1d75d94 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -30,7 +30,6 @@ pub struct JinaConfig { pub id2label: Option>, } - #[derive(Debug)] pub struct BertEmbeddings { word_embeddings: Embedding, diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index cb4084d7..ec8a8c84 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -30,7 +30,6 @@ pub struct JinaCodeConfig { pub id2label: Option>, } - #[derive(Debug)] pub struct BertEmbeddings { word_embeddings: Embedding, @@ -201,9 +200,15 @@ impl BertAttention { new_qkv_shape.push(self.num_attention_heads); new_qkv_shape.push(self.attention_head_size); - let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let query_layer = query_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let key_layer = key_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let value_layer = value_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = @@ -276,7 +281,9 @@ impl BertAttention { let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; let hidden_states = self.dense.forward(&context_layer)?; - let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + let hidden_states = self + .layer_norm_out + .forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -309,7 +316,10 @@ impl JinaCodeBertLayer { .pp("mlp") .pp("down_layer") .get((config.hidden_size, config.intermediate_size), "weight")?; - let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_bias = vb + .pp("mlp") + .pp("down_layer") + .get(config.hidden_size, "bias")?; let down_layer = Linear::new(down_weight, Some(down_bias), None); let layer_norm_1 = LayerNorm::load( diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py index 365e5451..9ed9004c 100644 --- a/backends/python/server/text_embeddings_server/layers/attention/rocm.py +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -42,4 +42,4 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): is_causal, False, None, - ) \ No newline at end of file + ) diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py index abd9e676..0834b734 100644 --- a/backends/python/server/text_embeddings_server/layers/layernorm.py +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -41,7 +41,7 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) self.variance_epsilon = config.layer_norm_eps - + def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual @@ -51,4 +51,4 @@ def forward(self, hidden_states, residual=None): return hidden_states, residual else: - raise ValueError("System not recognized") \ No newline at end of file + raise ValueError("System not recognized") diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py index 1bccbc57..7eaddb6b 100644 --- a/backends/python/server/text_embeddings_server/layers/pooling.py +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -16,7 +16,7 @@ def mean_pooling(embedding, cu_seqlens, max_s): indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() embedding_padded = pad_input(embedding, indices, batch_size, max_s) - + sum_embeddings = torch.sum(embedding_padded, 1) - return sum_embeddings / seqlens[:, None] \ No newline at end of file + return sum_embeddings / seqlens[:, None] diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 6ebb70d4..40003013 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -234,4 +234,4 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: for i in range(len(batch)) ] else: - raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index ef33b7d2..8b137891 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -1,119 +1 @@ -mod logging; -mod management; -use backend_grpc_client::Client; -use nohash_hasher::BuildNoHashHasher; -use std::collections::HashMap; -use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, -}; -use tokio::runtime::Runtime; - -pub struct PythonBackend { - _backend_process: management::BackendProcess, - tokio_runtime: Runtime, - backend_client: Client, -} - -impl PythonBackend { - pub fn new( - model_path: String, - dtype: String, - model_type: ModelType, - uds_path: String, - otlp_endpoint: Option, - otlp_service_name: String, - pooling_mode: String, - ) -> Result { - match model_type { - ModelType::Classifier => { - return Err(BackendError::Start( - "`classifier` model type is not supported".to_string(), - )) - } - ModelType::Embedding(pool) => { - if pool != Pool::Cls && pool != Pool::Mean { - return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); - } - pool - } - }; - - let backend_process = management::BackendProcess::new( - model_path, - dtype, - &uds_path, - otlp_endpoint, - otlp_service_name, - pooling_mode, - )?; - let tokio_runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; - - let backend_client = tokio_runtime - .block_on(Client::connect_uds(uds_path)) - .map_err(|err| { - BackendError::Start(format!("Could not connect to backend process: {err}")) - })?; - - Ok(Self { - _backend_process: backend_process, - tokio_runtime, - backend_client, - }) - } -} - -impl Backend for PythonBackend { - fn health(&self) -> Result<(), BackendError> { - if self - .tokio_runtime - .block_on(self.backend_client.clone().health()) - .is_err() - { - return Err(BackendError::Unhealthy); - } - Ok(()) - } - - fn is_padded(&self) -> bool { - false - } - - fn embed(&self, batch: Batch) -> Result { - if !batch.raw_indices.is_empty() { - return Err(BackendError::Inference( - "raw embeddings are not supported for the Python backend.".to_string(), - )); - } - let batch_size = batch.len(); - - let results = self - .tokio_runtime - .block_on(self.backend_client.clone().embed( - batch.input_ids, - batch.token_type_ids, - batch.position_ids, - batch.cumulative_seq_lengths, - batch.max_length, - )) - .map_err(|err| BackendError::Inference(err.to_string()))?; - let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); - - let mut embeddings = - HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, e) in pooled_embeddings.into_iter().enumerate() { - embeddings.insert(i, Embedding::Pooled(e)); - } - - Ok(embeddings) - } - - fn predict(&self, _batch: Batch) -> Result { - Err(BackendError::Inference( - "`predict` is not implemented".to_string(), - )) - } -} diff --git a/backends/src/lib.rs b/backends/src/lib.rs index db27cddc..d332b4a7 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,7 +39,6 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, - pooling_mode: String, ) -> Result { let (backend_sender, backend_receiver) = mpsc::unbounded_channel(); @@ -50,7 +49,6 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, - pooling_mode, )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -140,7 +138,6 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, - pooling_mode: String, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -161,7 +158,6 @@ fn init_backend( uds_path, otlp_endpoint, otlp_service_name, - pooling_mode, ) }) .join() diff --git a/docs/source/en/local_amd_gpu.md b/docs/source/en/local_amd_gpu.md index 8dc8e2de..2cfab5ac 100644 --- a/docs/source/en/local_amd_gpu.md +++ b/docs/source/en/local_amd_gpu.md @@ -37,4 +37,4 @@ and curl 127.0.0.1:80/embed \ -X POST -d '{"inputs":"What is Deep Learning?"}' \ -H 'Content-Type: application/json' -``` \ No newline at end of file +``` diff --git a/router/src/lib.rs b/router/src/lib.rs index 14f1dfb3..d2023515 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,11 +191,6 @@ pub async fn run( } }); - let pooling_str = match inferred_pooling { - Some(pool) => pool.to_string(), - None => "none".to_string(), - }; - // Create backend tracing::info!("Starting model backend"); let backend = text_embeddings_backend::Backend::new( @@ -205,7 +200,6 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), - pooling_str, ) .context("Could not create backend")?; backend @@ -312,24 +306,24 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, - pooling: &Option, -) -> Result<(text_embeddings_backend::ModelType, Option)> { + pooling: Option, +) -> Result { for arch in &config.architectures { - if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { - return Ok((text_embeddings_backend::ModelType::Embedding( + if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { + return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, - ), Some(text_embeddings_backend::Pool::Splade))); + )); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." ); } - return Ok((text_embeddings_backend::ModelType::Classifier, None)); + return Ok(text_embeddings_backend::ModelType::Classifier); } } - if Some(text_embeddings_backend::Pool::Splade) == *pooling { + if Some(text_embeddings_backend::Pool::Splade) == pooling { return Err(anyhow!( "Splade pooling is not supported: model is not a ForMaskedLM model" )); @@ -337,7 +331,7 @@ fn get_backend_model_type( // Set pooling let pool = match pooling { - Some(pool) => pool.clone(), + Some(pool) => pool, None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); @@ -353,7 +347,7 @@ fn get_backend_model_type( } } }; - Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool))) + Ok(text_embeddings_backend::ModelType::Embedding(pool)) } #[derive(Debug, Deserialize)] diff --git a/tests/README.md b/tests/README.md index c4ff5d0b..cfbf805c 100644 --- a/tests/README.md +++ b/tests/README.md @@ -31,4 +31,4 @@ Restart server with `USE_FLASH_ATTENTION=0`, and ``` python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 -``` \ No newline at end of file +``` diff --git a/tests/collect.py b/tests/collect.py index 313c0871..640f854c 100644 --- a/tests/collect.py +++ b/tests/collect.py @@ -34,4 +34,4 @@ save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" print(f"Saving embedding of shape {embedding.shape} to {save_path}") -torch.save(embedding, save_path) \ No newline at end of file +torch.save(embedding, save_path) diff --git a/tests/conftest.py b/tests/conftest.py index 6d8ed997..efdd6fc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ def __init__(self, process, port: int): def _inner_health(self) -> bool: return self.process.poll() is None - + def health(self, timeout: int = 60): assert timeout > 0 for _ in range(timeout): @@ -109,5 +109,5 @@ def local_launcher( if not use_flash_attention: del env["USE_FLASH_ATTENTION"] - - return local_launcher \ No newline at end of file + + return local_launcher diff --git a/tests/requirements.txt b/tests/requirements.txt index b1ee0f58..74d3b667 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,3 @@ pytest pytest-asyncio -aiohttp \ No newline at end of file +aiohttp diff --git a/tests/test_default_model.py b/tests/test_default_model.py index 595fe6bf..68499928 100644 --- a/tests/test_default_model.py +++ b/tests/test_default_model.py @@ -25,4 +25,4 @@ async def test_single_query(default_model): embedding = torch.Tensor(json.loads(response.text)) reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") - assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py index 3c3fde1c..04085522 100644 --- a/tests/test_flash_bert.py +++ b/tests/test_flash_bert.py @@ -25,4 +25,4 @@ async def test_single_query(default_model): embedding = torch.Tensor(json.loads(response.text)) reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") - assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3)