diff --git a/README.md b/README.md index 04df43e..187f1ee 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [](https://streamlit.io/) [](https://openai.com/) [](https://www.snowflake.com/en/) +[](https://www.supabase.io/) [](https://snowchat.streamlit.app/) @@ -18,7 +19,7 @@ - Interactive and user-friendly interface - Integration with Snowflake Data Warehouse - Utilizes OpenAI's GPT-3.5-turbo-16k and text-embedding-ada-002 -- Uses In-memory Vector Database FAISS for storing and searching through vectors +- Uses Supabase PG-vector Vector Database for storing and searching through vectors ## 🛠️ Installation @@ -29,13 +30,15 @@ cd snowchat pip install -r requirements.txt -3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA` and `WAREHOUSE` in project directory `secrets.toml`. If you don't have access to GPT-4 change the script in chain.py replace gpt-4 in model_name to gpt-3.5-turbo +3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` and `SUPABASE_SERVICE_KEY` in project directory `secrets.toml`. -4. Make you're schema.md that matches you're database. +4. Make you're schemas and store them in docs folder that matches you're database. -5. Run `python ingest.py` to get convert to embeddings and store as an index file. +5. Create supabase extention, table and function from the supabase/scripts.sql. -6. Run the Streamlit app to start chatting: +6. Run `python ingest.py` to get convert to embeddings and store as an index file. + +7. Run the Streamlit app to start chatting: streamlit run main.py ## 📚 Usage diff --git a/chain.py b/chain.py index 129dcda..341d8c9 100644 --- a/chain.py +++ b/chain.py @@ -1,8 +1,5 @@ from langchain.prompts.prompt import PromptTemplate -from langchain.chains import ( - ConversationalRetrievalChain, - LLMChain -) +from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.llms import OpenAI import streamlit as st @@ -15,7 +12,7 @@ {question} \""" Standalone question:""" - + condense_question_prompt = PromptTemplate.from_template(template) TEMPLATE = """ You're a helpful AI assistant who is specialized in data analysis using SQL. You have to write sql code in snowflake database based on the following question. Give a one or two sentences about how did you arrive at that sql code. (do not assume anything if the column is not available then say it is not available, do not make up code). Write the sql code in markdown format. @@ -25,7 +22,7 @@ Answer: -""" +""" QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) @@ -33,27 +30,24 @@ def get_chain(vectorstore): """ Get a chain for chatting with a vector database. """ - q_llm = OpenAI(temperature=0, openai_api_key=st.secrets["OPENAI_API_KEY"], model_name='gpt-3.5-turbo-16k') - - llm = OpenAI( - model_name='gpt-3.5-turbo', + q_llm = OpenAI( temperature=0, - openai_api_key=st.secrets["OPENAI_API_KEY"] - ) - - question_generator = LLMChain( - llm=q_llm, - prompt=condense_question_prompt + openai_api_key=st.secrets["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo-16k", ) - - doc_chain = load_qa_chain( - llm=llm, - chain_type="stuff", - prompt=QA_PROMPT + + llm = OpenAI( + model_name="gpt-3.5-turbo", + temperature=0, + openai_api_key=st.secrets["OPENAI_API_KEY"], ) + + question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) + + doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) chain = ConversationalRetrievalChain( - retriever=vectorstore.as_retriever(), - combine_docs_chain=doc_chain, - question_generator=question_generator - ) - return chain \ No newline at end of file + retriever=vectorstore.as_retriever(), + combine_docs_chain=doc_chain, + question_generator=question_generator, + ) + return chain diff --git a/docs/customer_details.md b/docs/customer_details.md new file mode 100644 index 0000000..7f4e758 --- /dev/null +++ b/docs/customer_details.md @@ -0,0 +1,10 @@ +**Table 1: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS** (Stores customer information) + +This table contains the personal information of customers who have made purchases on the platform. + +- CUSTOMER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for customers +- FIRST_NAME: Varchar (255) - First name of the customer +- LAST_NAME: Varchar (255) - Last name of the customer +- EMAIL: Varchar (255) - Email address of the customer +- PHONE: Varchar (20) - Phone number of the customer +- ADDRESS: Varchar (255) - Physical address of the customer \ No newline at end of file diff --git a/docs/order_details.md b/docs/order_details.md new file mode 100644 index 0000000..f461304 --- /dev/null +++ b/docs/order_details.md @@ -0,0 +1,8 @@ +**Table 2: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS** (Stores order information) + +This table contains information about orders placed by customers, including the date and total amount of each order. + +- ORDER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for orders +- CUSTOMER_ID: Number (38,0) [Foreign Key - CUSTOMER_DETAILS(CUSTOMER_ID)] - Customer who made the order +- ORDER_DATE: Date - Date when the order was made +- TOTAL_AMOUNT: Number (10,2) - Total amount of the order \ No newline at end of file diff --git a/docs/payments.md b/docs/payments.md new file mode 100644 index 0000000..5c8908e --- /dev/null +++ b/docs/payments.md @@ -0,0 +1,8 @@ +**Table 3: STREAM_HACKATHON.STREAMLIT.PAYMENTS** (Stores payment information) + +This table contains information about payments made by customers for their orders, including the date and amount of each payment. + +- PAYMENT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for payments +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the payment +- PAYMENT_DATE: Date - Date when the payment was made +- AMOUNT: Number (10,2) - Amount of the payment \ No newline at end of file diff --git a/docs/products.md b/docs/products.md new file mode 100644 index 0000000..130341b --- /dev/null +++ b/docs/products.md @@ -0,0 +1,8 @@ +**Table 4: STREAM_HACKATHON.STREAMLIT.PRODUCTS** (Stores product information) + +This table contains information about the products available for purchase on the platform, including their name, category, and price. + +- PRODUCT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for products +- PRODUCT_NAME: Varchar (255) - Name of the product +- CATEGORY: Varchar (255) - Category of the product +- PRICE: Number (10,2) - Price of the product \ No newline at end of file diff --git a/docs/transactions.md b/docs/transactions.md new file mode 100644 index 0000000..265e56a --- /dev/null +++ b/docs/transactions.md @@ -0,0 +1,9 @@ +**Table 5: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS** (Stores transaction information) + +This table contains information about individual transactions that occur when customers purchase products, including the associated order, product, quantity, and price. + +- TRANSACTION_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for transactions +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the transaction +- PRODUCT_ID: Number (38,0) - Product involved in the transaction +- QUANTITY: Number (38,0) - Quantity of the product in the transaction +- PRICE: Number (10,2) - Price of the product in the transaction diff --git a/faiss_index/index.faiss b/faiss_index/index.faiss deleted file mode 100644 index 9eb051f..0000000 Binary files a/faiss_index/index.faiss and /dev/null differ diff --git a/faiss_index/index.pkl b/faiss_index/index.pkl deleted file mode 100644 index e14084f..0000000 Binary files a/faiss_index/index.pkl and /dev/null differ diff --git a/ingest.py b/ingest.py index 606bde4..eaf470a 100644 --- a/ingest.py +++ b/ingest.py @@ -1,20 +1,57 @@ - -from langchain.embeddings import OpenAIEmbeddings -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.document_loaders import UnstructuredMarkdownLoader -from langchain.vectorstores import FAISS +from pydantic import BaseModel +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import SupabaseVectorStore +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.document_loaders import DirectoryLoader import streamlit as st +from supabase.client import Client, create_client +from typing import Any, Dict + + +class Secrets(BaseModel): + SUPABASE_URL: str + SUPABASE_SERVICE_KEY: str + OPENAI_API_KEY: str + + +class Config(BaseModel): + chunk_size: int = 1000 + chunk_overlap: int = 0 + docs_dir: str = "docs/" + docs_glob: str = "**/*.md" + + +class DocumentProcessor: + def __init__(self, secrets: Secrets, config: Config): + self.client: Client = create_client( + secrets.SUPABASE_URL, secrets.SUPABASE_SERVICE_KEY + ) + self.loader = DirectoryLoader(config.docs_dir, glob=config.docs_glob) + self.text_splitter = CharacterTextSplitter( + chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap + ) + self.embeddings = OpenAIEmbeddings(openai_api_key=secrets.OPENAI_API_KEY) -loader = UnstructuredMarkdownLoader('schema.md') -data = loader.load() + def process(self) -> Dict[str, Any]: + data = self.loader.load() + texts = self.text_splitter.split_documents(data) + vector_store = SupabaseVectorStore.from_documents( + texts, self.embeddings, client=self.client + ) + return vector_store -text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20) -texts = text_splitter.split_documents(data) -embeddings = OpenAIEmbeddings(openai_api_key = st.secrets["OPENAI_API_KEY"]) -docsearch = FAISS.from_documents(texts, embeddings) +def run(): + secrets = Secrets( + SUPABASE_URL=st.secrets["SUPABASE_URL"], + SUPABASE_SERVICE_KEY=st.secrets["SUPABASE_SERVICE_KEY"], + OPENAI_API_KEY=st.secrets["OPENAI_API_KEY"], + ) + config = Config() + doc_processor = DocumentProcessor(secrets, config) + result = doc_processor.process() + return result -docsearch.save_local("faiss_index") -# with open("vectors.pkl", "wb") as f: -# pickle.dump(docsearch, f) +if __name__ == "__main__": + run() diff --git a/main.py b/main.py index 0158c4e..b1345ab 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,3 @@ - import openai import streamlit as st import warnings @@ -6,42 +5,59 @@ from langchain.embeddings.openai import OpenAIEmbeddings from streamlit import components from utils.snowflake import query_data_warehouse -from langchain.vectorstores import FAISS +from langchain.vectorstores import SupabaseVectorStore from utils.snowddl import Snowddl -from utils.snowchat_ui import reset_chat_history, extract_code, message_func, is_sql_query +from utils.snowchat_ui import ( + reset_chat_history, + extract_code, + message_func, + is_sql_query, +) from snowflake.connector.errors import ProgrammingError -warnings.filterwarnings('ignore') +from supabase.client import Client, create_client + +warnings.filterwarnings("ignore") openai.api_key = st.secrets["OPENAI_API_KEY"] MAX_INPUTS = 1 chat_history = [] +supabase_url = st.secrets["SUPABASE_URL"] +supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +supabase: Client = create_client(supabase_url, supabase_key) + st.set_page_config( page_title="snowChat", page_icon="❄️", layout="centered", initial_sidebar_state="auto", menu_items={ - 'Report a bug': "https://github.com/kaarthik108/snowChat", - 'About': '''snowChat is a chatbot designed to help you with Snowflake Database. It is built using OpenAI's GPT-4 and Streamlit. + "Report a bug": "https://github.com/kaarthik108/snowChat", + "About": """snowChat is a chatbot designed to help you with Snowflake Database. It is built using OpenAI's GPT-4 and Streamlit. Go to the GitHub repo to learn more about the project. https://github.com/kaarthik108/snowChat - ''' - } + """, + }, ) + def load_chain(): - ''' + """ Load the chain from the local file system Returns: chain (Chain): The chain object - ''' + """ - embeddings = OpenAIEmbeddings(openai_api_key=st.secrets["OPENAI_API_KEY"]) - vectorstore = FAISS.load_local("faiss_index", embeddings) + embeddings = OpenAIEmbeddings( + openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" + ) + vectorstore = SupabaseVectorStore( + embedding=embeddings, client=supabase, table_name="documents" + ) return get_chain(vectorstore) + snow_ddl = Snowddl() st.title("snowChat") @@ -58,25 +74,28 @@ def load_chain(): # Create a sidebar with a dropdown menu selected_table = st.sidebar.selectbox( - "Select a table:", options=list(snow_ddl.ddl_dict.keys())) + "Select a table:", options=list(snow_ddl.ddl_dict.keys()) +) st.sidebar.markdown(f"### DDL for {selected_table} table") st.sidebar.code(snow_ddl.ddl_dict[selected_table], language="sql") st.write(styles_content, unsafe_allow_html=True) -if 'generated' not in st.session_state: - st.session_state['generated'] = [ - "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍"] -if 'past' not in st.session_state: - st.session_state['past'] = ["Hey!"] +if "generated" not in st.session_state: + st.session_state["generated"] = [ + "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍" + ] +if "past" not in st.session_state: + st.session_state["past"] = ["Hey!"] if "input" not in st.session_state: st.session_state["input"] = "" if "stored_session" not in st.session_state: st.session_state["stored_session"] = [] -if 'messages' not in st.session_state: - st.session_state['messages'] = [ - ("Hello! I'm a chatbot designed to help you with Snowflake Database.")] +if "messages" not in st.session_state: + st.session_state["messages"] = [ + ("Hello! I'm a chatbot designed to help you with Snowflake Database.") + ] if "query_count" not in st.session_state: st.session_state["query_count"] = 0 @@ -84,26 +103,32 @@ def load_chain(): RESET = True messages_container = st.container() -with st.form(key='my_form'): - query = st.text_input("Query: ", key="input", value="", - placeholder="Type your query here...", label_visibility="hidden") - submit_button = st.form_submit_button(label='Submit') +with st.form(key="my_form"): + query = st.text_input( + "Query: ", + key="input", + value="", + placeholder="Type your query here...", + label_visibility="hidden", + ) + submit_button = st.form_submit_button(label="Submit") col1, col2 = st.columns([1, 3.2]) reset_button = col1.button("Reset Chat History") -if reset_button or st.session_state['query_count'] >= MAX_INPUTS and RESET: +if reset_button or st.session_state["query_count"] >= MAX_INPUTS and RESET: RESET = False - st.session_state['query_count'] = 0 + st.session_state["query_count"] = 0 reset_chat_history() -if 'messages' not in st.session_state: - st.session_state['messages'] = [] +if "messages" not in st.session_state: + st.session_state["messages"] = [] + def update_progress_bar(value, prefix, progress_bar=None): if progress_bar is None: progress_bar = st.empty() - key = f'{prefix}_progress_bar_value' + key = f"{prefix}_progress_bar_value" if key not in st.session_state: st.session_state[key] = 0 @@ -113,50 +138,58 @@ def update_progress_bar(value, prefix, progress_bar=None): st.session_state[key] = 0 progress_bar.empty() + chain = load_chain() if len(query) > 2 and submit_button: submit_progress_bar = st.empty() - messages = st.session_state['messages'] - update_progress_bar(33, 'submit', submit_progress_bar) + messages = st.session_state["messages"] + update_progress_bar(33, "submit", submit_progress_bar) result = chain({"question": query, "chat_history": chat_history}) - update_progress_bar(66, 'submit', submit_progress_bar) - st.session_state['query_count'] += 1 + update_progress_bar(66, "submit", submit_progress_bar) + st.session_state["query_count"] += 1 messages.append((query, result["answer"])) st.session_state.past.append(query) - st.session_state.generated.append(result['answer']) - update_progress_bar(100, 'submit', submit_progress_bar) + st.session_state.generated.append(result["answer"]) + update_progress_bar(100, "submit", submit_progress_bar) + def self_heal(df, to_extract, i): - ''' + """ If the query fails, try to fix it by extracting the code from the error message and running it again. - + Args: df (pandas.DataFrame): The dataframe generated from the query to_extract (str): The query i (int): The index of the query in the chat history - + Returns: df (pandas.DataFrame): The dataframe generated from the query - ''' - + """ + error_message = str(df) - error_message = "I have an SQL query that's causing an error. FIX The SQL query by searching the schema definition: \n```sql\n" + to_extract + "\n```\n Error message: \n " + error_message + error_message = ( + "I have an SQL query that's causing an error. FIX The SQL query by searching the schema definition: \n```sql\n" + + to_extract + + "\n```\n Error message: \n " + + error_message + ) recover = chain({"question": error_message, "chat_history": ""}) - message_func(recover['answer']) - to_extract = extract_code(recover['answer']) - st.session_state["generated"][i] = recover['answer'] + message_func(recover["answer"]) + to_extract = extract_code(recover["answer"]) + st.session_state["generated"][i] = recover["answer"] if is_sql_query(to_extract): df = query_data_warehouse(to_extract) - + return df + def generate_df(to_extract: str, i: int): - ''' + """ Generate a dataframe from the query by querying the data warehouse. Args: @@ -165,7 +198,7 @@ def generate_df(to_extract: str, i: int): Returns: df (pandas.DataFrame): The dataframe generated from the query - ''' + """ df = query_data_warehouse(to_extract) if isinstance(df, ProgrammingError) and is_sql_query(to_extract): message_func("uh oh, I made an error, let me try to fix it") @@ -175,9 +208,9 @@ def generate_df(to_extract: str, i: int): with messages_container: - if st.session_state['generated']: - for i in range(len(st.session_state['generated'])): - message_func(st.session_state['past'][i], is_user=True) + if st.session_state["generated"]: + for i in range(len(st.session_state["generated"])): + message_func(st.session_state["past"][i], is_user=True) message_func(st.session_state["generated"][i]) if i > 0 and is_sql_query(st.session_state["generated"][i]): code = extract_code(st.session_state["generated"][i]) @@ -187,15 +220,17 @@ def generate_df(to_extract: str, i: int): except: # noqa: E722 pass -if st.session_state['query_count'] == MAX_INPUTS and RESET: +if st.session_state["query_count"] == MAX_INPUTS and RESET: st.warning( - "You have reached the maximum number of inputs. The chat history will be cleared after the next input.") + "You have reached the maximum number of inputs. The chat history will be cleared after the next input." + ) col2.markdown( - f'