diff --git a/router/src/lib.rs b/router/src/lib.rs index 540e61af..37cc3a73 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -91,13 +91,20 @@ pub async fn run( // Optionally download the pooling config. if pooling.is_none() { // If a pooling config exist, download it - let _ = download_pool_config(&api_repo).await; + let _ = download_pool_config(&api_repo).await.map_err(|err| { + tracing::warn!("Download failed: {err}"); + err + }); } - // Download sentence transformers config + // Download legacy sentence transformers config + // We don't warn on failure as it is a legacy file let _ = download_st_config(&api_repo).await; // Download new sentence transformers config - let _ = download_new_st_config(&api_repo).await; + let _ = download_new_st_config(&api_repo).await.map_err(|err| { + tracing::warn!("Download failed: {err}"); + err + }); // Download model from the Hub download_artifacts(&api_repo) @@ -387,10 +394,21 @@ fn get_backend_model_type( None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); - let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; - let config: PoolConfig = - serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?; - Pool::try_from(config)? + + match fs::read_to_string(config_path) { + Ok(config) => { + let config: PoolConfig = serde_json::from_str(&config) + .context("Failed to parse `1_Pooling/config.json`")?; + Pool::try_from(config)? + } + Err(err) => { + if !config.model_type.to_lowercase().contains("bert") { + return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model."); + } + tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling."); + text_embeddings_backend::Pool::Cls + } + } } }; Ok(text_embeddings_backend::ModelType::Embedding(pool))