diff --git a/chain.py b/chain.py index b3f3f5d..e684e88 100644 --- a/chain.py +++ b/chain.py @@ -8,7 +8,7 @@ from langchain.prompts.prompt import PromptTemplate # from langchain.chat_models import ChatOpenAI -TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. Be a little bit creative and humorous. +TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. display the sql code in the code format and the answer in the markdown format. If you don't know the answer, just say "Hmm, I'm not sure. I am trained only to answer sql related queries. Please try again." Don't try to make up an answer. Use snowflake database documentation https://docs.snowflake.com/sql-reference-commands for writing sql code. @@ -22,7 +22,7 @@ def get_chain(vectorstore): """ Get a chain for chatting with a vector database. """ - llm = ChatOpenAI(model_name='gpt-4', temperature=0.2) + llm = ChatOpenAI(model_name='gpt-4', temperature=0.8) chat_vector_db_chain = ChatVectorDBChain.from_llm(llm=llm, vectorstore=vectorstore, qa_prompt=QA_PROMPT, verbose=True) return chat_vector_db_chain diff --git a/ingest.py b/ingest.py index f35b43f..17313ab 100644 --- a/ingest.py +++ b/ingest.py @@ -2,15 +2,13 @@ import os from langchain.embeddings import OpenAIEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.document_loaders import DirectoryLoader +from langchain.document_loaders import DirectoryLoader, UnstructuredMarkdownLoader from langchain.vectorstores import FAISS from dotenv import load_dotenv # import pandas as pd load_dotenv() -PERSIST_DIRECTORY = 'store' - -loader = DirectoryLoader('./', glob="*.txt") +loader = UnstructuredMarkdownLoader('schema.md') data = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) diff --git a/main.py b/main.py index c03b5ff..a065fa7 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import pickle -import re +import html import os from langchain import FAISS import openai @@ -12,47 +12,52 @@ from streamlit import components from utils import query_data_warehouse load_dotenv() +from functools import lru_cache openai.api_key = os.getenv('OPENAI_API_KEY') -with open("vectors.pkl", "rb") as f: - print('Loading model...') - vectorstore = pickle.load(f) +# @st.cache_resource +def load_chain(): + with open("vectors.pkl", "rb") as f: + print('Loading model...') + vectorstore = pickle.load(f) -chain = get_chain(vectorstore) + return get_chain(vectorstore) + +chain = load_chain() st.title("snowChat") st.subheader("Chat with Snowflake Database") -ddl_transactions = ''' -CREATE OR REPLACE TABLE TRANSACTIONS ( - TRANSACTION_ID NUMBER(38,0) NOT NULL, - ORDER_ID NUMBER(38,0), - PRODUCT_ID NUMBER(38,0), - QUANTITY NUMBER(38,0), - PRICE NUMBER(10,2), - PRIMARY KEY (TRANSACTION_ID), - FOREIGN KEY (ORDER_ID) REFERENCES STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID) -); -''' - -# Add more DDLs for other tables here -# ddl_table2 = '''...''' -# ddl_table3 = '''...''' - -# Create a dictionary to store the table names and their corresponding DDLs -ddl_dict = { - "TRANSACTIONS": ddl_transactions, - # "TABLE2": ddl_table2, - # "TABLE3": ddl_table3, -} +class SnowChat: + def __init__(self): + self.ddl_dict = self.load_ddls() + + @staticmethod + def load_ddls(): + ddl_files = { + "TRANSACTIONS": "sql/ddl_transactions.sql", + "ORDER_DETAILS": "sql/ddl_orders.sql", + "PAYMENTS": "sql/ddl_payments.sql", + "PRODUCTS": "sql/ddl_products.sql", + "CUSTOMER_DETAILS": "sql/ddl_customer.sql" + } + + ddl_dict = {} + for table_name, file_name in ddl_files.items(): + with open(file_name, "r") as f: + ddl_dict[table_name] = f.read() + # print(f"DDL for table loaded. {ddl_dict[table_name]} ") + return ddl_dict + +snow_chat = SnowChat() # Create a sidebar with a dropdown menu -selected_table = st.sidebar.selectbox("Select a table:", options=list(ddl_dict.keys())) +selected_table = st.sidebar.selectbox("Select a table:", options=list(snow_chat.ddl_dict.keys())) # Display the DDL for the selected table st.sidebar.markdown(f"### DDL for {selected_table} table") -st.sidebar.code(ddl_dict[selected_table], language="sql") +st.sidebar.code(snow_chat.ddl_dict[selected_table], language="sql") st.write(""" @@ -106,6 +111,7 @@ if 'messages' not in st.session_state: st.session_state['messages'] = [("Hello! I'm a chatbot designed to help you with Snowflake Database.")] +@st.cache_resource def extract_code(text): # Use OpenAI's GPT-3 to extract the SQL code response = openai.ChatCompletion.create( @@ -123,17 +129,10 @@ def extract_code(text): messages_container = st.container() -# Add a button inside the container to get the value of the text input widget -# Input container -# input_container = st.container() -# form = st.form() - -# query = input_container.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden") with st.form(key='my_form'): query = st.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden") submit_button = st.form_submit_button(label='Submit') -# messages = [] if 'messages' not in st.session_state: st.session_state['messages'] = [] @@ -151,35 +150,35 @@ def extract_code(text): st.session_state.generated.append(result['answer']) def message(text, is_user=False, key=None, avatar_style="Adventurer"): + text = html.escape(text) if is_user: avatar_url = f"https://avataaars.io/?avatarStyle=Circle&topType=ShortHairShortFlat&accessoriesType=Blank&hairColor=BrownDark&facialHairType=Blank&clotheType=Hoodie&clotheColor=Blue03&eyeType=Default&eyebrowType=Default&mouthType=Default&skinColor=Light" message_alignment = "flex-end" message_bg_color = "linear-gradient(135deg, #ff5f6d 0%, #ffc371 100%)" avatar_class = "user-avatar" st.write(f""" -
-
- {text} -
- avatar - -
- """, unsafe_allow_html=True) +
+
+ {text} +
+ avatar + +
+ """, unsafe_allow_html=True) else: avatar_url = f"https://avataaars.io/?avatarStyle=Circle&topType=LongHairBun&accessoriesType=Blank&hairColor=BrownDark&facialHairType=Blank&clotheType=BlazerShirt&eyeType=Default&eyebrowType=Default&mouthType=Default&skinColor=Light" message_alignment = "flex-start" message_bg_color = "linear-gradient(135deg, #36d1dc 0%, #5b86e5 100%)" avatar_class = "bot-avatar" st.write(f""" -
- avatar -
- {text} -
-
- """, unsafe_allow_html=True) - - +
+ avatar +
+ {text} +
+
+ """, unsafe_allow_html=True) + with messages_container: if st.session_state['generated']: @@ -188,11 +187,12 @@ def message(text, is_user=False, key=None, avatar_style="Adventurer"): message(st.session_state["generated"][i], key=str(i), avatar_style="Adventurer") op = extract_code(st.session_state["generated"][i]) try: - if len(op) > 6: - print("op is", op) - df = query_data_warehouse(op) - st.spinner("Loading data...") - st.dataframe(df) + if len(op) > 2: + with st.spinner("In progress..."): + print("op is", op) + df = query_data_warehouse(op) + + st.dataframe(df) except: pass diff --git a/schema.md b/schema.md new file mode 100644 index 0000000..eeacb3b --- /dev/null +++ b/schema.md @@ -0,0 +1,37 @@ +**Table 1: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS** (Stores customer information) + +- CUSTOMER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for customers +- FIRST_NAME: Varchar (255) - First name of the customer +- LAST_NAME: Varchar (255) - Last name of the customer +- EMAIL: Varchar (255) - Email address of the customer +- PHONE: Varchar (20) - Phone number of the customer +- ADDRESS: Varchar (255) - Physical address of the customer + +**Table 2: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS** (Stores order information) + +- ORDER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for orders +- CUSTOMER_ID: Number (38,0) [Foreign Key - CUSTOMER_DETAILS(CUSTOMER_ID)] - Customer who made the order +- ORDER_DATE: Date - Date when the order was made +- TOTAL_AMOUNT: Number (10,2) - Total amount of the order + +**Table 3: STREAM_HACKATHON.STREAMLIT.PAYMENTS** (Stores payment information) + +- PAYMENT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for payments +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the payment +- PAYMENT_DATE: Date - Date when the payment was made +- AMOUNT: Number (10,2) - Amount of the payment + +**Table 4: STREAM_HACKATHON.STREAMLIT.PRODUCTS** (Stores product information) + +- PRODUCT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for products +- PRODUCT_NAME: Varchar (255) - Name of the product +- CATEGORY: Varchar (255) - Category of the product +- PRICE: Number (10,2) - Price of the product + +**Table 5: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS** (Stores transaction information) + +- TRANSACTION_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for transactions +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the transaction +- PRODUCT_ID: Number (38,0) - Product involved in the transaction +- QUANTITY: Number (38,0) - Quantity of the product in the transaction +- PRICE: Number (10,2) - Price of the product in the transaction diff --git a/schema.txt b/schema.txt deleted file mode 100644 index 7aa5a91..0000000 --- a/schema.txt +++ /dev/null @@ -1,33 +0,0 @@ - -tablename: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS -CUSTOMER_ID: Number (38,0), not null, primary key -FIRST_NAME: Varchar (255) -LAST_NAME: Varchar (255) -EMAIL: Varchar (255) -PHONE: Varchar (20) -ADDRESS: Varchar (255) - -tablename: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS -ORDER_ID: Number (38,0), not null, primary key -CUSTOMER_ID: Number (38,0), foreign key references CUSTOMER_DETAILS(CUSTOMER_ID) -ORDER_DATE: Date -TOTAL_AMOUNT: Number (10,2) - -tablename: STREAM_HACKATHON.STREAMLIT.PAYMENTS -PAYMENT_ID: Number (38,0), not null, primary key -ORDER_ID: Number (38,0), foreign key references ORDER_DETAILS(ORDER_ID) -PAYMENT_DATE: Date -AMOUNT: Number (10,2) - -tablename: STREAM_HACKATHON.STREAMLIT.PRODUCTS -PRODUCT_ID: Number (38,0), not null, primary key -PRODUCT_NAME: Varchar (255) -CATEGORY: Varchar (255) -PRICE: Number (10,2) - -tablename: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS -TRANSACTION_ID: Number (38,0), not null, primary key -ORDER_ID: Number (38,0), foreign key references ORDER_DETAILS(ORDER_ID) -PRODUCT_ID: Number (38,0) -QUANTITY: Number (38,0) -PRICE: Number (10,2) \ No newline at end of file diff --git a/sql/ddl_customer.sql b/sql/ddl_customer.sql new file mode 100644 index 0000000..2b734ad --- /dev/null +++ b/sql/ddl_customer.sql @@ -0,0 +1,9 @@ +create or replace TABLE CUSTOMER_DETAILS ( + CUSTOMER_ID NUMBER(38,0) NOT NULL, + FIRST_NAME VARCHAR(255), + LAST_NAME VARCHAR(255), + EMAIL VARCHAR(255), + PHONE VARCHAR(20), + ADDRESS VARCHAR(255), + primary key (CUSTOMER_ID) +); \ No newline at end of file diff --git a/sql/ddl_orders.sql b/sql/ddl_orders.sql new file mode 100644 index 0000000..f7cef21 --- /dev/null +++ b/sql/ddl_orders.sql @@ -0,0 +1,8 @@ +create or replace TABLE ORDER_DETAILS ( + ORDER_ID NUMBER(38,0) NOT NULL, + CUSTOMER_ID NUMBER(38,0), + ORDER_DATE DATE, + TOTAL_AMOUNT NUMBER(10,2), + primary key (ORDER_ID), + foreign key (CUSTOMER_ID) references STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS(CUSTOMER_ID) +); \ No newline at end of file diff --git a/sql/ddl_payments.sql b/sql/ddl_payments.sql new file mode 100644 index 0000000..0d04254 --- /dev/null +++ b/sql/ddl_payments.sql @@ -0,0 +1,8 @@ +create or replace TABLE PAYMENTS ( + PAYMENT_ID NUMBER(38,0) NOT NULL, + ORDER_ID NUMBER(38,0), + PAYMENT_DATE DATE, + AMOUNT NUMBER(10,2), + primary key (PAYMENT_ID), + foreign key (ORDER_ID) references STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID) +); \ No newline at end of file diff --git a/sql/ddl_products.sql b/sql/ddl_products.sql new file mode 100644 index 0000000..5dc6aa7 --- /dev/null +++ b/sql/ddl_products.sql @@ -0,0 +1,7 @@ +create or replace TABLE PRODUCTS ( + PRODUCT_ID NUMBER(38,0) NOT NULL, + PRODUCT_NAME VARCHAR(255), + CATEGORY VARCHAR(255), + PRICE NUMBER(10,2), + primary key (PRODUCT_ID) +); \ No newline at end of file diff --git a/sql/ddl_transactions.sql b/sql/ddl_transactions.sql new file mode 100644 index 0000000..3d576d9 --- /dev/null +++ b/sql/ddl_transactions.sql @@ -0,0 +1,9 @@ +CREATE OR REPLACE TABLE TRANSACTIONS ( + TRANSACTION_ID NUMBER(38,0) NOT NULL, + ORDER_ID NUMBER(38,0), + PRODUCT_ID NUMBER(38,0), + QUANTITY NUMBER(38,0), + PRICE NUMBER(10,2), + PRIMARY KEY (TRANSACTION_ID), + FOREIGN KEY (ORDER_ID) REFERENCES STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID) +); \ No newline at end of file diff --git a/utils.py b/utils.py index 7d37788..ac67c05 100644 --- a/utils.py +++ b/utils.py @@ -37,6 +37,7 @@ def query_data_warehouse(sql: str, parameters=None) -> any: try: cur.execute("USE DATABASE " + os.getenv("DATABASE")) + cur.execute("USE SCHEMA " + os.getenv("SCHEMA")) cur.execute(query, parameters) print("executing query") all_rows = cur.fetchall() diff --git a/vectors.pkl b/vectors.pkl index 3bf017d..114cfd3 100644 Binary files a/vectors.pkl and b/vectors.pkl differ