diff --git a/.gitignore b/.gitignore
index 33b021f..e08e5cf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,10 @@ archived_logs/
build/
snowchat.egg-info/
-chroma_db
\ No newline at end of file
+chroma_db
+
+pplx.py
+
+test.json
+test.*
+app.py
\ No newline at end of file
diff --git a/chain.py b/chain.py
index 7505f36..edd9b8b 100644
--- a/chain.py
+++ b/chain.py
@@ -1,8 +1,7 @@
from typing import Any, Callable, Dict, Optional
-import boto3
import streamlit as st
-from langchain.chat_models import BedrockChat, ChatOpenAI
+from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import SupabaseVectorStore
@@ -34,7 +33,7 @@ class ModelConfig(BaseModel):
@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
- if v not in ["gpt", "codellama", "mistral"]:
+ if v not in ["gpt", "mistral", "gemini"]:
raise ValueError(f"Unsupported model type: {v}")
return v
@@ -53,8 +52,8 @@ def __init__(self, config: ModelConfig):
def setup(self):
if self.model_type == "gpt":
self.setup_gpt()
- elif self.model_type == "codellama":
- self.setup_codellama()
+ elif self.model_type == "gemini":
+ self.setup_gemini()
elif self.model_type == "mistral":
self.setup_mixtral()
@@ -63,7 +62,7 @@ def setup_gpt(self):
model_name="gpt-3.5-turbo-0125",
temperature=0.2,
api_key=self.secrets["OPENAI_API_KEY"],
- max_tokens=500,
+ max_tokens=1000,
callbacks=[self.callback_handler],
streaming=True,
base_url=self.gateway_url,
@@ -71,51 +70,30 @@ def setup_gpt(self):
def setup_mixtral(self):
self.llm = ChatOpenAI(
- model_name="mistralai/mistral-medium",
+ model_name="mixtral-8x7b-32768",
temperature=0.2,
- api_key=self.secrets["OPENROUTER_API_KEY"],
- max_tokens=500,
+ api_key=self.secrets["GROQ_API_KEY"],
+ max_tokens=3000,
callbacks=[self.callback_handler],
streaming=True,
- base_url="https://openrouter.ai/api/v1",
+ base_url="https://api.groq.com/openai/v1",
)
- def setup_codellama(self):
+ def setup_gemini(self):
self.llm = ChatOpenAI(
- model_name="codellama/codellama-70b-instruct",
+ model_name="google/gemini-pro",
temperature=0.2,
api_key=self.secrets["OPENROUTER_API_KEY"],
- max_tokens=500,
+ max_tokens=1200,
callbacks=[self.callback_handler],
streaming=True,
base_url="https://openrouter.ai/api/v1",
+ default_headers={
+ "HTTP-Referer": "https://snowchat.streamlit.app/",
+ "X-Title": "Snowchat",
+ },
)
- # def setup_claude(self):
- # bedrock_runtime = boto3.client(
- # service_name="bedrock-runtime",
- # aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
- # aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
- # region_name="us-east-1",
- # )
- # parameters = {
- # "max_tokens_to_sample": 1000,
- # "stop_sequences": [],
- # "temperature": 0,
- # "top_p": 0.9,
- # }
- # self.q_llm = BedrockChat(
- # model_id="anthropic.claude-instant-v1", client=bedrock_runtime
- # )
-
- # self.llm = BedrockChat(
- # model_id="anthropic.claude-instant-v1",
- # client=bedrock_runtime,
- # callbacks=[self.callback_handler],
- # streaming=True,
- # model_kwargs=parameters,
- # )
-
def get_chain(self, vectorstore):
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
@@ -153,12 +131,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
query_name="v_match_documents",
)
- if "codellama" in model_name.lower():
- model_type = "codellama"
- elif "GPT-3.5" in model_name:
+ if "GPT-3.5" in model_name:
model_type = "gpt"
elif "mistral" in model_name.lower():
model_type = "mistral"
+ elif "gemini" in model_name.lower():
+ model_type = "gemini"
else:
raise ValueError(f"Unsupported model name: {model_name}")
diff --git a/main.py b/main.py
index 157e63a..894db25 100644
--- a/main.py
+++ b/main.py
@@ -14,16 +14,43 @@
chat_history = []
snow_ddl = Snowddl()
-st.title("snowChat")
+gradient_text_html = """
+
+
snowChat
+"""
+
+st.markdown(gradient_text_html, unsafe_allow_html=True)
+
st.caption("Talk your way through data")
model = st.radio(
"",
- options=["β¨ GPT-3.5", "βΎοΈ codellama", "π Mistral"],
+ options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"],
index=0,
horizontal=True,
)
st.session_state["model"] = model
+if "toast_shown" not in st.session_state:
+ st.session_state["toast_shown"] = False
+
+# Show the toast only if it hasn't been shown before
+if not st.session_state["toast_shown"]:
+ st.toast("The snowflake data retrieval is disabled for now.", icon="π")
+ st.session_state["toast_shown"] = True
+
+if st.session_state["model"] == "π Mistral 8x7B - Groq":
+ st.warning("This is highly rate-limited. Please use it sparingly", icon="β οΈ")
+
INITIAL_MESSAGE = [
{"role": "user", "content": "Hi!"},
{
@@ -38,10 +65,8 @@
with open("ui/styles.md", "r") as styles_file:
styles_content = styles_file.read()
-# Display the DDL for the selected table
st.sidebar.markdown(sidebar_content)
-# Create a sidebar with a dropdown menu
selected_table = st.sidebar.selectbox(
"Select a table:", options=list(snow_ddl.ddl_dict.keys())
)
@@ -81,9 +106,10 @@
message["content"],
True if message["role"] == "user" else False,
True if message["role"] == "data" else False,
+ model,
)
-callback_handler = StreamlitUICallbackHandler()
+callback_handler = StreamlitUICallbackHandler(model)
chain = load_chain(st.session_state["model"], callback_handler)
diff --git a/ui/styles.md b/ui/styles.md
index de84444..79238e4 100644
--- a/ui/styles.md
+++ b/ui/styles.md
@@ -11,8 +11,14 @@
background-color: white;
z-index: 100;
}
- h1 {
- font-family: 'Roboto Slab', serif;
+ h1, h2 {
+ font-weight: bold;
+ background: -webkit-linear-gradient(left, red, orange);
+ background: linear-gradient(to right, red, orange);
+ -webkit-background-clip: text;
+ -webkit-text-fill-color: transparent;
+ display: inline;
+ font-size: 3em;
}
.user-avatar {
float: right;
diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index d630bb4..30e3446 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -4,6 +4,24 @@
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
+image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
+gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z"
+mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png"
+openai_url = (
+ image_url
+ + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
+)
+
+
+def get_model_url(model_name):
+ if "gpt" in model_name.lower():
+ return openai_url
+ elif "gemini" in model_name.lower():
+ return gemini_url
+ elif "mistral" in model_name.lower():
+ return mistral_url
+ return mistral_url
+
def format_message(text):
"""
@@ -26,7 +44,7 @@ def format_message(text):
return formatted_text
-def message_func(text, is_user=False, is_df=False):
+def message_func(text, is_user=False, is_df=False, model="gpt"):
"""
This function is used to display the messages in the chatbot UI.
@@ -35,6 +53,9 @@ def message_func(text, is_user=False, is_df=False):
is_user (bool): Whether the message is from the user or not.
is_df (bool): Whether the message is a dataframe or not.
"""
+ model_url = get_model_url(model)
+
+ avatar_url = model_url
if is_user:
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned"
message_alignment = "flex-end"
@@ -45,13 +66,12 @@ def message_func(text, is_user=False, is_df=False):
{text} \n
-

+
""",
unsafe_allow_html=True,
)
else:
- avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "#71797E"
avatar_class = "bot-avatar"
@@ -60,7 +80,7 @@ def message_func(text, is_user=False, is_df=False):
st.write(
f"""
-

+
""",
unsafe_allow_html=True,
@@ -73,8 +93,8 @@ def message_func(text, is_user=False, is_df=False):
st.write(
f"""
-

-
+

+
{text} \n
""",
@@ -83,11 +103,13 @@ def message_func(text, is_user=False, is_df=False):
class StreamlitUICallbackHandler(BaseCallbackHandler):
- def __init__(self):
+ def __init__(self, model):
self.token_buffer = []
self.placeholder = st.empty()
self.has_streaming_ended = False
self.has_streaming_started = False
+ self.model = model
+ self.avatar_url = get_model_url(model)
def start_loading_message(self):
loading_message_content = self._get_bot_message_container("Thinking...")
@@ -109,17 +131,11 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
def _get_bot_message_container(self, text):
"""Generate the bot's message container style for the given text."""
- avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
- message_alignment = "flex-start"
- message_bg_color = "#71797E"
- avatar_class = "bot-avatar"
- formatted_text = format_message(
- text
- ) # Ensure this handles "Thinking..." appropriately.
+ formatted_text = format_message(text)
container_content = f"""
-
-

-
+
+

+
{formatted_text} \n
"""
@@ -129,14 +145,13 @@ def display_dataframe(self, df):
"""
Display the dataframe in Streamlit UI within the chat container.
"""
- avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
avatar_class = "bot-avatar"
st.write(
f"""
-

+
""",
unsafe_allow_html=True,