-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for Groq, the super fast inference service.
- Loading branch information
1 parent
70118ca
commit 921af92
Showing
7 changed files
with
92 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import json | ||
import logging | ||
import textwrap | ||
import uuid | ||
|
||
import os | ||
from groq import Groq | ||
import gradio as gr | ||
import PIL | ||
import requests | ||
|
||
from modules.presets import i18n | ||
|
||
from ..index_func import construct_index | ||
from ..utils import count_token, construct_system | ||
from .base_model import BaseLLMModel | ||
|
||
|
||
class Groq_Client(BaseLLMModel): | ||
def __init__(self, model_name, api_key, user_name="") -> None: | ||
super().__init__(model_name=model_name, user=user_name) | ||
self.api_key = api_key | ||
self.client = Groq( | ||
api_key=os.environ.get("GROQ_API_KEY"), | ||
) | ||
|
||
def _get_groq_style_input(self): | ||
messages = [construct_system(self.system_prompt), *self.history] | ||
return messages | ||
|
||
def get_answer_at_once(self): | ||
messages = self._get_groq_style_input() | ||
chat_completion = self.client.chat.completions.create( | ||
messages=messages, | ||
model=self.model_name, | ||
) | ||
return chat_completion.choices[0].message.content, chat_completion.usage.total_tokens | ||
|
||
|
||
def get_answer_stream_iter(self): | ||
messages = self._get_groq_style_input() | ||
completion = self.client.chat.completions.create( | ||
model=self.model_name, | ||
messages=messages, | ||
temperature=self.temperature, | ||
max_tokens=self.max_generation_token, | ||
top_p=self.top_p, | ||
stream=True, | ||
stop=self.stop_sequence, | ||
) | ||
|
||
partial_text = "" | ||
for chunk in completion: | ||
partial_text += chunk.choices[0].delta.content or "" | ||
yield partial_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ langchain==0.1.14 | |
langchain-openai | ||
langchainhub | ||
langchain_community | ||
groq | ||
markdown | ||
PyPDF2 | ||
pdfplumber | ||
|