diff --git a/.gitignore b/.gitignore index e08e5cf..4f78449 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,9 @@ pplx.py test.json test.* -app.py \ No newline at end of file +app.py + +snowchat/ + +env/ +.venv/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 48d8756..a5ae177 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -22,5 +22,6 @@ "titleBar.inactiveBackground": "#51103e99", "titleBar.inactiveForeground": "#e7e7e799" }, - "peacock.color": "#51103e" + "peacock.color": "#51103e", + "makefile.configureOnOpen": false } \ No newline at end of file diff --git a/agent.py b/agent.py index be7bea7..1c76a6a 100644 --- a/agent.py +++ b/agent.py @@ -35,29 +35,25 @@ class ModelConfig: "o3-mini": ModelConfig( model_name="o3-mini", api_key=st.secrets["OPENAI_API_KEY"] ), - "Deepseek R1": ModelConfig( - model_name="deepseek-r1-distill-llama-70b", - api_key=st.secrets["GROQ_API_KEY"], - base_url=f"https://api.groq.com/openai/v1", + "Grok 2": ModelConfig( + model_name="grok-2-latest", + api_key=st.secrets["XAI_API_KEY"], + base_url="https://api.x.ai/v1", ), - # "Mistral 7B": ModelConfig( - # model_name="mistralai/mistral-7b-v0.1", api_key=st.secrets["REPLICATE_API_TOKEN"] - # ), "Qwen 2.5": ModelConfig( model_name="accounts/fireworks/models/qwen2p5-coder-32b-instruct", api_key=st.secrets["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1", ), - "Gemini Exp 1206": ModelConfig( - model_name="gemini-exp-1206", + "Gemini 2.0 Flash": ModelConfig( + model_name="gemini-2.0-flash", api_key=st.secrets["GEMINI_API_KEY"], base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ), } sys_msg = SystemMessage( content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: - ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. - ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE. + ALWAYS USE THE Database_Schema TOOL TO GET THE SCHEMA OF THE TABLE BEFORE GENERATING SQL CODE. - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. """ @@ -94,7 +90,7 @@ def llm_agent(state: MessagesState): builder.add_edge(START, "llm_agent") builder.add_conditional_edges("llm_agent", tools_condition) builder.add_edge("tools", "llm_agent") - builder.add_edge("llm_agent", END) + # builder.add_edge("llm_agent", END) react_graph = builder.compile(checkpointer=memory) # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() diff --git a/main.py b/main.py index 3055721..78456c7 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,8 @@ import warnings import streamlit as st -from langchain_core.messages import HumanMessage from snowflake.snowpark.exceptions import SnowparkSQLException - +from langchain_core.messages import HumanMessage from agent import MessagesState, create_agent # from utils.snow_connect import SnowflakeConnection @@ -42,8 +41,8 @@ model_options = { "o3-mini": "o3-mini", "Qwen 2.5": "Qwen 2.5", - "Gemini Exp 1206": "Gemini Exp 1206", - "Deepseek R1": "Deepseek R1", + "Gemini 2.0 Flash": "Gemini 2.0 Flash", + "Grok 2": "Grok 2", } model = st.radio( diff --git a/requirements.txt b/requirements.txt index 6f7b221..5a7d766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ langchain==0.3.3 langchain_community==0.3.2 -langchain_core==0.3.33 +# langchain_core==0.3.33 langchain_openai==0.3.3 langgraph==0.2.38 pydantic==2.9.2 Requests==2.32.3 -snowflake_connector_python==3.1.0 -snowflake_snowpark_python==1.5.0 +snowflake_connector_python +snowflake_snowpark_python streamlit==1.33.0 websocket_client==1.7.0 duckduckgo_search==6.3.0 diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index d396e7a..e466fab 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -20,6 +20,7 @@ snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z" qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z" deepseek_url = image_url + "/deepseek-color.png" +grok_url = image_url + "/xAI-logo.jpg" def get_model_url(model_name): if "qwen" in model_name.lower(): @@ -38,6 +39,8 @@ def get_model_url(model_name): return gemini_url elif "deepseek" in model_name.lower(): return deepseek_url + elif "grok" in model_name.lower(): + return grok_url return mistral_url