Skip to content

Commit

Permalink
Added support for Groq, the super fast inference service.
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Apr 23, 2024
1 parent 70118ca commit 921af92
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 0 deletions.
1 change: 1 addition & 0 deletions config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
"ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
"huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
"groq_api_key": "", // 你的 Groq API Key,用于 Groq 对话模型(https://console.groq.com/)

//== Azure ==
"openai_api_type": "openai", // 可选项:azure, openai
Expand Down
3 changes: 3 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def load_config_to_environ(key_list):
ollama_host = config.get("ollama_host", "")
os.environ["OLLAMA_HOST"] = ollama_host

groq_api_key = config.get("groq_api_key", "")
os.environ["GROQ_API_KEY"] = groq_api_key

load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])

Expand Down
55 changes: 55 additions & 0 deletions modules/models/Groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import logging
import textwrap
import uuid

import os
from groq import Groq
import gradio as gr
import PIL
import requests

from modules.presets import i18n

from ..index_func import construct_index
from ..utils import count_token, construct_system
from .base_model import BaseLLMModel


class Groq_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key
self.client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)

def _get_groq_style_input(self):
messages = [construct_system(self.system_prompt), *self.history]
return messages

def get_answer_at_once(self):
messages = self._get_groq_style_input()
chat_completion = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
)
return chat_completion.choices[0].message.content, chat_completion.usage.total_tokens


def get_answer_stream_iter(self):
messages = self._get_groq_style_input()
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_generation_token,
top_p=self.top_p,
stream=True,
stop=self.stop_sequence,
)

partial_text = ""
for chunk in completion:
partial_text += chunk.choices[0].delta.content or ""
yield partial_text
3 changes: 3 additions & 0 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class ModelType(Enum):
GoogleGemini = 19
GoogleGemma = 20
Ollama = 21
Groq = 22

@classmethod
def get_type(cls, model_name: str):
Expand All @@ -173,6 +174,8 @@ def get_type(cls, model_name: str):
model_type = ModelType.OpenAI
elif "chatglm" in model_name_lower:
model_type = ModelType.ChatGLM
elif "groq" in model_name_lower:
model_type = ModelType.Groq
elif "ollama" in model_name_lower:
model_type = ModelType.Ollama
elif "llama" in model_name_lower or "alpaca" in model_name_lower:
Expand Down
4 changes: 4 additions & 0 deletions modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def get_model(
logging.info(f"正在加载ChatGLM模型: {model_name}")
from .ChatGLM import ChatGLM_Client
model = ChatGLM_Client(model_name, user_name=user_name)
elif model_type == ModelType.Groq:
logging.info(f"正在加载Groq模型: {model_name}")
from .Groq import Groq_Client
model = Groq_Client(model_name, access_key, user_name=user_name)
elif model_type == ModelType.LLaMA and lora_model_path == "":
msg = f"现在请为 {model_name} 选择LoRA模型"
logging.info(msg)
Expand Down
25 changes: 25 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@
"DALL-E 3",
"Gemini Pro",
"Gemini Pro Vision",
"Groq LLaMA3 8B",
"Groq LLaMA3 70B",
"Groq LLaMA2 70B",
"Groq Mixtral 8x7B",
"Groq Gemma 7B",
"GooglePaLM",
"Gemma 2B",
"Gemma 7B",
Expand Down Expand Up @@ -218,6 +223,26 @@
"repo_id": "google/gemma-7b-it",
"model_name": "gemma-7b-it",
"token_limit": 8192,
},
"Groq LLaMA3 8B": {
"model_name": "llama3-8b-8192",
"token_limit": 8192,
},
"Groq LLaMA3 70B": {
"model_name": "llama3-70b-8192",
"token_limit": 8192,
},
"Groq LLaMA2 70B": {
"model_name": "llama2-70b-4096",
"token_limit": 4096,
},
"Groq Mixtral 8x7B": {
"model_name": "mixtral-8x7b-32768",
"token_limit": 32768,
},
"Groq Gemma 7B": {
"model_name": "gemma-7b-it",
"token_limit": 8192,
}
}

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ langchain==0.1.14
langchain-openai
langchainhub
langchain_community
groq
markdown
PyPDF2
pdfplumber
Expand Down

0 comments on commit 921af92

Please sign in to comment.