diff --git a/chain.py b/chain.py
index b3f3f5d..e684e88 100644
--- a/chain.py
+++ b/chain.py
@@ -8,7 +8,7 @@
from langchain.prompts.prompt import PromptTemplate
# from langchain.chat_models import ChatOpenAI
-TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. Be a little bit creative and humorous.
+TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. display the sql code in the code format and the answer in the markdown format.
If you don't know the answer, just say "Hmm, I'm not sure. I am trained only to answer sql related queries. Please try again." Don't try to make up an answer.
Use snowflake database documentation https://docs.snowflake.com/sql-reference-commands for writing sql code.
@@ -22,7 +22,7 @@ def get_chain(vectorstore):
"""
Get a chain for chatting with a vector database.
"""
- llm = ChatOpenAI(model_name='gpt-4', temperature=0.2)
+ llm = ChatOpenAI(model_name='gpt-4', temperature=0.8)
chat_vector_db_chain = ChatVectorDBChain.from_llm(llm=llm, vectorstore=vectorstore, qa_prompt=QA_PROMPT, verbose=True)
return chat_vector_db_chain
diff --git a/ingest.py b/ingest.py
index f35b43f..17313ab 100644
--- a/ingest.py
+++ b/ingest.py
@@ -2,15 +2,13 @@
import os
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.document_loaders import DirectoryLoader
+from langchain.document_loaders import DirectoryLoader, UnstructuredMarkdownLoader
from langchain.vectorstores import FAISS
from dotenv import load_dotenv
# import pandas as pd
load_dotenv()
-PERSIST_DIRECTORY = 'store'
-
-loader = DirectoryLoader('./', glob="*.txt")
+loader = UnstructuredMarkdownLoader('schema.md')
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
diff --git a/main.py b/main.py
index c03b5ff..a065fa7 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,6 @@
import pickle
-import re
+import html
import os
from langchain import FAISS
import openai
@@ -12,47 +12,52 @@
from streamlit import components
from utils import query_data_warehouse
load_dotenv()
+from functools import lru_cache
openai.api_key = os.getenv('OPENAI_API_KEY')
-with open("vectors.pkl", "rb") as f:
- print('Loading model...')
- vectorstore = pickle.load(f)
+# @st.cache_resource
+def load_chain():
+ with open("vectors.pkl", "rb") as f:
+ print('Loading model...')
+ vectorstore = pickle.load(f)
-chain = get_chain(vectorstore)
+ return get_chain(vectorstore)
+
+chain = load_chain()
st.title("snowChat")
st.subheader("Chat with Snowflake Database")
-ddl_transactions = '''
-CREATE OR REPLACE TABLE TRANSACTIONS (
- TRANSACTION_ID NUMBER(38,0) NOT NULL,
- ORDER_ID NUMBER(38,0),
- PRODUCT_ID NUMBER(38,0),
- QUANTITY NUMBER(38,0),
- PRICE NUMBER(10,2),
- PRIMARY KEY (TRANSACTION_ID),
- FOREIGN KEY (ORDER_ID) REFERENCES STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID)
-);
-'''
-
-# Add more DDLs for other tables here
-# ddl_table2 = '''...'''
-# ddl_table3 = '''...'''
-
-# Create a dictionary to store the table names and their corresponding DDLs
-ddl_dict = {
- "TRANSACTIONS": ddl_transactions,
- # "TABLE2": ddl_table2,
- # "TABLE3": ddl_table3,
-}
+class SnowChat:
+ def __init__(self):
+ self.ddl_dict = self.load_ddls()
+
+ @staticmethod
+ def load_ddls():
+ ddl_files = {
+ "TRANSACTIONS": "sql/ddl_transactions.sql",
+ "ORDER_DETAILS": "sql/ddl_orders.sql",
+ "PAYMENTS": "sql/ddl_payments.sql",
+ "PRODUCTS": "sql/ddl_products.sql",
+ "CUSTOMER_DETAILS": "sql/ddl_customer.sql"
+ }
+
+ ddl_dict = {}
+ for table_name, file_name in ddl_files.items():
+ with open(file_name, "r") as f:
+ ddl_dict[table_name] = f.read()
+ # print(f"DDL for table loaded. {ddl_dict[table_name]} ")
+ return ddl_dict
+
+snow_chat = SnowChat()
# Create a sidebar with a dropdown menu
-selected_table = st.sidebar.selectbox("Select a table:", options=list(ddl_dict.keys()))
+selected_table = st.sidebar.selectbox("Select a table:", options=list(snow_chat.ddl_dict.keys()))
# Display the DDL for the selected table
st.sidebar.markdown(f"### DDL for {selected_table} table")
-st.sidebar.code(ddl_dict[selected_table], language="sql")
+st.sidebar.code(snow_chat.ddl_dict[selected_table], language="sql")
st.write("""
@@ -106,6 +111,7 @@
if 'messages' not in st.session_state:
st.session_state['messages'] = [("Hello! I'm a chatbot designed to help you with Snowflake Database.")]
+@st.cache_resource
def extract_code(text):
# Use OpenAI's GPT-3 to extract the SQL code
response = openai.ChatCompletion.create(
@@ -123,17 +129,10 @@ def extract_code(text):
messages_container = st.container()
-# Add a button inside the container to get the value of the text input widget
-# Input container
-# input_container = st.container()
-# form = st.form()
-
-# query = input_container.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden")
with st.form(key='my_form'):
query = st.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden")
submit_button = st.form_submit_button(label='Submit')
-# messages = []
if 'messages' not in st.session_state:
st.session_state['messages'] = []
@@ -151,35 +150,35 @@ def extract_code(text):
st.session_state.generated.append(result['answer'])
def message(text, is_user=False, key=None, avatar_style="Adventurer"):
+ text = html.escape(text)
if is_user:
avatar_url = f"https://avataaars.io/?avatarStyle=Circle&topType=ShortHairShortFlat&accessoriesType=Blank&hairColor=BrownDark&facialHairType=Blank&clotheType=Hoodie&clotheColor=Blue03&eyeType=Default&eyebrowType=Default&mouthType=Default&skinColor=Light"
message_alignment = "flex-end"
message_bg_color = "linear-gradient(135deg, #ff5f6d 0%, #ffc371 100%)"
avatar_class = "user-avatar"
st.write(f"""
-