Skip to content

Commit

Permalink
chore(config): validate supported models (#3293)
Browse files Browse the repository at this point in the history
* validate supported models

* fix comments

* clean code

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Unit test for Config::validate_model_config

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
antimonyGu and autofix-ci[bot] authored Nov 1, 2024
1 parent 0934248 commit 8ffec60
Showing 1 changed file with 82 additions and 2 deletions.
84 changes: 82 additions & 2 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashSet, path::PathBuf};
use std::{collections::HashSet, path::PathBuf, process};

use anyhow::{anyhow, Context, Result};
use derive_builder::Builder;
Expand All @@ -9,7 +9,7 @@ use tracing::debug;

use crate::{
api::code::CodeSearchParams,
languages,
config, languages,
path::repositories_dir,
terminal::{HeaderFormat, InfoMessage},
};
Expand Down Expand Up @@ -65,6 +65,24 @@ impl Config {
.print();
}

if let Err(e) = cfg.validate_config() {
cfg = Default::default();
InfoMessage::new(
"Parsing config failed",
HeaderFormat::BoldRed,
&[
&format!(
"Warning: Could not parse the Tabby configuration at {}",
crate::path::config_file().as_path().to_string_lossy()
),
&format!("Reason: {e}"),
"Falling back to default config, please resolve the errors and restart Tabby",
],
)
.print();
process::exit(1);
}

Ok(cfg)
}

Expand All @@ -84,6 +102,30 @@ impl Config {
}
Ok(())
}

fn validate_config(&self) -> Result<()> {
Self::validate_model_config(&self.model.completion)?;
Self::validate_model_config(&self.model.chat)?;

Ok(())
}

fn validate_model_config(model_config: &Option<ModelConfig>) -> Result<()> {
if let Some(config::ModelConfig::Http(completion_http_config)) = &model_config {
if let Some(models) = &completion_http_config.supported_models {
if let Some(model_name) = &completion_http_config.model_name {
if !models.contains(model_name) {
return Err(anyhow!(
"Suppported model list does not contain model: {}",
model_name
));
}
}
}
}

Ok(())
}
}

lazy_static! {
Expand Down Expand Up @@ -412,6 +454,44 @@ mod tests {
debug_assert!(config.is_ok(), "{}", config.err().unwrap());
}

#[test]
fn it_parses_invalid_model_name_config() {
let toml_config = r#"
# Completion model
[model.completion.http]
kind = "llama.cpp/completion"
api_endpoint = "http://localhost:8888"
prompt_template = "<PRE> {prefix} <SUF>{suffix} <MID>" # Example prompt template for the CodeLlama model series.
supported_models = ["test"]
model_name = "wsxiaoys/StarCoder-1B"
# Chat model
[model.chat.http]
kind = "openai/chat"
api_endpoint = "http://localhost:8888"
supported_models = ["Qwen2-1.5B-Instruct"]
model_name = "Qwen2-1.5B-Instruct"
# Embedding model
[model.embedding.http]
kind = "llama.cpp/embedding"
api_endpoint = "http://localhost:8888"
model_name = "Qwen2-1.5B-Instruct"
"#;

let config: Config =
serdeconv::from_toml_str::<Config>(toml_config).expect("Failed to parse config");

if let Err(e) = Config::validate_model_config(&config.model.completion) {
println!("Final result: {}", e);
}

assert!(
matches!(Config::validate_model_config(&config.model.completion), Err(ref e) if true)
);
assert!(Config::validate_model_config(&config.model.chat).is_ok());
}

#[test]
fn it_parses_local_dir() {
let repo = RepositoryConfig {
Expand Down

0 comments on commit 8ffec60

Please sign in to comment.