diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000..ca4e7d7 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,2 @@ +[theme] +base = "dark" \ No newline at end of file diff --git a/agent.py b/agent.py index 1cbd3f0..be7bea7 100644 --- a/agent.py +++ b/agent.py @@ -1,18 +1,15 @@ -import os import streamlit as st from dataclasses import dataclass from typing import Annotated, Sequence, Optional from langchain.callbacks.base import BaseCallbackHandler -from langchain_anthropic import ChatAnthropic from langchain_core.messages import SystemMessage from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph import START, StateGraph +from langgraph.graph import START, END, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from langgraph.graph.message import add_messages from langchain_core.messages import BaseMessage -from langchain_community.llms import Replicate from tools import retriever_tool from tools import search, sql_executor_tool @@ -35,13 +32,13 @@ class ModelConfig: model_configurations = { - "gpt-4o": ModelConfig( - model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"] + "o3-mini": ModelConfig( + model_name="o3-mini", api_key=st.secrets["OPENAI_API_KEY"] ), - "Gemini Flash 1.5": ModelConfig( - model_name="gemini-1.5-flash", - api_key=st.secrets["GEMINI_API_KEY"], - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + "Deepseek R1": ModelConfig( + model_name="deepseek-r1-distill-llama-70b", + api_key=st.secrets["GROQ_API_KEY"], + base_url=f"https://api.groq.com/openai/v1", ), # "Mistral 7B": ModelConfig( # model_name="mistralai/mistral-7b-v0.1", api_key=st.secrets["REPLICATE_API_TOKEN"] @@ -59,14 +56,13 @@ class ModelConfig: } sys_msg = SystemMessage( content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: + ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. + ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. - - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way. - - Make sure to always return both the SQL code and the result of the query """ ) -tools = [retriever_tool, search, sql_executor_tool] +tools = [retriever_tool, search] def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: config = model_configurations.get(model_name) @@ -82,7 +78,7 @@ def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> Stat callbacks=[callback_handler], streaming=True, base_url=config.base_url, - temperature=0.01, + # temperature=0.1, default_headers={"HTTP-Referer": "https://snowchat.streamlit.app/", "X-Title": "Snowchat"}, ) @@ -94,14 +90,15 @@ def llm_agent(state: MessagesState): builder = StateGraph(MessagesState) builder.add_node("llm_agent", llm_agent) builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "llm_agent") builder.add_conditional_edges("llm_agent", tools_condition) builder.add_edge("tools", "llm_agent") - + builder.add_edge("llm_agent", END) react_graph = builder.compile(checkpointer=memory) # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() - # with open("graph.png", "wb") as f: + # with open("graph_2.png", "wb") as f: # f.write(png_data) # image = Image.open(BytesIO(png_data)) diff --git a/main.py b/main.py index 2b7c2c9..3055721 100644 --- a/main.py +++ b/main.py @@ -17,17 +17,22 @@ gradient_text_html = """ -