Skip to content

Commit

Permalink
o3 mini
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Jan 31, 2025
1 parent 999635b commit cf6418a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[theme]
base = "dark"
31 changes: 14 additions & 17 deletions agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import os
import streamlit as st
from dataclasses import dataclass
from typing import Annotated, Sequence, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langgraph.graph import START, END, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage
from langchain_community.llms import Replicate

from tools import retriever_tool
from tools import search, sql_executor_tool
Expand All @@ -35,13 +32,13 @@ class ModelConfig:


model_configurations = {
"gpt-4o": ModelConfig(
model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"]
"o3-mini": ModelConfig(
model_name="o3-mini", api_key=st.secrets["OPENAI_API_KEY"]
),
"Gemini Flash 1.5": ModelConfig(
model_name="gemini-1.5-flash",
api_key=st.secrets["GEMINI_API_KEY"],
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
"Deepseek R1": ModelConfig(
model_name="deepseek-r1-distill-llama-70b",
api_key=st.secrets["GROQ_API_KEY"],
base_url=f"https://api.groq.com/openai/v1",
),
# "Mistral 7B": ModelConfig(
# model_name="mistralai/mistral-7b-v0.1", api_key=st.secrets["REPLICATE_API_TOKEN"]
Expand All @@ -59,14 +56,13 @@ class ModelConfig:
}
sys_msg = SystemMessage(
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE.
ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE.
- Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
- Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
- Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way.
Make sure to always return both the SQL code and the result of the query
"""
)
tools = [retriever_tool, search, sql_executor_tool]
tools = [retriever_tool, search]

def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph:
config = model_configurations.get(model_name)
Expand All @@ -82,7 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> Stat
callbacks=[callback_handler],
streaming=True,
base_url=config.base_url,
temperature=0.01,
# temperature=0.1,
default_headers={"HTTP-Referer": "https://snowchat.streamlit.app/", "X-Title": "Snowchat"},
)

Expand All @@ -94,14 +90,15 @@ def llm_agent(state: MessagesState):
builder = StateGraph(MessagesState)
builder.add_node("llm_agent", llm_agent)
builder.add_node("tools", ToolNode(tools))

builder.add_edge(START, "llm_agent")
builder.add_conditional_edges("llm_agent", tools_condition)
builder.add_edge("tools", "llm_agent")

builder.add_edge("llm_agent", END)
react_graph = builder.compile(checkpointer=memory)

# png_data = react_graph.get_graph(xray=True).draw_mermaid_png()
# with open("graph.png", "wb") as f:
# with open("graph_2.png", "wb") as f:
# f.write(png_data)

# image = Image.open(BytesIO(png_data))
Expand Down
31 changes: 18 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,33 @@

gradient_text_html = """
<style>
.gradient-text {
font-weight: bold;
background: -webkit-linear-gradient(left, red, orange);
background: linear-gradient(to right, red, orange);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
display: inline;
font-size: 3em;
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@700;900&display=swap');
.snowchat-title {
font-family: 'Poppins', sans-serif;
font-weight: 900;
font-size: 4em;
background: linear-gradient(90deg, #ff6a00, #ee0979);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
margin: 0;
padding: 20px 0;
text-align: center;
}
</style>
<div class="gradient-text">snowChat</div>
<div class="snowchat-title">snowChat</div>
"""

st.markdown(gradient_text_html, unsafe_allow_html=True)

st.caption("Talk your way through data")

model_options = {
"gpt-4o": "GPT-4o",
"o3-mini": "o3-mini",
"Qwen 2.5": "Qwen 2.5",
"Gemini Exp 1206": "Gemini Exp 1206",
"Gemini Flash 1.5": "Gemini Flash 1.5",
"Deepseek R1": "Deepseek R1",
}

model = st.radio(
Expand Down Expand Up @@ -69,8 +74,8 @@
st.toast("Probably rate limited.. Go easy folks", icon="⚠️")
st.session_state["rate-limit"] = False

if st.session_state["model"] == "Mixtral 8x7B":
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")
if st.session_state["model"] == "Deepseek R1":
st.warning("Deepseek R1 is highly rate limited. Please use it sparingly", icon="⚠️")

INITIAL_MESSAGE = [
{"role": "user", "content": "Hi!"},
Expand Down
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
langchain==0.3.3
langchain_anthropic==0.2.3
langchain_community==0.3.2
langchain_core==0.3.12
langchain_openai==0.2.2
langchain-google-genai==2.0.1
langchain_core==0.3.33
langchain_openai==0.3.3
langgraph==0.2.38
pydantic==2.9.2
Requests==2.32.3
Expand Down
15 changes: 8 additions & 7 deletions ui/styles.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
z-index: 100;
}
h1, h2 {
font-weight: bold;
background: -webkit-linear-gradient(left, red, orange);
background: linear-gradient(to right, red, orange);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
display: inline;
font-size: 3em;
font-family: 'Poppins', sans-serif;
font-weight: 900;
font-size: 3em;
background: linear-gradient(90deg, #ff6a00, #ee0979);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
display: inline;
}
.user-avatar {
float: right;
Expand Down
6 changes: 4 additions & 2 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z"
snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z"
qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z"

deepseek_url = image_url + "/deepseek-color.png"

def get_model_url(model_name):
if "qwen" in model_name.lower():
Expand All @@ -32,10 +32,12 @@ def get_model_url(model_name):
return gemini_url
elif "arctic" in model_name.lower():
return snow_url
elif "gpt" in model_name.lower():
elif "gpt" in model_name.lower() or "o3" in model_name.lower():
return openai_url
elif "gemini" in model_name.lower():
return gemini_url
elif "deepseek" in model_name.lower():
return deepseek_url
return mistral_url


Expand Down

0 comments on commit cf6418a

Please sign in to comment.