diff --git a/modules/models/DALLE3.py b/modules/models/DALLE3.py new file mode 100644 index 00000000..e9467fb8 --- /dev/null +++ b/modules/models/DALLE3.py @@ -0,0 +1,38 @@ +import re +import json +import openai +from openai import OpenAI +from .base_model import BaseLLMModel +from .. import shared +from ..config import retrieve_proxy + + +class OpenAI_DALLE3_Client(BaseLLMModel): + def __init__(self, model_name, api_key, user_name="") -> None: + super().__init__(model_name=model_name, user=user_name) + self.api_key = api_key + + def _get_dalle3_prompt(self): + prompt = self.history[-1]["content"] + if prompt.endswith("--raw"): + prompt = "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:" + prompt + return prompt + + @shared.state.switching_api_key + def get_answer_at_once(self): + prompt = self._get_dalle3_prompt() + with retrieve_proxy(): + client = OpenAI(api_key=openai.api_key) + try: + response = client.images.generate( + model="dall-e-3", + prompt=prompt, + size="1024x1024", + quality="standard", + n=1, + ) + except openai.BadRequestError as e: + msg = str(e) + match = re.search(r"'message': '([^']*)'", msg) + return match.group(1), 0 + return f' {response.data[0].revised_prompt}', 0 diff --git a/modules/models/base_model.py b/modules/models/base_model.py index eaeada99..52fcef34 100644 --- a/modules/models/base_model.py +++ b/modules/models/base_model.py @@ -153,6 +153,7 @@ class ModelType(Enum): Qwen = 15 OpenAIVision = 16 ERNIE = 17 + DALLE3 = 18 @classmethod def get_type(cls, model_name: str): @@ -195,6 +196,8 @@ def get_type(cls, model_name: str): model_type = ModelType.Qwen elif "ernie" in model_name_lower: model_type = ModelType.ERNIE + elif "dall" in model_name_lower: + model_type = ModelType.DALLE3 else: model_type = ModelType.LLaMA return model_type diff --git a/modules/models/models.py b/modules/models/models.py index c5f0767f..9e67c8cd 100644 --- a/modules/models/models.py +++ b/modules/models/models.py @@ -129,6 +129,10 @@ def get_model( elif model_type == ModelType.ERNIE: from .ERNIE import ERNIE_Client model = ERNIE_Client(model_name, api_key=os.getenv("ERNIE_APIKEY"),secret_key=os.getenv("ERNIE_SECRETKEY")) + elif model_type == ModelType.DALLE3: + from .DALLE3 import OpenAI_DALLE3_Client + access_key = os.environ.get("OPENAI_API_KEY", access_key) + model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name) elif model_type == ModelType.Unknown: raise ValueError(f"未知模型: {model_name}") logging.info(msg) diff --git a/modules/presets.py b/modules/presets.py index 8266e684..49b2894f 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -62,6 +62,7 @@ "GPT4 Vision", "川虎助理", "川虎助理 Pro", + "DALL-E 3", "GooglePaLM", "xmchat", "Azure OpenAI",