diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 97e0250..6c7256e 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -4,26 +4,39 @@ on:
push:
branches:
- main
- - prod
pull_request:
branches:
- main
- - prod
jobs:
lint:
+ name: Lint and Format Code
runs-on: ubuntu-latest
+
steps:
- name: Check out repository
uses: actions/checkout@v3
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.9"
+
+ - name: Cache pip dependencies
+ uses: actions/cache@v3
with:
- python-version: 3.9
+ path: ~/.cache/pip
+ key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-
- name: Install dependencies
- run: pip install black
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install black ruff mypy codespell
- - name: Lint with black
- run: black --check .
+ - name: Run Formatting and Linting
+ run: |
+ make format
+ make lint
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..a768e53
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,57 @@
+.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix
+
+# Define a variable for Python and notebook files.
+PYTHON_FILES=src/
+MYPY_CACHE=.mypy_cache
+
+######################
+# LINTING AND FORMATTING
+######################
+
+lint format: PYTHON_FILES=.
+lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
+lint_package: PYTHON_FILES=src
+lint_tests: PYTHON_FILES=tests
+lint_tests: MYPY_CACHE=.mypy_cache_test
+
+lint lint_diff lint_package lint_tests:
+ python -m ruff check .
+ [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff
+ [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES)
+ [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES)
+ [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
+
+format format_diff:
+ ruff format $(PYTHON_FILES)
+ ruff check --fix $(PYTHON_FILES)
+
+spell_check:
+ codespell --toml pyproject.toml
+
+spell_fix:
+ codespell --toml pyproject.toml -w
+
+######################
+# RUN ALL
+######################
+
+all: format lint spell_check
+
+######################
+# HELP
+######################
+
+help:
+ @echo '----'
+ @echo 'format - run code formatters'
+ @echo 'lint - run linters'
+ @echo 'spell_check - run spell check'
+ @echo 'all - run all tasks'
+ @echo 'lint-fix - run lint and fix issues'
+
+######################
+# LINT-FIX TARGET
+######################
+
+lint-fix: format lint
+ @echo "Linting and fixing completed successfully."
diff --git a/agent.py b/agent.py
new file mode 100644
index 0000000..a7a7462
--- /dev/null
+++ b/agent.py
@@ -0,0 +1,102 @@
+import os
+from dataclasses import dataclass
+from typing import Annotated, Sequence, Optional
+
+from langchain.callbacks.base import BaseCallbackHandler
+from langchain_anthropic import ChatAnthropic
+from langchain_core.messages import SystemMessage
+from langchain_openai import ChatOpenAI
+from langgraph.checkpoint.memory import MemorySaver
+from langgraph.graph import START, StateGraph
+from langgraph.prebuilt import ToolNode, tools_condition
+from langgraph.graph.message import add_messages
+from langchain_core.messages import BaseMessage
+
+from template import TEMPLATE
+from tools import retriever_tool
+
+
+@dataclass
+class MessagesState:
+ messages: Annotated[Sequence[BaseMessage], add_messages]
+
+
+memory = MemorySaver()
+
+
+@dataclass
+class ModelConfig:
+ model_name: str
+ api_key: str
+ base_url: Optional[str] = None
+
+
+def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
+ model_configurations = {
+ "gpt-4o-mini": ModelConfig(
+ model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")
+ ),
+ "gemma2-9b": ModelConfig(
+ model_name="gemma2-9b-it",
+ api_key=os.getenv("GROQ_API_KEY"),
+ base_url="https://api.groq.com/openai/v1",
+ ),
+ "claude3-haiku": ModelConfig(
+ model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
+ ),
+ "mixtral-8x22b": ModelConfig(
+ model_name="accounts/fireworks/models/mixtral-8x22b-instruct",
+ api_key=os.getenv("FIREWORKS_API_KEY"),
+ base_url="https://api.fireworks.ai/inference/v1",
+ ),
+ "llama-3.1-405b": ModelConfig(
+ model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
+ api_key=os.getenv("FIREWORKS_API_KEY"),
+ base_url="https://api.fireworks.ai/inference/v1",
+ ),
+ }
+ config = model_configurations.get(model_name)
+ if not config:
+ raise ValueError(f"Unsupported model name: {model_name}")
+
+ sys_msg = SystemMessage(
+ content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
+ Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code.
+ """
+ )
+
+ llm = (
+ ChatOpenAI(
+ model=config.model_name,
+ api_key=config.api_key,
+ callbacks=[callback_handler],
+ streaming=True,
+ base_url=config.base_url,
+ )
+ if config.model_name != "claude-3-haiku-20240307"
+ else ChatAnthropic(
+ model=config.model_name,
+ api_key=config.api_key,
+ callbacks=[callback_handler],
+ streaming=True,
+ )
+ )
+
+ tools = [retriever_tool]
+
+ llm_with_tools = llm.bind_tools(tools)
+
+ def reasoner(state: MessagesState):
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
+
+ # Build the graph
+ builder = StateGraph(MessagesState)
+ builder.add_node("reasoner", reasoner)
+ builder.add_node("tools", ToolNode(tools))
+ builder.add_edge(START, "reasoner")
+ builder.add_conditional_edges("reasoner", tools_condition)
+ builder.add_edge("tools", "reasoner")
+
+ react_graph = builder.compile(checkpointer=memory)
+
+ return react_graph
diff --git a/chain.py b/chain.py
index a16b862..a0c2595 100644
--- a/chain.py
+++ b/chain.py
@@ -1,155 +1,154 @@
-from typing import Any, Callable, Dict, Optional
-
-import streamlit as st
-from langchain_community.chat_models import ChatOpenAI
-from langchain.embeddings.openai import OpenAIEmbeddings
-from langchain.llms import OpenAI
-from langchain.vectorstores import SupabaseVectorStore
-from pydantic import BaseModel, validator
-from supabase.client import Client, create_client
-
-from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
-
-from operator import itemgetter
-
-from langchain.prompts.prompt import PromptTemplate
-from langchain.schema import format_document
-from langchain_core.messages import get_buffer_string
-from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables import RunnableParallel, RunnablePassthrough
-from langchain_openai import ChatOpenAI, OpenAIEmbeddings
-from langchain_anthropic import ChatAnthropic
-
-DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
-
-supabase_url = st.secrets["SUPABASE_URL"]
-supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
-supabase: Client = create_client(supabase_url, supabase_key)
-
-
-class ModelConfig(BaseModel):
- model_type: str
- secrets: Dict[str, Any]
- callback_handler: Optional[Callable] = None
-
-
-class ModelWrapper:
- def __init__(self, config: ModelConfig):
- self.model_type = config.model_type
- self.secrets = config.secrets
- self.callback_handler = config.callback_handler
- self.llm = self._setup_llm()
-
- def _setup_llm(self):
- model_config = {
- "gpt-4o-mini": {
- "model_name": "gpt-4o-mini",
- "api_key": self.secrets["OPENAI_API_KEY"],
- },
- "gemma2-9b": {
- "model_name": "gemma2-9b-it",
- "api_key": self.secrets["GROQ_API_KEY"],
- "base_url": "https://api.groq.com/openai/v1",
- },
- "claude3-haiku": {
- "model_name": "claude-3-haiku-20240307",
- "api_key": self.secrets["ANTHROPIC_API_KEY"],
- },
- "mixtral-8x22b": {
- "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
- "api_key": self.secrets["FIREWORKS_API_KEY"],
- "base_url": "https://api.fireworks.ai/inference/v1",
- },
- "llama-3.1-405b": {
- "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
- "api_key": self.secrets["FIREWORKS_API_KEY"],
- "base_url": "https://api.fireworks.ai/inference/v1",
- },
- }
-
- config = model_config[self.model_type]
-
- return (
- ChatOpenAI(
- model_name=config["model_name"],
- temperature=0.1,
- api_key=config["api_key"],
- max_tokens=700,
- callbacks=[self.callback_handler],
- streaming=True,
- base_url=config["base_url"]
- if config["model_name"] != "gpt-4o-mini"
- else None,
- default_headers={
- "HTTP-Referer": "https://snowchat.streamlit.app/",
- "X-Title": "Snowchat",
- },
- )
- if config["model_name"] != "claude-3-haiku-20240307"
- else (
- ChatAnthropic(
- model=config["model_name"],
- temperature=0.1,
- max_tokens=700,
- timeout=None,
- max_retries=2,
- callbacks=[self.callback_handler],
- streaming=True,
- )
- )
- )
-
- def get_chain(self, vectorstore):
- def _combine_documents(
- docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
- ):
- doc_strings = [format_document(doc, document_prompt) for doc in docs]
- return document_separator.join(doc_strings)
-
- _inputs = RunnableParallel(
- standalone_question=RunnablePassthrough.assign(
- chat_history=lambda x: get_buffer_string(x["chat_history"])
- )
- | CONDENSE_QUESTION_PROMPT
- | OpenAI()
- | StrOutputParser(),
- )
- _context = {
- "context": itemgetter("standalone_question")
- | vectorstore.as_retriever()
- | _combine_documents,
- "question": lambda x: x["standalone_question"],
- }
- conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
-
- return conversational_qa_chain
-
-
-def load_chain(model_name="qwen", callback_handler=None):
- embeddings = OpenAIEmbeddings(
- openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
- )
- vectorstore = SupabaseVectorStore(
- embedding=embeddings,
- client=supabase,
- table_name="documents",
- query_name="v_match_documents",
- )
-
- model_type_mapping = {
- "gpt-4o-mini": "gpt-4o-mini",
- "gemma2-9b": "gemma2-9b",
- "claude3-haiku": "claude3-haiku",
- "mixtral-8x22b": "mixtral-8x22b",
- "llama-3.1-405b": "llama-3.1-405b",
- }
-
- model_type = model_type_mapping.get(model_name.lower())
- if model_type is None:
- raise ValueError(f"Unsupported model name: {model_name}")
-
- config = ModelConfig(
- model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
- )
- model = ModelWrapper(config)
- return model.get_chain(vectorstore)
+# from dataclasses import dataclass, field
+# from operator import itemgetter
+# from typing import Any, Callable, Dict, Optional
+
+# import streamlit as st
+# from langchain.embeddings.openai import OpenAIEmbeddings
+# from langchain.llms import OpenAI
+# from langchain.prompts.prompt import PromptTemplate
+# from langchain.schema import format_document
+# from langchain.vectorstores import SupabaseVectorStore
+# from langchain_anthropic import ChatAnthropic
+# from langchain_community.chat_models import ChatOpenAI
+# from langchain_core.messages import get_buffer_string
+# from langchain_core.output_parsers import StrOutputParser
+# from langchain_core.runnables import RunnableParallel, RunnablePassthrough
+# from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+
+# from supabase.client import Client, create_client
+# from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
+
+# DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
+
+# supabase_url = st.secrets["SUPABASE_URL"]
+# supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
+# supabase: Client = create_client(supabase_url, supabase_key)
+
+
+# @dataclass
+# class ModelConfig:
+# model_type: str
+# secrets: Dict[str, Any]
+# callback_handler: Optional[Callable] = field(default=None)
+
+
+# class ModelWrapper:
+# def __init__(self, config: ModelConfig):
+# self.model_type = config.model_type
+# self.secrets = config.secrets
+# self.callback_handler = config.callback_handler
+# self.llm = self._setup_llm()
+
+# def _setup_llm(self):
+# model_config = {
+# "gpt-4o-mini": {
+# "model_name": "gpt-4o-mini",
+# "api_key": self.secrets["OPENAI_API_KEY"],
+# },
+# "gemma2-9b": {
+# "model_name": "gemma2-9b-it",
+# "api_key": self.secrets["GROQ_API_KEY"],
+# "base_url": "https://api.groq.com/openai/v1",
+# },
+# "claude3-haiku": {
+# "model_name": "claude-3-haiku-20240307",
+# "api_key": self.secrets["ANTHROPIC_API_KEY"],
+# },
+# "mixtral-8x22b": {
+# "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
+# "api_key": self.secrets["FIREWORKS_API_KEY"],
+# "base_url": "https://api.fireworks.ai/inference/v1",
+# },
+# "llama-3.1-405b": {
+# "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
+# "api_key": self.secrets["FIREWORKS_API_KEY"],
+# "base_url": "https://api.fireworks.ai/inference/v1",
+# },
+# }
+
+# config = model_config[self.model_type]
+
+# return (
+# ChatOpenAI(
+# model_name=config["model_name"],
+# temperature=0.1,
+# api_key=config["api_key"],
+# max_tokens=700,
+# callbacks=[self.callback_handler],
+# streaming=True,
+# base_url=config["base_url"]
+# if config["model_name"] != "gpt-4o-mini"
+# else None,
+# default_headers={
+# "HTTP-Referer": "https://snowchat.streamlit.app/",
+# "X-Title": "Snowchat",
+# },
+# )
+# if config["model_name"] != "claude-3-haiku-20240307"
+# else (
+# ChatAnthropic(
+# model=config["model_name"],
+# temperature=0.1,
+# max_tokens=700,
+# timeout=None,
+# max_retries=2,
+# callbacks=[self.callback_handler],
+# streaming=True,
+# )
+# )
+# )
+
+# def get_chain(self, vectorstore):
+# def _combine_documents(
+# docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
+# ):
+# doc_strings = [format_document(doc, document_prompt) for doc in docs]
+# return document_separator.join(doc_strings)
+
+# _inputs = RunnableParallel(
+# standalone_question=RunnablePassthrough.assign(
+# chat_history=lambda x: get_buffer_string(x["chat_history"])
+# )
+# | CONDENSE_QUESTION_PROMPT
+# | OpenAI()
+# | StrOutputParser(),
+# )
+# _context = {
+# "context": itemgetter("standalone_question")
+# | vectorstore.as_retriever()
+# | _combine_documents,
+# "question": lambda x: x["standalone_question"],
+# }
+# conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
+
+# return conversational_qa_chain
+
+
+# def load_chain(model_name="qwen", callback_handler=None):
+# embeddings = OpenAIEmbeddings(
+# openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
+# )
+# vectorstore = SupabaseVectorStore(
+# embedding=embeddings,
+# client=supabase,
+# table_name="documents",
+# query_name="v_match_documents",
+# )
+
+# model_type_mapping = {
+# "gpt-4o-mini": "gpt-4o-mini",
+# "gemma2-9b": "gemma2-9b",
+# "claude3-haiku": "claude3-haiku",
+# "mixtral-8x22b": "mixtral-8x22b",
+# "llama-3.1-405b": "llama-3.1-405b",
+# }
+
+# model_type = model_type_mapping.get(model_name.lower())
+# if model_type is None:
+# raise ValueError(f"Unsupported model name: {model_name}")
+
+# config = ModelConfig(
+# model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
+# )
+# model = ModelWrapper(config)
+# return model.get_chain(vectorstore)
diff --git a/ingest.py b/ingest.py
index 67de0f3..c6669f3 100644
--- a/ingest.py
+++ b/ingest.py
@@ -6,6 +6,7 @@
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import SupabaseVectorStore
from pydantic import BaseModel
+
from supabase.client import Client, create_client
diff --git a/main.py b/main.py
index 1a491be..97a6ddc 100644
--- a/main.py
+++ b/main.py
@@ -2,9 +2,10 @@
import warnings
import streamlit as st
+from langchain_core.messages import HumanMessage
from snowflake.snowpark.exceptions import SnowparkSQLException
-from chain import load_chain
+from agent import MessagesState, create_agent
# from utils.snow_connect import SnowflakeConnection
from utils.snowchat_ui import StreamlitUICallbackHandler, message_func
@@ -50,6 +51,9 @@
)
st.session_state["model"] = model
+if "assistant_response_processed" not in st.session_state:
+ st.session_state["assistant_response_processed"] = True # Initialize to True
+
if "toast_shown" not in st.session_state:
st.session_state["toast_shown"] = False
@@ -76,6 +80,7 @@
"content": "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍",
},
]
+config = {"configurable": {"thread_id": "42"}}
with open("ui/sidebar.md", "r") as sidebar_file:
sidebar_content = sidebar_file.read()
@@ -118,18 +123,28 @@
# Prompt for user input and save
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
+ st.session_state["assistant_response_processed"] = (
+ False # Assistant response not yet processed
+ )
-for message in st.session_state.messages:
+messages_to_display = st.session_state.messages.copy()
+# if not st.session_state["assistant_response_processed"]:
+# # Exclude the last assistant message if assistant response not yet processed
+# if messages_to_display and messages_to_display[-1]["role"] == "assistant":
+# print("\n\nthis is messages_to_display \n\n", messages_to_display)
+# messages_to_display = messages_to_display[:-1]
+
+for message in messages_to_display:
message_func(
message["content"],
- True if message["role"] == "user" else False,
- True if message["role"] == "data" else False,
- model,
+ is_user=(message["role"] == "user"),
+ is_df=(message["role"] == "data"),
+ model=model,
)
callback_handler = StreamlitUICallbackHandler(model)
-chain = load_chain(st.session_state["model"], callback_handler)
+react_graph = create_agent(callback_handler, st.session_state["model"])
def append_chat_history(question, answer):
@@ -148,20 +163,21 @@ def append_message(content, role="assistant"):
def handle_sql_exception(query, conn, e, retries=2):
- append_message("Uh oh, I made an error, let me try to fix it..")
- error_message = (
- "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n"
- + query
- + "\n```\n Error message: \n "
- + str(e)
- )
- new_query = chain({"question": error_message, "chat_history": ""})["answer"]
- append_message(new_query)
- if get_sql(new_query) and retries > 0:
- return execute_sql(get_sql(new_query), conn, retries - 1)
- else:
- append_message("I'm sorry, I couldn't fix the error. Please try again.")
- return None
+ # append_message("Uh oh, I made an error, let me try to fix it..")
+ # error_message = (
+ # "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n"
+ # + query
+ # + "\n```\n Error message: \n "
+ # + str(e)
+ # )
+ # new_query = chain({"question": error_message, "chat_history": ""})["answer"]
+ # append_message(new_query)
+ # if get_sql(new_query) and retries > 0:
+ # return execute_sql(get_sql(new_query), conn, retries - 1)
+ # else:
+ # append_message("I'm sorry, I couldn't fix the error. Please try again.")
+ # return None
+ pass
def execute_sql(query, conn, retries=2):
@@ -176,20 +192,25 @@ def execute_sql(query, conn, retries=2):
if (
"messages" in st.session_state
- and st.session_state["messages"][-1]["role"] != "assistant"
+ and st.session_state["messages"][-1]["role"] == "user"
+ and not st.session_state["assistant_response_processed"]
):
user_input_content = st.session_state["messages"][-1]["content"]
if isinstance(user_input_content, str):
+ # Start loading animation
callback_handler.start_loading_message()
- result = chain.invoke(
- {
- "question": user_input_content,
- "chat_history": [h for h in st.session_state["history"]],
- }
- )
- append_message(result.content)
+ messages = [HumanMessage(content=user_input_content)]
+
+ state = MessagesState(messages=messages)
+ result = react_graph.invoke(state, config=config)
+
+ if result["messages"]:
+ assistant_message = callback_handler.final_message
+ append_message(assistant_message)
+ st.session_state["assistant_response_processed"] = True
+
if (
st.session_state["model"] == "Mixtral 8x7B"
diff --git a/template.py b/template.py
index c8cd086..5cc1759 100644
--- a/template.py
+++ b/template.py
@@ -1,4 +1,3 @@
-from langchain.prompts.prompt import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
template = """You are an AI chatbot having a conversation with a human.
diff --git a/tools.py b/tools.py
new file mode 100644
index 0000000..5b5a450
--- /dev/null
+++ b/tools.py
@@ -0,0 +1,28 @@
+import streamlit as st
+from langchain.prompts.prompt import PromptTemplate
+from supabase.client import Client, create_client
+from langchain.tools.retriever import create_retriever_tool
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.vectorstores import SupabaseVectorStore
+
+supabase_url = st.secrets["SUPABASE_URL"]
+supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
+supabase: Client = create_client(supabase_url, supabase_key)
+
+
+embeddings = OpenAIEmbeddings(
+ openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
+)
+vectorstore = SupabaseVectorStore(
+ embedding=embeddings,
+ client=supabase,
+ table_name="documents",
+ query_name="v_match_documents",
+)
+
+
+retriever_tool = create_retriever_tool(
+ vectorstore.as_retriever(),
+ name="Database_Schema",
+ description="Search for database schema details",
+)
diff --git a/utils/snow_connect.py b/utils/snow_connect.py
index 2268c8b..d0b396a 100644
--- a/utils/snow_connect.py
+++ b/utils/snow_connect.py
@@ -2,7 +2,6 @@
import streamlit as st
from snowflake.snowpark.session import Session
-from snowflake.snowpark.version import VERSION
class SnowflakeConnection:
diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py
index 03f5f58..98a6337 100644
--- a/utils/snowchat_ui.py
+++ b/utils/snowchat_ui.py
@@ -1,10 +1,10 @@
import html
import re
+import textwrap
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
-
image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
mistral_url = (
@@ -61,7 +61,7 @@ def format_message(text):
def message_func(text, is_user=False, is_df=False, model="gpt"):
"""
- This function is used to display the messages in the chatbot UI.
+ This function displays messages in the chatbot UI, ensuring proper alignment and avatar positioning.
Parameters:
text (str): The text to be displayed.
@@ -69,52 +69,36 @@ def message_func(text, is_user=False, is_df=False, model="gpt"):
is_df (bool): Whether the message is a dataframe or not.
"""
model_url = get_model_url(model)
+ avatar_url = user_url if is_user else model_url
+ message_bg_color = (
+ "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" if is_user else "#71797E"
+ )
+ avatar_class = "user-avatar" if is_user else "bot-avatar"
+ alignment = "flex-end" if is_user else "flex-start"
+ margin_side = "margin-left" if is_user else "margin-right"
+ message_text = html.escape(text.strip()).replace('\n', '
')
- avatar_url = model_url
if is_user:
- avatar_url = user_url
- message_alignment = "flex-end"
- message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
- avatar_class = "user-avatar"
- st.write(
- f"""
-