From cf6418a04f453d8c3b642f2517968d20beb8cc13 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sat, 1 Feb 2025 10:24:11 +1300 Subject: [PATCH] o3 mini --- .streamlit/config.toml | 2 ++ agent.py | 31 ++++++++++++++----------------- main.py | 31 ++++++++++++++++++------------- requirements.txt | 6 ++---- ui/styles.md | 15 ++++++++------- utils/snowchat_ui.py | 6 ++++-- 6 files changed, 48 insertions(+), 43 deletions(-) create mode 100644 .streamlit/config.toml diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000..ca4e7d7 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,2 @@ +[theme] +base = "dark" \ No newline at end of file diff --git a/agent.py b/agent.py index 1cbd3f0..be7bea7 100644 --- a/agent.py +++ b/agent.py @@ -1,18 +1,15 @@ -import os import streamlit as st from dataclasses import dataclass from typing import Annotated, Sequence, Optional from langchain.callbacks.base import BaseCallbackHandler -from langchain_anthropic import ChatAnthropic from langchain_core.messages import SystemMessage from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph import START, StateGraph +from langgraph.graph import START, END, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage -from langchain_community.llms import Replicate from tools import retriever_tool from tools import search, sql_executor_tool @@ -35,13 +32,13 @@ class ModelConfig: model_configurations = { - "gpt-4o": ModelConfig( - model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"] + "o3-mini": ModelConfig( + model_name="o3-mini", api_key=st.secrets["OPENAI_API_KEY"] ), - "Gemini Flash 1.5": ModelConfig( - model_name="gemini-1.5-flash", - api_key=st.secrets["GEMINI_API_KEY"], - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + "Deepseek R1": ModelConfig( + model_name="deepseek-r1-distill-llama-70b", + api_key=st.secrets["GROQ_API_KEY"], + base_url=f"https://api.groq.com/openai/v1", ), # "Mistral 7B": ModelConfig( # model_name="mistralai/mistral-7b-v0.1", api_key=st.secrets["REPLICATE_API_TOKEN"] @@ -59,14 +56,13 @@ class ModelConfig: } sys_msg = SystemMessage( content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: + ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. + ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. - - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way. - - Make sure to always return both the SQL code and the result of the query """ ) -tools = [retriever_tool, search, sql_executor_tool] +tools = [retriever_tool, search] def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: config = model_configurations.get(model_name) @@ -82,7 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> Stat callbacks=[callback_handler], streaming=True, base_url=config.base_url, - temperature=0.01, + # temperature=0.1, default_headers={"HTTP-Referer": "https://snowchat.streamlit.app/", "X-Title": "Snowchat"}, ) @@ -94,14 +90,15 @@ def llm_agent(state: MessagesState): builder = StateGraph(MessagesState) builder.add_node("llm_agent", llm_agent) builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "llm_agent") builder.add_conditional_edges("llm_agent", tools_condition) builder.add_edge("tools", "llm_agent") - + builder.add_edge("llm_agent", END) react_graph = builder.compile(checkpointer=memory) # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() - # with open("graph.png", "wb") as f: + # with open("graph_2.png", "wb") as f: # f.write(png_data) # image = Image.open(BytesIO(png_data)) diff --git a/main.py b/main.py index 2b7c2c9..3055721 100644 --- a/main.py +++ b/main.py @@ -17,17 +17,22 @@ gradient_text_html = """ -
snowChat
+
snowChat
""" st.markdown(gradient_text_html, unsafe_allow_html=True) @@ -35,10 +40,10 @@ st.caption("Talk your way through data") model_options = { - "gpt-4o": "GPT-4o", + "o3-mini": "o3-mini", "Qwen 2.5": "Qwen 2.5", "Gemini Exp 1206": "Gemini Exp 1206", - "Gemini Flash 1.5": "Gemini Flash 1.5", + "Deepseek R1": "Deepseek R1", } model = st.radio( @@ -69,8 +74,8 @@ st.toast("Probably rate limited.. Go easy folks", icon="⚠️") st.session_state["rate-limit"] = False -if st.session_state["model"] == "Mixtral 8x7B": - st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️") +if st.session_state["model"] == "Deepseek R1": + st.warning("Deepseek R1 is highly rate limited. Please use it sparingly", icon="⚠️") INITIAL_MESSAGE = [ {"role": "user", "content": "Hi!"}, diff --git a/requirements.txt b/requirements.txt index 40338d5..6f7b221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ langchain==0.3.3 -langchain_anthropic==0.2.3 langchain_community==0.3.2 -langchain_core==0.3.12 -langchain_openai==0.2.2 -langchain-google-genai==2.0.1 +langchain_core==0.3.33 +langchain_openai==0.3.3 langgraph==0.2.38 pydantic==2.9.2 Requests==2.32.3 diff --git a/ui/styles.md b/ui/styles.md index 79238e4..f9a5990 100644 --- a/ui/styles.md +++ b/ui/styles.md @@ -12,13 +12,14 @@ z-index: 100; } h1, h2 { - font-weight: bold; - background: -webkit-linear-gradient(left, red, orange); - background: linear-gradient(to right, red, orange); - -webkit-background-clip: text; - -webkit-text-fill-color: transparent; - display: inline; - font-size: 3em; + font-family: 'Poppins', sans-serif; + font-weight: 900; + font-size: 3em; + background: linear-gradient(90deg, #ff6a00, #ee0979); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + text-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3); + display: inline; } .user-avatar { float: right; diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 05db4b1..d396e7a 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -19,7 +19,7 @@ meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z" snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z" qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z" - +deepseek_url = image_url + "/deepseek-color.png" def get_model_url(model_name): if "qwen" in model_name.lower(): @@ -32,10 +32,12 @@ def get_model_url(model_name): return gemini_url elif "arctic" in model_name.lower(): return snow_url - elif "gpt" in model_name.lower(): + elif "gpt" in model_name.lower() or "o3" in model_name.lower(): return openai_url elif "gemini" in model_name.lower(): return gemini_url + elif "deepseek" in model_name.lower(): + return deepseek_url return mistral_url