diff --git a/modules/models/Qwen.py b/modules/models/Qwen.py index f5fc8d1b..a6cc9ad1 100644 --- a/modules/models/Qwen.py +++ b/modules/models/Qwen.py @@ -1,4 +1,5 @@ from transformers import AutoModelForCausalLM, AutoTokenizer +import os from transformers.generation import GenerationConfig import logging import colorama @@ -9,8 +10,18 @@ class Qwen_Client(BaseLLMModel): def __init__(self, model_name, user_name="") -> None: super().__init__(model_name=model_name, user=user_name) - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True) - self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval() + model_source = None + if os.path.exists("models"): + model_dirs = os.listdir("models") + if model_name in model_dirs: + model_source = f"models/{model_name}" + if model_source is None: + try: + model_source = MODEL_METADATA[model_name]["repo_id"] + except KeyError: + model_source = model_name + self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True, resume_download=True) + self.model = AutoModelForCausalLM.from_pretrained(model_source, device_map="auto", trust_remote_code=True, resume_download=True).eval() def generation_config(self): return GenerationConfig.from_dict({