diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 97e0250..6c7256e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,26 +4,39 @@ on: push: branches: - main - - prod pull_request: branches: - main - - prod jobs: lint: + name: Lint and Format Code runs-on: ubuntu-latest + steps: - name: Check out repository uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 + with: + python-version: "3.9" + + - name: Cache pip dependencies + uses: actions/cache@v3 with: - python-version: 3.9 + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- - name: Install dependencies - run: pip install black + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install black ruff mypy codespell - - name: Lint with black - run: black --check . + - name: Run Formatting and Linting + run: | + make format + make lint diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a768e53 --- /dev/null +++ b/Makefile @@ -0,0 +1,57 @@ +.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix + +# Define a variable for Python and notebook files. +PYTHON_FILES=src/ +MYPY_CACHE=.mypy_cache + +###################### +# LINTING AND FORMATTING +###################### + +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=src +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + python -m ruff check . + [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + ruff format $(PYTHON_FILES) + ruff check --fix $(PYTHON_FILES) + +spell_check: + codespell --toml pyproject.toml + +spell_fix: + codespell --toml pyproject.toml -w + +###################### +# RUN ALL +###################### + +all: format lint spell_check + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'spell_check - run spell check' + @echo 'all - run all tasks' + @echo 'lint-fix - run lint and fix issues' + +###################### +# LINT-FIX TARGET +###################### + +lint-fix: format lint + @echo "Linting and fixing completed successfully." diff --git a/agent.py b/agent.py new file mode 100644 index 0000000..a7a7462 --- /dev/null +++ b/agent.py @@ -0,0 +1,102 @@ +import os +from dataclasses import dataclass +from typing import Annotated, Sequence, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain_anthropic import ChatAnthropic +from langchain_core.messages import SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import START, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition +from langgraph.graph.message import add_messages +from langchain_core.messages import BaseMessage + +from template import TEMPLATE +from tools import retriever_tool + + +@dataclass +class MessagesState: + messages: Annotated[Sequence[BaseMessage], add_messages] + + +memory = MemorySaver() + + +@dataclass +class ModelConfig: + model_name: str + api_key: str + base_url: Optional[str] = None + + +def create_agent(callback_handler: BaseCallbackHandler, model_name: str): + model_configurations = { + "gpt-4o-mini": ModelConfig( + model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY") + ), + "gemma2-9b": ModelConfig( + model_name="gemma2-9b-it", + api_key=os.getenv("GROQ_API_KEY"), + base_url="https://api.groq.com/openai/v1", + ), + "claude3-haiku": ModelConfig( + model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") + ), + "mixtral-8x22b": ModelConfig( + model_name="accounts/fireworks/models/mixtral-8x22b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + "llama-3.1-405b": ModelConfig( + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", + api_key=os.getenv("FIREWORKS_API_KEY"), + base_url="https://api.fireworks.ai/inference/v1", + ), + } + config = model_configurations.get(model_name) + if not config: + raise ValueError(f"Unsupported model name: {model_name}") + + sys_msg = SystemMessage( + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. + Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code. + """ + ) + + llm = ( + ChatOpenAI( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + base_url=config.base_url, + ) + if config.model_name != "claude-3-haiku-20240307" + else ChatAnthropic( + model=config.model_name, + api_key=config.api_key, + callbacks=[callback_handler], + streaming=True, + ) + ) + + tools = [retriever_tool] + + llm_with_tools = llm.bind_tools(tools) + + def reasoner(state: MessagesState): + return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} + + # Build the graph + builder = StateGraph(MessagesState) + builder.add_node("reasoner", reasoner) + builder.add_node("tools", ToolNode(tools)) + builder.add_edge(START, "reasoner") + builder.add_conditional_edges("reasoner", tools_condition) + builder.add_edge("tools", "reasoner") + + react_graph = builder.compile(checkpointer=memory) + + return react_graph diff --git a/chain.py b/chain.py index a16b862..a0c2595 100644 --- a/chain.py +++ b/chain.py @@ -1,155 +1,154 @@ -from typing import Any, Callable, Dict, Optional - -import streamlit as st -from langchain_community.chat_models import ChatOpenAI -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAI -from langchain.vectorstores import SupabaseVectorStore -from pydantic import BaseModel, validator -from supabase.client import Client, create_client - -from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT - -from operator import itemgetter - -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import format_document -from langchain_core.messages import get_buffer_string -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableParallel, RunnablePassthrough -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_anthropic import ChatAnthropic - -DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") - -supabase_url = st.secrets["SUPABASE_URL"] -supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] -supabase: Client = create_client(supabase_url, supabase_key) - - -class ModelConfig(BaseModel): - model_type: str - secrets: Dict[str, Any] - callback_handler: Optional[Callable] = None - - -class ModelWrapper: - def __init__(self, config: ModelConfig): - self.model_type = config.model_type - self.secrets = config.secrets - self.callback_handler = config.callback_handler - self.llm = self._setup_llm() - - def _setup_llm(self): - model_config = { - "gpt-4o-mini": { - "model_name": "gpt-4o-mini", - "api_key": self.secrets["OPENAI_API_KEY"], - }, - "gemma2-9b": { - "model_name": "gemma2-9b-it", - "api_key": self.secrets["GROQ_API_KEY"], - "base_url": "https://api.groq.com/openai/v1", - }, - "claude3-haiku": { - "model_name": "claude-3-haiku-20240307", - "api_key": self.secrets["ANTHROPIC_API_KEY"], - }, - "mixtral-8x22b": { - "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - "llama-3.1-405b": { - "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", - "api_key": self.secrets["FIREWORKS_API_KEY"], - "base_url": "https://api.fireworks.ai/inference/v1", - }, - } - - config = model_config[self.model_type] - - return ( - ChatOpenAI( - model_name=config["model_name"], - temperature=0.1, - api_key=config["api_key"], - max_tokens=700, - callbacks=[self.callback_handler], - streaming=True, - base_url=config["base_url"] - if config["model_name"] != "gpt-4o-mini" - else None, - default_headers={ - "HTTP-Referer": "https://snowchat.streamlit.app/", - "X-Title": "Snowchat", - }, - ) - if config["model_name"] != "claude-3-haiku-20240307" - else ( - ChatAnthropic( - model=config["model_name"], - temperature=0.1, - max_tokens=700, - timeout=None, - max_retries=2, - callbacks=[self.callback_handler], - streaming=True, - ) - ) - ) - - def get_chain(self, vectorstore): - def _combine_documents( - docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" - ): - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) - - _inputs = RunnableParallel( - standalone_question=RunnablePassthrough.assign( - chat_history=lambda x: get_buffer_string(x["chat_history"]) - ) - | CONDENSE_QUESTION_PROMPT - | OpenAI() - | StrOutputParser(), - ) - _context = { - "context": itemgetter("standalone_question") - | vectorstore.as_retriever() - | _combine_documents, - "question": lambda x: x["standalone_question"], - } - conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm - - return conversational_qa_chain - - -def load_chain(model_name="qwen", callback_handler=None): - embeddings = OpenAIEmbeddings( - openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" - ) - vectorstore = SupabaseVectorStore( - embedding=embeddings, - client=supabase, - table_name="documents", - query_name="v_match_documents", - ) - - model_type_mapping = { - "gpt-4o-mini": "gpt-4o-mini", - "gemma2-9b": "gemma2-9b", - "claude3-haiku": "claude3-haiku", - "mixtral-8x22b": "mixtral-8x22b", - "llama-3.1-405b": "llama-3.1-405b", - } - - model_type = model_type_mapping.get(model_name.lower()) - if model_type is None: - raise ValueError(f"Unsupported model name: {model_name}") - - config = ModelConfig( - model_type=model_type, secrets=st.secrets, callback_handler=callback_handler - ) - model = ModelWrapper(config) - return model.get_chain(vectorstore) +# from dataclasses import dataclass, field +# from operator import itemgetter +# from typing import Any, Callable, Dict, Optional + +# import streamlit as st +# from langchain.embeddings.openai import OpenAIEmbeddings +# from langchain.llms import OpenAI +# from langchain.prompts.prompt import PromptTemplate +# from langchain.schema import format_document +# from langchain.vectorstores import SupabaseVectorStore +# from langchain_anthropic import ChatAnthropic +# from langchain_community.chat_models import ChatOpenAI +# from langchain_core.messages import get_buffer_string +# from langchain_core.output_parsers import StrOutputParser +# from langchain_core.runnables import RunnableParallel, RunnablePassthrough +# from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +# from supabase.client import Client, create_client +# from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT + +# DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") + +# supabase_url = st.secrets["SUPABASE_URL"] +# supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +# supabase: Client = create_client(supabase_url, supabase_key) + + +# @dataclass +# class ModelConfig: +# model_type: str +# secrets: Dict[str, Any] +# callback_handler: Optional[Callable] = field(default=None) + + +# class ModelWrapper: +# def __init__(self, config: ModelConfig): +# self.model_type = config.model_type +# self.secrets = config.secrets +# self.callback_handler = config.callback_handler +# self.llm = self._setup_llm() + +# def _setup_llm(self): +# model_config = { +# "gpt-4o-mini": { +# "model_name": "gpt-4o-mini", +# "api_key": self.secrets["OPENAI_API_KEY"], +# }, +# "gemma2-9b": { +# "model_name": "gemma2-9b-it", +# "api_key": self.secrets["GROQ_API_KEY"], +# "base_url": "https://api.groq.com/openai/v1", +# }, +# "claude3-haiku": { +# "model_name": "claude-3-haiku-20240307", +# "api_key": self.secrets["ANTHROPIC_API_KEY"], +# }, +# "mixtral-8x22b": { +# "model_name": "accounts/fireworks/models/mixtral-8x22b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# "llama-3.1-405b": { +# "model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct", +# "api_key": self.secrets["FIREWORKS_API_KEY"], +# "base_url": "https://api.fireworks.ai/inference/v1", +# }, +# } + +# config = model_config[self.model_type] + +# return ( +# ChatOpenAI( +# model_name=config["model_name"], +# temperature=0.1, +# api_key=config["api_key"], +# max_tokens=700, +# callbacks=[self.callback_handler], +# streaming=True, +# base_url=config["base_url"] +# if config["model_name"] != "gpt-4o-mini" +# else None, +# default_headers={ +# "HTTP-Referer": "https://snowchat.streamlit.app/", +# "X-Title": "Snowchat", +# }, +# ) +# if config["model_name"] != "claude-3-haiku-20240307" +# else ( +# ChatAnthropic( +# model=config["model_name"], +# temperature=0.1, +# max_tokens=700, +# timeout=None, +# max_retries=2, +# callbacks=[self.callback_handler], +# streaming=True, +# ) +# ) +# ) + +# def get_chain(self, vectorstore): +# def _combine_documents( +# docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" +# ): +# doc_strings = [format_document(doc, document_prompt) for doc in docs] +# return document_separator.join(doc_strings) + +# _inputs = RunnableParallel( +# standalone_question=RunnablePassthrough.assign( +# chat_history=lambda x: get_buffer_string(x["chat_history"]) +# ) +# | CONDENSE_QUESTION_PROMPT +# | OpenAI() +# | StrOutputParser(), +# ) +# _context = { +# "context": itemgetter("standalone_question") +# | vectorstore.as_retriever() +# | _combine_documents, +# "question": lambda x: x["standalone_question"], +# } +# conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm + +# return conversational_qa_chain + + +# def load_chain(model_name="qwen", callback_handler=None): +# embeddings = OpenAIEmbeddings( +# openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +# ) +# vectorstore = SupabaseVectorStore( +# embedding=embeddings, +# client=supabase, +# table_name="documents", +# query_name="v_match_documents", +# ) + +# model_type_mapping = { +# "gpt-4o-mini": "gpt-4o-mini", +# "gemma2-9b": "gemma2-9b", +# "claude3-haiku": "claude3-haiku", +# "mixtral-8x22b": "mixtral-8x22b", +# "llama-3.1-405b": "llama-3.1-405b", +# } + +# model_type = model_type_mapping.get(model_name.lower()) +# if model_type is None: +# raise ValueError(f"Unsupported model name: {model_name}") + +# config = ModelConfig( +# model_type=model_type, secrets=st.secrets, callback_handler=callback_handler +# ) +# model = ModelWrapper(config) +# return model.get_chain(vectorstore) diff --git a/ingest.py b/ingest.py index 67de0f3..c6669f3 100644 --- a/ingest.py +++ b/ingest.py @@ -6,6 +6,7 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel + from supabase.client import Client, create_client diff --git a/main.py b/main.py index 1a491be..97a6ddc 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,10 @@ import warnings import streamlit as st +from langchain_core.messages import HumanMessage from snowflake.snowpark.exceptions import SnowparkSQLException -from chain import load_chain +from agent import MessagesState, create_agent # from utils.snow_connect import SnowflakeConnection from utils.snowchat_ui import StreamlitUICallbackHandler, message_func @@ -50,6 +51,9 @@ ) st.session_state["model"] = model +if "assistant_response_processed" not in st.session_state: + st.session_state["assistant_response_processed"] = True # Initialize to True + if "toast_shown" not in st.session_state: st.session_state["toast_shown"] = False @@ -76,6 +80,7 @@ "content": "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍", }, ] +config = {"configurable": {"thread_id": "42"}} with open("ui/sidebar.md", "r") as sidebar_file: sidebar_content = sidebar_file.read() @@ -118,18 +123,28 @@ # Prompt for user input and save if prompt := st.chat_input(): st.session_state.messages.append({"role": "user", "content": prompt}) + st.session_state["assistant_response_processed"] = ( + False # Assistant response not yet processed + ) -for message in st.session_state.messages: +messages_to_display = st.session_state.messages.copy() +# if not st.session_state["assistant_response_processed"]: +# # Exclude the last assistant message if assistant response not yet processed +# if messages_to_display and messages_to_display[-1]["role"] == "assistant": +# print("\n\nthis is messages_to_display \n\n", messages_to_display) +# messages_to_display = messages_to_display[:-1] + +for message in messages_to_display: message_func( message["content"], - True if message["role"] == "user" else False, - True if message["role"] == "data" else False, - model, + is_user=(message["role"] == "user"), + is_df=(message["role"] == "data"), + model=model, ) callback_handler = StreamlitUICallbackHandler(model) -chain = load_chain(st.session_state["model"], callback_handler) +react_graph = create_agent(callback_handler, st.session_state["model"]) def append_chat_history(question, answer): @@ -148,20 +163,21 @@ def append_message(content, role="assistant"): def handle_sql_exception(query, conn, e, retries=2): - append_message("Uh oh, I made an error, let me try to fix it..") - error_message = ( - "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" - + query - + "\n```\n Error message: \n " - + str(e) - ) - new_query = chain({"question": error_message, "chat_history": ""})["answer"] - append_message(new_query) - if get_sql(new_query) and retries > 0: - return execute_sql(get_sql(new_query), conn, retries - 1) - else: - append_message("I'm sorry, I couldn't fix the error. Please try again.") - return None + # append_message("Uh oh, I made an error, let me try to fix it..") + # error_message = ( + # "You gave me a wrong SQL. FIX The SQL query by searching the schema definition: \n```sql\n" + # + query + # + "\n```\n Error message: \n " + # + str(e) + # ) + # new_query = chain({"question": error_message, "chat_history": ""})["answer"] + # append_message(new_query) + # if get_sql(new_query) and retries > 0: + # return execute_sql(get_sql(new_query), conn, retries - 1) + # else: + # append_message("I'm sorry, I couldn't fix the error. Please try again.") + # return None + pass def execute_sql(query, conn, retries=2): @@ -176,20 +192,25 @@ def execute_sql(query, conn, retries=2): if ( "messages" in st.session_state - and st.session_state["messages"][-1]["role"] != "assistant" + and st.session_state["messages"][-1]["role"] == "user" + and not st.session_state["assistant_response_processed"] ): user_input_content = st.session_state["messages"][-1]["content"] if isinstance(user_input_content, str): + # Start loading animation callback_handler.start_loading_message() - result = chain.invoke( - { - "question": user_input_content, - "chat_history": [h for h in st.session_state["history"]], - } - ) - append_message(result.content) + messages = [HumanMessage(content=user_input_content)] + + state = MessagesState(messages=messages) + result = react_graph.invoke(state, config=config) + + if result["messages"]: + assistant_message = callback_handler.final_message + append_message(assistant_message) + st.session_state["assistant_response_processed"] = True + if ( st.session_state["model"] == "Mixtral 8x7B" diff --git a/template.py b/template.py index c8cd086..5cc1759 100644 --- a/template.py +++ b/template.py @@ -1,4 +1,3 @@ -from langchain.prompts.prompt import PromptTemplate from langchain_core.prompts import ChatPromptTemplate template = """You are an AI chatbot having a conversation with a human. diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..5b5a450 --- /dev/null +++ b/tools.py @@ -0,0 +1,28 @@ +import streamlit as st +from langchain.prompts.prompt import PromptTemplate +from supabase.client import Client, create_client +from langchain.tools.retriever import create_retriever_tool +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import SupabaseVectorStore + +supabase_url = st.secrets["SUPABASE_URL"] +supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +supabase: Client = create_client(supabase_url, supabase_key) + + +embeddings = OpenAIEmbeddings( + openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" +) +vectorstore = SupabaseVectorStore( + embedding=embeddings, + client=supabase, + table_name="documents", + query_name="v_match_documents", +) + + +retriever_tool = create_retriever_tool( + vectorstore.as_retriever(), + name="Database_Schema", + description="Search for database schema details", +) diff --git a/utils/snow_connect.py b/utils/snow_connect.py index 2268c8b..d0b396a 100644 --- a/utils/snow_connect.py +++ b/utils/snow_connect.py @@ -2,7 +2,6 @@ import streamlit as st from snowflake.snowpark.session import Session -from snowflake.snowpark.version import VERSION class SnowflakeConnection: diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 03f5f58..98a6337 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -1,10 +1,10 @@ import html import re +import textwrap import streamlit as st from langchain.callbacks.base import BaseCallbackHandler - image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z" mistral_url = ( @@ -61,7 +61,7 @@ def format_message(text): def message_func(text, is_user=False, is_df=False, model="gpt"): """ - This function is used to display the messages in the chatbot UI. + This function displays messages in the chatbot UI, ensuring proper alignment and avatar positioning. Parameters: text (str): The text to be displayed. @@ -69,52 +69,36 @@ def message_func(text, is_user=False, is_df=False, model="gpt"): is_df (bool): Whether the message is a dataframe or not. """ model_url = get_model_url(model) + avatar_url = user_url if is_user else model_url + message_bg_color = ( + "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" if is_user else "#71797E" + ) + avatar_class = "user-avatar" if is_user else "bot-avatar" + alignment = "flex-end" if is_user else "flex-start" + margin_side = "margin-left" if is_user else "margin-right" + message_text = html.escape(text.strip()).replace('\n', '
') - avatar_url = model_url if is_user: - avatar_url = user_url - message_alignment = "flex-end" - message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" - avatar_class = "user-avatar" - st.write( - f""" -
-
- {text} \n
- avatar -
- """, - unsafe_allow_html=True, - ) + container_html = f""" +
+
+ {message_text} +
+ avatar +
+ """ else: - message_alignment = "flex-start" - message_bg_color = "#71797E" - avatar_class = "bot-avatar" + container_html = f""" +
+ avatar +
+ {message_text} +
+
+ """ - if is_df: - st.write( - f""" -
- avatar -
- """, - unsafe_allow_html=True, - ) - st.write(text) - return - else: - text = format_message(text) + st.write(container_html, unsafe_allow_html=True) - st.write( - f""" -
- avatar -
- {text} \n
-
- """, - unsafe_allow_html=True, - ) class StreamlitUICallbackHandler(BaseCallbackHandler): @@ -125,6 +109,7 @@ def __init__(self, model): self.has_streaming_started = False self.model = model self.avatar_url = get_model_url(model) + self.final_message = "" def start_loading_message(self): loading_message_content = self._get_bot_message_container("Thinking...") @@ -138,6 +123,7 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): complete_message = "".join(self.token_buffer) container_content = self._get_bot_message_container(complete_message) self.placeholder.markdown(container_content, unsafe_allow_html=True) + self.final_message = "".join(self.token_buffer) def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): self.token_buffer = [] @@ -146,16 +132,20 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): def _get_bot_message_container(self, text): """Generate the bot's message container style for the given text.""" - formatted_text = format_message(text) + formatted_text = format_message(text.strip()) container_content = f""" -
- avatar -
- {formatted_text} \n
+
+ avatar +
+ {formatted_text}
+
""" return container_content + + + def display_dataframe(self, df): """ Display the dataframe in Streamlit UI within the chat container. @@ -165,13 +155,14 @@ def display_dataframe(self, df): st.write( f""" -
- avatar +
+ avatar
""", unsafe_allow_html=True, ) st.write(df) + def __call__(self, *args, **kwargs): pass