Skip to content

Commit

Permalink
Update formating
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Apr 15, 2023
1 parent 451e922 commit 8cefe5a
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 97 deletions.
4 changes: 2 additions & 2 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.prompts.prompt import PromptTemplate
# from langchain.chat_models import ChatOpenAI

TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. Be a little bit creative and humorous.
TEMPLATE = """ You're name is snowchat, and you are a senior snowflake developer. You are currently working in a snowflake database. You have to write a sql code in snowflake database based on the following question. Also you have to ignore the sql keywords and the context and give a one or two sentences about how did you arrive at that sql code. display the sql code in the code format and the answer in the markdown format.
If you don't know the answer, just say "Hmm, I'm not sure. I am trained only to answer sql related queries. Please try again." Don't try to make up an answer.
Use snowflake database documentation https://docs.snowflake.com/sql-reference-commands for writing sql code.
Expand All @@ -22,7 +22,7 @@ def get_chain(vectorstore):
"""
Get a chain for chatting with a vector database.
"""
llm = ChatOpenAI(model_name='gpt-4', temperature=0.2)
llm = ChatOpenAI(model_name='gpt-4', temperature=0.8)
chat_vector_db_chain = ChatVectorDBChain.from_llm(llm=llm, vectorstore=vectorstore, qa_prompt=QA_PROMPT, verbose=True)

return chat_vector_db_chain
6 changes: 2 additions & 4 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import os
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain.document_loaders import DirectoryLoader, UnstructuredMarkdownLoader
from langchain.vectorstores import FAISS
from dotenv import load_dotenv
# import pandas as pd
load_dotenv()

PERSIST_DIRECTORY = 'store'

loader = DirectoryLoader('./', glob="*.txt")
loader = UnstructuredMarkdownLoader('schema.md')
data = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
Expand Down
116 changes: 58 additions & 58 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import pickle
import re
import html
import os
from langchain import FAISS
import openai
Expand All @@ -12,47 +12,52 @@
from streamlit import components
from utils import query_data_warehouse
load_dotenv()
from functools import lru_cache

openai.api_key = os.getenv('OPENAI_API_KEY')

with open("vectors.pkl", "rb") as f:
print('Loading model...')
vectorstore = pickle.load(f)
# @st.cache_resource
def load_chain():
with open("vectors.pkl", "rb") as f:
print('Loading model...')
vectorstore = pickle.load(f)

chain = get_chain(vectorstore)
return get_chain(vectorstore)

chain = load_chain()

st.title("snowChat")
st.subheader("Chat with Snowflake Database")

ddl_transactions = '''
CREATE OR REPLACE TABLE TRANSACTIONS (
TRANSACTION_ID NUMBER(38,0) NOT NULL,
ORDER_ID NUMBER(38,0),
PRODUCT_ID NUMBER(38,0),
QUANTITY NUMBER(38,0),
PRICE NUMBER(10,2),
PRIMARY KEY (TRANSACTION_ID),
FOREIGN KEY (ORDER_ID) REFERENCES STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID)
);
'''

# Add more DDLs for other tables here
# ddl_table2 = '''...'''
# ddl_table3 = '''...'''

# Create a dictionary to store the table names and their corresponding DDLs
ddl_dict = {
"TRANSACTIONS": ddl_transactions,
# "TABLE2": ddl_table2,
# "TABLE3": ddl_table3,
}
class SnowChat:
def __init__(self):
self.ddl_dict = self.load_ddls()

@staticmethod
def load_ddls():
ddl_files = {
"TRANSACTIONS": "sql/ddl_transactions.sql",
"ORDER_DETAILS": "sql/ddl_orders.sql",
"PAYMENTS": "sql/ddl_payments.sql",
"PRODUCTS": "sql/ddl_products.sql",
"CUSTOMER_DETAILS": "sql/ddl_customer.sql"
}

ddl_dict = {}
for table_name, file_name in ddl_files.items():
with open(file_name, "r") as f:
ddl_dict[table_name] = f.read()
# print(f"DDL for table loaded. {ddl_dict[table_name]} ")
return ddl_dict

snow_chat = SnowChat()

# Create a sidebar with a dropdown menu
selected_table = st.sidebar.selectbox("Select a table:", options=list(ddl_dict.keys()))
selected_table = st.sidebar.selectbox("Select a table:", options=list(snow_chat.ddl_dict.keys()))

# Display the DDL for the selected table
st.sidebar.markdown(f"### DDL for {selected_table} table")
st.sidebar.code(ddl_dict[selected_table], language="sql")
st.sidebar.code(snow_chat.ddl_dict[selected_table], language="sql")

st.write("""
<link rel="preconnect" href="https://fonts.gstatic.com">
Expand Down Expand Up @@ -106,6 +111,7 @@
if 'messages' not in st.session_state:
st.session_state['messages'] = [("Hello! I'm a chatbot designed to help you with Snowflake Database.")]

@st.cache_resource
def extract_code(text):
# Use OpenAI's GPT-3 to extract the SQL code
response = openai.ChatCompletion.create(
Expand All @@ -123,17 +129,10 @@ def extract_code(text):

messages_container = st.container()

# Add a button inside the container to get the value of the text input widget
# Input container
# input_container = st.container()
# form = st.form()

# query = input_container.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden")
with st.form(key='my_form'):
query = st.text_input("", key="input", placeholder="Type your query here...", label_visibility="hidden")
submit_button = st.form_submit_button(label='Submit')

# messages = []

if 'messages' not in st.session_state:
st.session_state['messages'] = []
Expand All @@ -151,35 +150,35 @@ def extract_code(text):
st.session_state.generated.append(result['answer'])

def message(text, is_user=False, key=None, avatar_style="Adventurer"):
text = html.escape(text)
if is_user:
avatar_url = f"https://avataaars.io/?avatarStyle=Circle&topType=ShortHairShortFlat&accessoriesType=Blank&hairColor=BrownDark&facialHairType=Blank&clotheType=Hoodie&clotheColor=Blue03&eyeType=Default&eyebrowType=Default&mouthType=Default&skinColor=Light"
message_alignment = "flex-end"
message_bg_color = "linear-gradient(135deg, #ff5f6d 0%, #ffc371 100%)"
avatar_class = "user-avatar"
st.write(f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<div style="background: {message_bg_color}; color: white; border-radius: 5px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text}
</div>
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" />
</div>
""", unsafe_allow_html=True)
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<div style="background: {message_bg_color}; color: white; border-radius: 5px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text}
</div>
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" />
</div>
""", unsafe_allow_html=True)
else:
avatar_url = f"https://avataaars.io/?avatarStyle=Circle&topType=LongHairBun&accessoriesType=Blank&hairColor=BrownDark&facialHairType=Blank&clotheType=BlazerShirt&eyeType=Default&eyebrowType=Default&mouthType=Default&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "linear-gradient(135deg, #36d1dc 0%, #5b86e5 100%)"
avatar_class = "bot-avatar"
st.write(f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" />
<div style="background: {message_bg_color}; color: white; border-radius: 5px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text}
</div>
</div>
""", unsafe_allow_html=True)


<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" />
<div style="background: {message_bg_color}; color: white; border-radius: 5px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text}
</div>
</div>
""", unsafe_allow_html=True)


with messages_container:
if st.session_state['generated']:
Expand All @@ -188,11 +187,12 @@ def message(text, is_user=False, key=None, avatar_style="Adventurer"):
message(st.session_state["generated"][i], key=str(i), avatar_style="Adventurer")
op = extract_code(st.session_state["generated"][i])
try:
if len(op) > 6:
print("op is", op)
df = query_data_warehouse(op)
st.spinner("Loading data...")
st.dataframe(df)
if len(op) > 2:
with st.spinner("In progress..."):
print("op is", op)
df = query_data_warehouse(op)

st.dataframe(df)
except:
pass

Expand Down
37 changes: 37 additions & 0 deletions schema.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
**Table 1: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS** (Stores customer information)

- CUSTOMER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for customers
- FIRST_NAME: Varchar (255) - First name of the customer
- LAST_NAME: Varchar (255) - Last name of the customer
- EMAIL: Varchar (255) - Email address of the customer
- PHONE: Varchar (20) - Phone number of the customer
- ADDRESS: Varchar (255) - Physical address of the customer

**Table 2: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS** (Stores order information)

- ORDER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for orders
- CUSTOMER_ID: Number (38,0) [Foreign Key - CUSTOMER_DETAILS(CUSTOMER_ID)] - Customer who made the order
- ORDER_DATE: Date - Date when the order was made
- TOTAL_AMOUNT: Number (10,2) - Total amount of the order

**Table 3: STREAM_HACKATHON.STREAMLIT.PAYMENTS** (Stores payment information)

- PAYMENT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for payments
- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the payment
- PAYMENT_DATE: Date - Date when the payment was made
- AMOUNT: Number (10,2) - Amount of the payment

**Table 4: STREAM_HACKATHON.STREAMLIT.PRODUCTS** (Stores product information)

- PRODUCT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for products
- PRODUCT_NAME: Varchar (255) - Name of the product
- CATEGORY: Varchar (255) - Category of the product
- PRICE: Number (10,2) - Price of the product

**Table 5: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS** (Stores transaction information)

- TRANSACTION_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for transactions
- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the transaction
- PRODUCT_ID: Number (38,0) - Product involved in the transaction
- QUANTITY: Number (38,0) - Quantity of the product in the transaction
- PRICE: Number (10,2) - Price of the product in the transaction
33 changes: 0 additions & 33 deletions schema.txt

This file was deleted.

9 changes: 9 additions & 0 deletions sql/ddl_customer.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
create or replace TABLE CUSTOMER_DETAILS (
CUSTOMER_ID NUMBER(38,0) NOT NULL,
FIRST_NAME VARCHAR(255),
LAST_NAME VARCHAR(255),
EMAIL VARCHAR(255),
PHONE VARCHAR(20),
ADDRESS VARCHAR(255),
primary key (CUSTOMER_ID)
);
8 changes: 8 additions & 0 deletions sql/ddl_orders.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
create or replace TABLE ORDER_DETAILS (
ORDER_ID NUMBER(38,0) NOT NULL,
CUSTOMER_ID NUMBER(38,0),
ORDER_DATE DATE,
TOTAL_AMOUNT NUMBER(10,2),
primary key (ORDER_ID),
foreign key (CUSTOMER_ID) references STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS(CUSTOMER_ID)
);
8 changes: 8 additions & 0 deletions sql/ddl_payments.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
create or replace TABLE PAYMENTS (
PAYMENT_ID NUMBER(38,0) NOT NULL,
ORDER_ID NUMBER(38,0),
PAYMENT_DATE DATE,
AMOUNT NUMBER(10,2),
primary key (PAYMENT_ID),
foreign key (ORDER_ID) references STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID)
);
7 changes: 7 additions & 0 deletions sql/ddl_products.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
create or replace TABLE PRODUCTS (
PRODUCT_ID NUMBER(38,0) NOT NULL,
PRODUCT_NAME VARCHAR(255),
CATEGORY VARCHAR(255),
PRICE NUMBER(10,2),
primary key (PRODUCT_ID)
);
9 changes: 9 additions & 0 deletions sql/ddl_transactions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE OR REPLACE TABLE TRANSACTIONS (
TRANSACTION_ID NUMBER(38,0) NOT NULL,
ORDER_ID NUMBER(38,0),
PRODUCT_ID NUMBER(38,0),
QUANTITY NUMBER(38,0),
PRICE NUMBER(10,2),
PRIMARY KEY (TRANSACTION_ID),
FOREIGN KEY (ORDER_ID) REFERENCES STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS(ORDER_ID)
);
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def query_data_warehouse(sql: str, parameters=None) -> any:

try:
cur.execute("USE DATABASE " + os.getenv("DATABASE"))
cur.execute("USE SCHEMA " + os.getenv("SCHEMA"))
cur.execute(query, parameters)
print("executing query")
all_rows = cur.fetchall()
Expand Down
Binary file modified vectors.pkl
Binary file not shown.

0 comments on commit 8cefe5a

Please sign in to comment.