From 0b2a2c9e6640dd1eb993bf05ff5019e79fbdd580 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sun, 25 Jun 2023 14:59:25 +1200 Subject: [PATCH] Add supabase as vector DB --- README.md | 13 ++-- chain.py | 46 ++++++------ docs/customer_details.md | 10 +++ docs/order_details.md | 8 +++ docs/payments.md | 8 +++ docs/products.md | 8 +++ docs/transactions.md | 9 +++ faiss_index/index.faiss | Bin 18477 -> 0 bytes faiss_index/index.pkl | Bin 3085 -> 0 bytes ingest.py | 65 +++++++++++++---- main.py | 147 ++++++++++++++++++++++++--------------- requirements.txt | 3 +- schema.md | 47 ------------- streamlit-hack.yml | 20 ------ supabase/scripts.sql | 36 ++++++++++ utils/snowchat_ui.py | 86 ++++++++++++++++------- utils/snowddl.py | 11 +-- utils/snowflake.py | 8 +-- 18 files changed, 320 insertions(+), 205 deletions(-) create mode 100644 docs/customer_details.md create mode 100644 docs/order_details.md create mode 100644 docs/payments.md create mode 100644 docs/products.md create mode 100644 docs/transactions.md delete mode 100644 faiss_index/index.faiss delete mode 100644 faiss_index/index.pkl delete mode 100644 schema.md delete mode 100644 streamlit-hack.yml create mode 100644 supabase/scripts.sql diff --git a/README.md b/README.md index 04df43e..187f1ee 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Streamlit](https://img.shields.io/badge/-Streamlit-FF4B4B?style=flat-square&logo=streamlit&logoColor=white)](https://streamlit.io/) [![OpenAI](https://img.shields.io/badge/-OpenAI-412991?style=flat-square&logo=openai&logoColor=white)](https://openai.com/) [![Snowflake](https://img.shields.io/badge/-Snowflake-29BFFF?style=flat-square&logo=snowflake&logoColor=white)](https://www.snowflake.com/en/) +[![Supabase](https://img.shields.io/badge/-Supabase-00C04A?style=flat-square&logo=supabase&logoColor=white)](https://www.supabase.io/) [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://snowchat.streamlit.app/) @@ -18,7 +19,7 @@ - Interactive and user-friendly interface - Integration with Snowflake Data Warehouse - Utilizes OpenAI's GPT-3.5-turbo-16k and text-embedding-ada-002 -- Uses In-memory Vector Database FAISS for storing and searching through vectors +- Uses Supabase PG-vector Vector Database for storing and searching through vectors ## 🛠️ Installation @@ -29,13 +30,15 @@ cd snowchat pip install -r requirements.txt -3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA` and `WAREHOUSE` in project directory `secrets.toml`. If you don't have access to GPT-4 change the script in chain.py replace gpt-4 in model_name to gpt-3.5-turbo +3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` and `SUPABASE_SERVICE_KEY` in project directory `secrets.toml`. -4. Make you're schema.md that matches you're database. +4. Make you're schemas and store them in docs folder that matches you're database. -5. Run `python ingest.py` to get convert to embeddings and store as an index file. +5. Create supabase extention, table and function from the supabase/scripts.sql. -6. Run the Streamlit app to start chatting: +6. Run `python ingest.py` to get convert to embeddings and store as an index file. + +7. Run the Streamlit app to start chatting: streamlit run main.py ## 📚 Usage diff --git a/chain.py b/chain.py index 129dcda..341d8c9 100644 --- a/chain.py +++ b/chain.py @@ -1,8 +1,5 @@ from langchain.prompts.prompt import PromptTemplate -from langchain.chains import ( - ConversationalRetrievalChain, - LLMChain -) +from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.llms import OpenAI import streamlit as st @@ -15,7 +12,7 @@ {question} \""" Standalone question:""" - + condense_question_prompt = PromptTemplate.from_template(template) TEMPLATE = """ You're a helpful AI assistant who is specialized in data analysis using SQL. You have to write sql code in snowflake database based on the following question. Give a one or two sentences about how did you arrive at that sql code. (do not assume anything if the column is not available then say it is not available, do not make up code). Write the sql code in markdown format. @@ -25,7 +22,7 @@ Answer: -""" +""" QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) @@ -33,27 +30,24 @@ def get_chain(vectorstore): """ Get a chain for chatting with a vector database. """ - q_llm = OpenAI(temperature=0, openai_api_key=st.secrets["OPENAI_API_KEY"], model_name='gpt-3.5-turbo-16k') - - llm = OpenAI( - model_name='gpt-3.5-turbo', + q_llm = OpenAI( temperature=0, - openai_api_key=st.secrets["OPENAI_API_KEY"] - ) - - question_generator = LLMChain( - llm=q_llm, - prompt=condense_question_prompt + openai_api_key=st.secrets["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo-16k", ) - - doc_chain = load_qa_chain( - llm=llm, - chain_type="stuff", - prompt=QA_PROMPT + + llm = OpenAI( + model_name="gpt-3.5-turbo", + temperature=0, + openai_api_key=st.secrets["OPENAI_API_KEY"], ) + + question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) + + doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) chain = ConversationalRetrievalChain( - retriever=vectorstore.as_retriever(), - combine_docs_chain=doc_chain, - question_generator=question_generator - ) - return chain \ No newline at end of file + retriever=vectorstore.as_retriever(), + combine_docs_chain=doc_chain, + question_generator=question_generator, + ) + return chain diff --git a/docs/customer_details.md b/docs/customer_details.md new file mode 100644 index 0000000..7f4e758 --- /dev/null +++ b/docs/customer_details.md @@ -0,0 +1,10 @@ +**Table 1: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS** (Stores customer information) + +This table contains the personal information of customers who have made purchases on the platform. + +- CUSTOMER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for customers +- FIRST_NAME: Varchar (255) - First name of the customer +- LAST_NAME: Varchar (255) - Last name of the customer +- EMAIL: Varchar (255) - Email address of the customer +- PHONE: Varchar (20) - Phone number of the customer +- ADDRESS: Varchar (255) - Physical address of the customer \ No newline at end of file diff --git a/docs/order_details.md b/docs/order_details.md new file mode 100644 index 0000000..f461304 --- /dev/null +++ b/docs/order_details.md @@ -0,0 +1,8 @@ +**Table 2: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS** (Stores order information) + +This table contains information about orders placed by customers, including the date and total amount of each order. + +- ORDER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for orders +- CUSTOMER_ID: Number (38,0) [Foreign Key - CUSTOMER_DETAILS(CUSTOMER_ID)] - Customer who made the order +- ORDER_DATE: Date - Date when the order was made +- TOTAL_AMOUNT: Number (10,2) - Total amount of the order \ No newline at end of file diff --git a/docs/payments.md b/docs/payments.md new file mode 100644 index 0000000..5c8908e --- /dev/null +++ b/docs/payments.md @@ -0,0 +1,8 @@ +**Table 3: STREAM_HACKATHON.STREAMLIT.PAYMENTS** (Stores payment information) + +This table contains information about payments made by customers for their orders, including the date and amount of each payment. + +- PAYMENT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for payments +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the payment +- PAYMENT_DATE: Date - Date when the payment was made +- AMOUNT: Number (10,2) - Amount of the payment \ No newline at end of file diff --git a/docs/products.md b/docs/products.md new file mode 100644 index 0000000..130341b --- /dev/null +++ b/docs/products.md @@ -0,0 +1,8 @@ +**Table 4: STREAM_HACKATHON.STREAMLIT.PRODUCTS** (Stores product information) + +This table contains information about the products available for purchase on the platform, including their name, category, and price. + +- PRODUCT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for products +- PRODUCT_NAME: Varchar (255) - Name of the product +- CATEGORY: Varchar (255) - Category of the product +- PRICE: Number (10,2) - Price of the product \ No newline at end of file diff --git a/docs/transactions.md b/docs/transactions.md new file mode 100644 index 0000000..265e56a --- /dev/null +++ b/docs/transactions.md @@ -0,0 +1,9 @@ +**Table 5: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS** (Stores transaction information) + +This table contains information about individual transactions that occur when customers purchase products, including the associated order, product, quantity, and price. + +- TRANSACTION_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for transactions +- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the transaction +- PRODUCT_ID: Number (38,0) - Product involved in the transaction +- QUANTITY: Number (38,0) - Quantity of the product in the transaction +- PRICE: Number (10,2) - Price of the product in the transaction diff --git a/faiss_index/index.faiss b/faiss_index/index.faiss deleted file mode 100644 index 9eb051fde0f89248b69f10dacfa95168664d759e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 18477 zcmXwg2V73?`+rM{wh^KvqCy!_=e|y|${yLX?7jEaP%4#?LYhP}GFs2MuOm@JqUu_qp%)^J9ESAIX;0_+!~zj`<58@k4ccsv!e(*(;+3eASo-%)LM( zS=RGpnqk}EZ?|wZ;f;aX`%V=282*Qr7y4pM>kxK-(`_{U`w)y01M$P?M)2d~RMw$; z3MOVgXH&{{U~s#3u=7|bzS7OXGx{%aa^+jx8Nap~(%J6yCmNcmj85kKE8lKT1m#B?=)&+pQN z559g5OtuY!Lpf$JI*%)DmgX|A3y=8lEjln_&R&T8_eV;c-Wf)=vcdOVCFr-iuA1?K z$^C{H^Uv13G38VO2EP2lKVP#@tpZf{NmH+&zR5Gz|ME5W#~vlTX@WhNT#0sXUR@-9 z2j{ov0G$QCp1+MdTLeP#@J%q&r;zq46AjA4F@5D$ST$xmOdQt`(+)ny2(Q)9@9;91 zGgMoxS8`9&V^$ITdS}i1L^YD*_6K40tsv<1eJ{W2o{PJ6+H;dD4H2#!b&r~T6Z3Yh zL!(obY+=h+Y_35F!$T(|$9-EkJzhGzW)SZ;=7KcEwii5{_EDM-*MRm1qHM=WJr@6E z--deO568BOEq>=2pXz|kj#sSYXm4DyU^w2&IV|>p1wQ1sZ>@ve)uaY4{4|0O@usXO z@H_2^wcPx}PDt%rM>gDRt~O~Ohxb?7u%wa@X!)N4bq2M;4aqvH*ZSu;D%ugEd|SYp z=V1`^r?G4iy3IXvrY`qdb_}}y4(B|#hc#aJF7g_9loA+K#x=-elBY?Xh$4}ck414VI7V)f5^V)THytq z-stD652vH=fz>%}Xwu$FrgLO}yVYlQv3DV+TNT@iCNfr}p}^UQr9)xSW066n1cig{8aT+$*L#SE`9E5})UAO5S{&P#iwL;H>Uk*{kZuW{G{51W4_T-SlI z@2_&g4BXjifT?}A;B=SwnEE%29em`B&$F-Kj^FOEWO^Jdu^+<-12X-doeEFEM_2Q? z<<(zs>#wc6Gph%F`?nh>bk{)T!Ybvb<#M1Iz%CC?;*#lB7@1MRbqp4vCJFFJ<2ev{ z$PzX-Frn+yWdeh1KN(>jZ4NeHnDV2S`)HEaBJJkwE9kI(%*`$8?Z6@d?}=7|Fo= z5FX0ej#aH)*@r`g%3DJRMp(kv8*BJna~bJ1S=q8Lcs@j1oxbWN)@`eg*GsJN=cNl= z%PIvA?snybUzKKpH{R0~Xm4Qnli6JV>`C6g=_cIpr89Kdm5r+dKJW>imGEDw;1V5G z&o-L&D;6i;8YN~9TIzn`p(`4y#^$E{GOXv;jawt_9|-;=PKAy~j^nZPp+LMR>&}kk z##4IWf>b+{Dz9+b14d`-ruU$!yxAj>P5a#n(l?zYF13|syj#SL?S3gI{mfy{*dgp> zQ%f}Tu@o3%);15YIlF)XP9FuQ0C5WRnAH!qVq`9HAy{4Rj$=A`^NHi4v2VS*@aSq! zbUeU7F|~HdxE^}4bmC$Wx@17v6vB8teRba2ol0D)gukcz!SCtiLYufzqArAP901RbN3*^82e7NjY7|;P z`;IGL4Z(kDjiIySGZ<&o4h9)-!3C2Y7@Z-wp1s5gA4>Z5Pt5xI5gak6DH_$x!69F) zuzru1tm&uGsPiwE51Qo4uSY%Lw1?d4at@T%IK%7?_HZuij?f$ItDTI`XYLbxD}T?8 zQ@lS_;L@iPvB57(POn2;Z2&Z*yx{^nNa%GJA~ObanrTLwhGkRg$U>8S<)e`HVGN&l zJ?H!Lv4Fk||M7ZR6 ze*}1cS|Iq+ZPCuf;AFi42rFgs!P?8dX<0jq%(jdC7 zSRNRiFSHJ~6?YT!3k@^lf$$3Po8w?%X;(%Zr9^byCJ|RdvoXiO&F2JhZX4+GVkW!l z5QNy?3-6&?=i9?fDN!V%z6qT`?_%ALTtaWxuP924WAuv zs4hO*2w#qvg}3V7!$%jwl`-wiair5Iek}bu7ygF$ou@swm3UV>zNRy0C+A#(O)w~R7Ip9Mr-wH15 zlsgLM~(~j%{s*bUoK(WZq3J6MioL681X39s~=C-T!%E% zNIhF*i1PxA*tug#P~ImGT5EMc?OSbO#K{I~Tm3|#7wm*bHrS@uBX3m6iAS)<;*G>N z1E8ULlL^d`kHY^pUq%BLAU^C3=c@Jr@gnp6ev*Z~vEg~Uelg)==6Bg6(fi=5Gr^eL zufII5(PVM$@}=A$AP$E!*I4-k(|)_+*`fPF}QMfe+`0nFfCK5ofw4%XKl)Y!+{Fu^^U&+c#ti354j zk}!b**wMQy(we|avB$66m;ao+(EVfC9Z7iOGk-eZp>tl`aPJO2{Pt<#6%@h$Q^R7p z(~MQP>1qR2Gp-ju+}Q`uHc^ngokE@p`(GY{#CJ#>i^4Zg%W5Z)hpFYbmS6PLx5ExF zyGUDg+j%Wn_+Q~=lcO$xbMaPoZNe`6Q{oDLgPBHfjlkXHh=pCM^Y?YYHY+D#Ze!wL65+QV$moa3|>%(`d7%}e$1;?NBC zCh4a5e^v0t^o1vs740vh@Dp=}`eTolL-_9r>B`V_TkMsY39W9&fqi^`niW-G9Bl4A zg*R_X1ixeMn{M)hFXb$2M>KTGy)ET-?8JqxZ^_<|pL$P4!wDtq#v?Bj&lmePblPEw zv`}SJMgu*Mm-p!bq)94Y7mpELjX>~bP3mk2ys?F8M|5DcW|{l}kYDB{UuC?yxifhk zM{qJSSIKLzjpMh{bp$iVs(Fwc9D(~MpTmd4wXpB-;j);QHPtI?HJvgJJV>} zfaFnxrZJJ5h@;6*)q@+aexvCJV?5G2#ud)DB~< zHK*bC(eWzrRIRrm4dqI42?&ox-c@}$pp#}wnK{htZwfoJ@(ACDps%*AdcZE7^4cgM z4yg5t_VJe$&A9WDn(-28Hk9+*yF<|6OPZtqZQ@I#jzQ-o6Oi!ApUj!eyF73~@+2sH zt>Dc`R;7$yUnZ}SLwiWN{!0-aed!&0p!0_do*^uKYc&_XhP-{P);rDk!AZwaH^Tra z?_ne7r<}M(xzVh?;AMF5HXXg&cUDE7I`_MjaIph>4)i7bSs}ew?OAUA(YiJplK&C2 z$LXx)!CSpBV2%q0EPbsI);#VR4?y9ujpwd_hvU1k#G~zDOs7lme9Tx~ zuKiD$R_8tzoVJkMH}qiSn`GimCj13q0%<>q%RULu!yDNGo-Dz`Q-&PnWq#FJ?H zEtTJ}%a@}2AHtx~8>HCR=LJu6fA*)3OgRxc&u)Wb7KQVPNs&m}ERzla@s1?&KlwR4 z-h6`KFq!-*BOO!BP8@Y7ZKO=F08bBof_sMUVcqwp!JX&1uvl+B4v+oB!`>vK$fL9t z!Q-{v!LIUQczAJRm1e27kDxUPu2CDkeF=SsmvM2fA~VsNu$#*WB<-Zk70x@`I@QkC z@YsdCpWhljrO7sq746`7gEo}wOR@9pRCii0&b!hJ{k0}gR$n5~+;h?mMqJB9ev96o z#EGB46j};@D$#4Og_h2!{dkfR`TY{j$S@$kf_9Ur^K!kZ_t zEOGj2wySNiB61k{N%^he5%8N0l;3M=JtJin$~)s(yVdE`JM59BubM6!`ouz#{U_LO z>BUKdAS@}B{K-v8=zX!}B&4iE-Z+@@Ndl6_2zBT z9t*8U;uoaMsKh1~iM*!x3|o#NBQ#Rl^JE_Oc@*h?6XAC_`671c+(rD;@FbjSRYIA{ zAJgIvyAy6T^jPM2fU_QxEz}GBRra%v0XTGWVa8((Ydt6BcgkY6)NA;QPKH&^kYRFA zDs&0cs`bDkW2Y#$kaDUdGRoks7lAa(eZR?enymsQ?qgR*S(aI?smr_e4yN-xk8WI7 z9(%wCgm)OeC4hAuwG`;Q**sB@t=HiTMQ1~WF}`Waz~Mi zsgpwQN7^bI+fuc*885sh&~@@_De zR!aN31{=Md&K;!JNS=gz*EpHH9Q#<5gg4r+V)C|GNbBInV?)($~w(m#Fq#u|T~#8yFW$ImCMP|uf=gFE^M}*sv)LO) z)1M!acodgu=aRPf0^&ZN`$VhO!wJk0&q?;t4P@#z#VisxhjQXI@(=Hz$@+olx-b@= zoV>?rZm_rgIZk?kcHV787X+dUH3@7+9nw!J<7B>DzXTQM!w|7Hkv~OVV3Z?~Iz-I3 zZV$ry_w3~W)b{}Ckwko<&}%c1<0gQy^XF`|OHNFAV7fSSe)n`~Wi~R1*!;;zv zUct_eQxxI*X+9Lu6J0WJg-#iXj5YXr)9Ok^bOTZi4r>jL=(Jp16AbjctKKllz* zDc8%?tuf*e;d5b?nV+m!6rsRjBl{yrugOQYO2qPTcXlOpIE!j>fRT=ZYnHi8oQq8p z4%K#0ujd=Vs=Dcv(}K9jMzO=>6TR_f6ss6~2!!&PEMBDKa{v&W8HCSBx?#^*((>m&ynq{Kv=V+_+g^q-y~>Gpw8#0&kK= z2#+NEgPgg#l2LDjXKg=oAH%iOH|6mjmU{%>z}n;M`Su4-*sSGA!rKG&M%3qR1Cc>p zvvv>{wC2=V;I!W}*?+C8f zqahtZK?mC3Q!OMSjd) zdmO$Sun2be{Z=MM@4>3>hp=MqTkLCANA9L|kQezJ0lw@md~G%v*3|QcXK7Es;#vt0 zlWcKR%s`mZClczThAoPT#&Oww;9a*g_-4PI57)Jp8&1BWG#n6!2~Hurg>R8kytEaJ z%PxSHoAQ+oNp8$H_A)Ecv4Ne|W3Z0Y3O5bV;Fdp6Ic~R6>$O*z>And1eG1sMr~Yqc z8AOb41`S<4vMmWq`OhXt@l|kVSTl#_W~MG&a(%{4=V!1xy7A)rS)GSn@tj==OjD{f zkLuslXwu7MDQ1;|zIQYFvlIl8agCAQ@U~s=+*(h=Xk3KO% z+ITkzKU4?s_q%sv-i)QJBFjX*KZBtqB=8HpFY-m}H-PRvU8oE>$<8P0$j-NFV6puU z-tx%_Uhlmn=pVlW%Qe*HIEL~Al^1Al?cu`9dl1@U9K3m3inI541MLgcop*}YcmBlb zY+;kS0MDPf3=c2eXTR$1hB;YL+4klA4`ULE8~FH z#OH=|;(nRtI5Ewa?VIu+KXc$M)5;BF&HfzW_50V6d!%H7X0bguR8_GZdGGj}TcJ2U zwIhmaQ1ef)VK*YVd}aVUX~v=GRAa3Ewi%16I4t}1LhKpwS{$eM!wjEtSeH{DKGt6c z2lHRyxaBrj+U+TO>U9pBb?!nxog_5)bsXue)hAnLur$v>e34;)cL&=W@czg|_%$(-~r6>U41H)P-OC90rSG8j8Kc>lYl+=D1B>|0$$}R);}nO$_d0Ht2b?H%zKq zBK5zO53Hb^H9lqw6?11p!quj-SCTeg=5Su|czGYAjeFv~ufZTNB(Pv#VJ_3P@Ci+( z!4Pv@Z1UIz#%oqU_Jt{golNPU@efwuXTtj5>WN{;M`3(iJqR_lkgs{KfTPwe}GogvBH_#1jhrht5dR8p2z!!PnpVFcITiMxNr7&XGJ~VyQnf@&f z>6%n6r#z$?LOsJ>nmV&L!{9v|ao$QRImGub`@C`=D7}ictW`-9F;{%g_SYVm-4N(&Mwu?=3Rq2y#miX>;<8`Uxb*cbySCE``t|86Z#_GYaPNX!(_4%4hox6; zAwAdab)N`~Y!A3s?S?65yArNU)ML}W(;9!Xs(}@J$ z_{pqzKbsf!oJpI)VEx;D@S@%y81mwv^mEl7Nb2>E_sjXo={oSq)>o`J_#HcMk;4kT zexSf(X{$X7y^pd!z8b&Rbi#Mj_0_ua2u|w~SSRjXC-@bgRq3hK-?ZpV9ME8V986z* zO)A_G&s}!j!2qpife;**>8<-*P`xO&4c zF*}$taI-`_#s^M6%g5OF)6$fAxt<7~G17|nB|Ku`Ge+|Y z2~Kt8qACVUYdS;OtWN@WvT;3{-{>aY4;pQQd6pg7>e!tqa4N1hDI!(y7B2dxEx&G@ z0fKA9tS|QpgM9nH5caSOuDBDdl&rXk`_9dRUU#acPG3B0x%hGKD4emoF@IiYFLXj` zxv3`?SnX5zMDp}Am1o@LG)r$e&4KJ)ww-->cMZLl9)O;o!zJSRTHeh#Rf0MGfP_DO z{mWK(c*X)|2W^H%>ns@IK%PB11^*eZg5j30B%L*l@ye(NK%53kw>qoD-^~2?QM`T7 z8E97J?Q0vj6F#vG2EfMEu@ap*=AGO__|uaY&ZwjA$eV*Ul?5!$<~9)iAkFitbmo%@ zwB2!=_-k42)_VbLp5+OkIlzwUW?U?dMO(Q!ZgU8SN%ueTyIqgr{epIyC1YB;A3UC* z%!ykK^mmxgmw@)4au66Gjlj7fr-+jt!RxUjnO4{_6z`Rtu}|m*#>C68X#F4F>wE_o zc=s#xvbDk%1GSOPu$BX7R_g+327Bx6A?Aj^&dCR{_aFCovCp|Zh;O^{edpdtzbZDM zgK;>cS;8Yr9|^J0M>;$DjM}cTI;Q5Gm;;Hp2*sZ2e!Rk`jJ>Ti`{MwW<9vta zwkZ5W>m?Jp;KIB8mVp%yf-WO`kvI>d{!S9SffHSiVykCQ+10Wrbp8~DH2-*D>T1E= z+#`FP@M38G+X-#=hhlold+?cBpV1ybqo{GDty(}@kJj7A5T{&{+Mo{0F)>iS>RZZQ zqb2}d2kx7)mhBm!BbOXFg+h;K9C^%6UfT`Z(GAG&kstZWG>Z=*X&Bm`9L9I;+eUo3 zK@yxFT^xsJTkc1~kI+)$)#k)gTI$s@Q@N+q^UU zblF7|JlZS)V{a^gY75Yr8 zY&(2xvlE=+L*VP0frPPq9^$(Z!kjz!u^oWQa-=s13#m42S(iPfd$&h{Jvcb>$$@kGvf8tE${Z=u!ddvz`IM) zXl+A*UGDl~0`W`+lJ~(zY7W-CiPH!^6Mn|rql|B8>WJm5hN55Q4#XP;eB=y-3#B{w zH+wU9+GHwDJJen^VM*+R!2s2@qmF#}c`D@iBapvT*KTdBM9R$(CN!LX!rR%d8b6oX*E&ViQ+C)(d7&I_hI9tvF>Jk%g8@dYW4m(?G7Y?7-b_ zuf*uh-{`rcK=?D7F?QaiteF@MJ^Mbig=bjBkhJd z9X}&IPG~ZJ-N#tI*~?tjT;IiorWFUzp&XROy?Qo;**`KdDWVNpuip+Ci_%!s-^Z{t z+ZaE+j%UJmZ~QQ@7AN95uif&*N~h(*TgX#_yU>iA^4URJDqRRXO}GLX_M!W zdwL+9VO?2VJ7HOEuJw>-UwGF2)vH7xtSdBo+@sWjykawz^hAl8R4nBONR05vhxW1n z;&*{FHl}8dlJ1;Md%Kx5Vjd&@Q(k`0f{_IdGQB4T%sovWJ6hlpPA?q?Y3;WF`D=FK zsT1o{I2mugu)}r5J|OtNYeF5=-y4PL729!3m=haNr;L%#N(oo@f=;s|T;Onnw*in} z0l|f{N4KiYEsj1@&{<~;F3@kH;*CCVGH?<4Z(M`%{>y;=jzX_3t_{Z6YlA$9gL#V< zcD!GoLMZL#3&-+Klb6^DlvSjnQ!&J^PWY;AVXa>$4^f+InB6rgJ@NpOCac8bC_Jy4 z-;wfPpg0pgx9>JA@QcKX%cEeW)ng3bw2IMMIb~N_%n9ikQXb>tItjNp_q8?Gt6Ycn zeqP{~?+iw3qv3(!KTT5;Ps*5IC6muT+40XkRo(PdCND|l0&k?LKz^F?O$yVVcn#OG z(a<+g1F3Qxqw7S$=Y`aLOajF*tYXqESa(m$2B_5$@8d(ODwVVq}^=8gzMnmbvjaB0@8EQ3*hfq zbFp^(y|zDH^KI#o{UvxY%1)vmo;Tb{_$xl8j|uOTVWkoWEB?pEu;Qf6d`O!C@;TFB z;l1!ze+i2-y|L<^n@rpZS<|aoqp17%d(b=-*y^F(N`AB@lee`_=7cK{8Zu}^J5(!; z)K<@Yx!?ghV<4>nfxW8q2c%v9DdTS%sr?U}hHlB(ltCSku0tKzs;Ro!-7fcpZ5fvL zUqaqSPyW>I0wZq;b=G_H7lGPp+YOPN&IKBO@*>Y;Avloy?M8mXJbX5zv8 z8K~Fd0hCWO;MJE>(Er$7_V=grS?V0%{ap$7HgCQd&SNr==M-}VJ}Ykktr^-s zpTvnbP|OSY4T)xq(Hs&VWdogwEId46R3;vQ+#a#uw{I<@Y$UZCdJQBas$s{2#c&xhH_ke6P3rT%8V$Y5gLYjZlC3yZyWRp1I* zPih3CVv2douPYStTZ*yuSlY{roc5b?u8~YQVsrZ%iJU8~9h(589dhF0NkBTHn$C9? z{)dr&0m2oSZ0HN2>%EY&5;iRRE73kHluPARYp-yC#cDX*fWAure)L|s5lu1={yyR0(&r8fdw(H)Nf2ArI6Qns?Ml z*6YpWlV(gketsHmTy2R>uQpRhb)PJ>Q+N!d_xYc0?y!dZ00maLwLzbiK1iLIhH@2u zJ@%H!X)1McnEK>6<&kFUt&mx@eO2MjUT*EL#UN#4E^-q2Y;3jANxrD(!+JEDBOXtE z%T@BJ9~faw6`rdqeLS4%d4-9bKg!OMHEf>BtrkDQmziedZ)f6-HvQ!{Z-=olFH-~_ zkaUS(eRYR2&r)niUA&yATie5t@8m-7sJnq_x9X|Vrg|uHJw2c9qtbAZpTu4--T$3a z?eD339k7wNch7_Z2RMA0iyhebusc`k~nAbh~&NI^CofiQJfv+SgHP`t4w5b+6zv+ z@2G~DcvgIwHyhEO+wH)=nG4*v@sE@H!yDb}w zHRDTbd3f}bc4p!l4?mkse)S8G zMso7ac%io!-)7 zF;xfe`3rAu8Dis47cpY|J`DWt1RoWkM|$u|`RIP1rAOMS?RyT?2yaBiApHHA!IWP8 z;P#CO(G4krXDq+`P~skEP;cB)tu`#;)bp~HSc&BIg=gZ_*N}!Aal(Nj`b5z)=Z~BS zXFf#&%|~sYDdsyVdA>XG7*eN!(a&6PCx##X2YU!LioO(Q~wxI&|A6Y!!6|Ck$8) z=O*rwCdxzL`=>c@U+XAb2Y*O+EQ1}v8gBU2h#MH)hdOtPU_|K`Hnm_SE?Hy-1AC2S zhYFMV;3YrcYTJ(dZFv(saBn@#y4euwoIeHo2fTolzZbBG(UV#KmcEetd^-%Dx>sl|yx4yT_B$Mc#p|<`0|(l{KC1%d>a;4%7_1LH zhDPxh`|aVK!6n|TB7k+Al80S2{lr@E+=f#4L-&p?**+Lv<(vSe%_Wv&R?VArTL`3_WT4wSI~IK2RCRlI zfbZB>s>ED5%03wC@`SO4%=>W~{}NvY9&;XoJoP>Q?dgHEcW@xeT+Y?KhJ88 z)rZbk;Myb)SXFsR2|SMI{y7@VW+XsRQVcFTeUYUZ&}Rci-SNN96Jb*Ky}bX2Fj$_N z#LeeNp~>o>>{Yq}W*^gH59;nmoxPW8A6LKE9Pk@w!LQVq@Z&uW^7ht-Qd+eqws&$S z4A?M_4NKXHb?>00V*(?$)(B0juz#6Z4hJRH`3)`dlbWAR%5wm@eObH7DENc~?pc*!b!uVcZIcV7hY zqq}P-((9@2lughMh4jXf<^grriuXbZ(i~=<7P9y^xHZO;qkancvCQg==%C3oOH`Yhhenv=IJKf4ubQ zD0!WyANygrl(76BX+LUl*j;xq>Typv6INFyd_i3ArSQi)g9j?lnO67=;?W$q|8W;A z$^FF^df7vtUH;hM!bF^wQ!4gZCLF-vc1H4}Zk$EO8A%QAR>9>jEnon914mmp;eJ_G z*VW$!KGDZmd+Sj5wQa`2VeMHE+H?|@j2?{dnm@sYDK_d~$xwaLCm#P!@Zx8V_JJM~ zUD463t-QM4Fc358ykaznSrD^j*x-uvU9(#ueBj3x`%vx>gU9-gLs}aPxpN5Ky$NO4 zT+QTq9er@}vdQ?tFc7D%3x~Z`jpV1kZQNaN4a0XX-{7fls$zTcv*f$>I^Hub!+@P# z#p|<9u|tvYghMxtgQ(ttFra}3D#rP;tP&Yq7A0e{zcY(|`2%!T*Q0yAf*`&74~(#_ zr}n+>g`;2f!6z>IaABvNqN&$Io^dJ;io1@)i;o82m-HWW?QgL4srIy<`aFMH5TcER zbTV`x^YyV+>3Q;Fxjk-qG=ZlqH&nZv9#cD)^W6&Igx78)JVIAZ5!=4Y7t}?cq%Io* zp+)y}=zh&g*1JB24PSE*6!Vizop2Y42QX!S4hzcc!0vx+fP+R?B4LL1$DZ(;uFUoL zi9Lpzs$0KwR42Nm@fRmm?lK{adtcsxPOZ;k%)fsalzANKaop>C5;TbGh&}eKCEUEj z5v5C@Jme<+l|r$=^*;;(tsDghO z_=h*QYoJMcvkzaWFQDe331***WyuSg$X-9HSxdiY+}L3YB+Q@4AC0MDpJU3|di#bl z%`J?pI>e%;dt$S0pCIJDj1R54z){0!*f^>SHmSm(_tO2)j|SY}s(m1yoD$C+9IKRr#djpj zVnb@8uoGWUBiTqD1&&freE+A48EEyr%^XRLP4+rZ{DE@q=X ztYHC%ioxNVwdCY>TXI=ss}jG;gfEEq$iy8887{~p)|D#9)R5QIJ;aY3y=Id1p;>!t?T39{+9*TiGO#(#+8fNrxEEe{cJubjf3k; zZnMonK9c9OLu`^=04sd=nXRAjmM!QUrOAD6Mp%4=yP~bIN!{!8oSt$=cXJ#*!&uh; zyHZM8dlrs6Mc}99vvAAdx{zDh4JOW5kIo~T(&yD<iXN1;KD)WbL;Bp450-3bSC@`xKR zve9M78Sy$2=F~oi>As4)RkU~Sp?wc8;>7{P;~{)}tBd$##(1F_sBu~^&1>C(I5rW# zJ@8R2(nfO9Oz8J=FHYHg7s%Ie;f0QNI?rxjID%EN%~dOd`bau}jSKeR&;fO%z3tvB zq;vdW@n{~_bp|YsCm%-GywXx--5&NNPvH&bo9*Gifi!3pV2u4g)Wb=AdJwjKado{? z@h7yHzmcbGP2z$->AjU*n}ay<0K6^V!&h5=XCv=Utv%<)_cFMcqb@s|(RT}|3m;m6 zn(v2stAGd~&X(4g=X25?7}C2Tbk6(3NslGc0jOKrK&5l%g0IPkK+$7;S?J;q5N$Jb;-OuPSe=ooPPy;L6_>xM)g8m2{k4<{VyD2|+=kiRc;6F-@|mLroHR>lC7b=L7kz(m zkkD<7=AE`mYe0Ea9}s5}-&{++I&>n&JUNfzn#McN7yboj#OpEJb^A$68zX5e5)Os0 z#sepZpvRmEEE+Gve9I6{o(0T;AEMCACQnmr z6rKR=={s!2tMSOEa44PEQl9(Gm^qK!i{$@mIgUJq;2*VjQd>FnrGZR-3x-TD6C4We z<~NXZmd%I{!Zq(K$ZzJsM3;8N?Q@W>5kLZ=lZ<>3;dVbN<}MJU)?wu{8>__6*!B(b zAK%8q(v|*5XP}Tb1i{7ijT^x_&vGQL5nKs0$JnewhxVzxY`U~Sa5k*ceX4kl-prow z&cZdF|6*0{IR3KWmPE78i~D_W@AK#*C-0y%eQmDV^xeh;&bw^Tm4%lPd$+bpXI@xz z25h(XVotjLKp5ksA+lqqmq7chRMKbPLf2@2;neI%q-#|!UT+Qbe142Nh_1&{kZ^}F z>K5T6W#NZzMb{(#4`#myPY09LYoO9AND`j$S!5s3noFPCIwjzICu^0w6--Z0$1hX1 zf$*KA3ygd@lCML*$k+UX<2kgz4>0Nq;>nmLKpv9LCJ9SdU*uPtZDl{s>8o=-?axi< zz7?JhGy$po12E~`286CBE-NAYpTYHABs87eK|K?DoHxGbj6$=j4)&sRu8&EPw=pAN zJU?~nCsqa4Aul-YUo_p_h-G;1tYEsOhb;@a!x)Q zE+}P;Jh*B+{kbN>HjFg$j^aSQJ14#%t53*=Y9(u*qaIuTBD)JpzzpX zH(EIqA#pH#4(RW3xZHuK8I0uu2jo#$gSZpqVJaxc>YzBMn17oz0XsA4y%O2P<~y*y z?J?5cWQj5dCr_yAXVvF(J=V~gt=&k+q|s_)Bz)oa$>Dgl?FkT>>Cfzstno>IPF`6) z*7rW!KDoZ?lQoC(lqqG0kIL1dk4P66qyFFTZ2lnn%z7~+@6`}_lVqhZX$v$uVg^|n zd)1?BU3by};)Ns(JPtrQjxV;FV#sAJbyD}HEK74pcqZcQ4cMs61JGtCzxUG}ssgu& zYsc7EXN2~G>g|a&4rAL2wd#6+L%oUWZM4Qz{3jx6Mq_kg?yOHFUeg>Ktl?aKTj zXR@aMrHPruOE0)Y9)VZ5wUSAPAYou4jyK8Rq<8$2-!OT_vl$RtlgRh|r?0lt=__~! zL+C!QL#OKT!p!Ded&yp){Z=WHYs&5L-h!@`~p5Y@Y*KEfTX&2BKGWl^QW2vSx zmHl|zl?zOYEa&?64Hp@0SK4XzDyJiK4s%ewE&gy?A9Q6Xj0)@D_ z*6WE3PaT6o9!4d8V}w(|;}|*2fsuv+@gFO05KDPS4`N?+IZK%Z9-2h(cBb=qWmF~5 zpFr~^`WUI;_)+HPPoJwkh+(2PN%S#OiSxK_(QAIYW&^ujIusH^rUUslxpcJ`6eiV? zDX&N(7g9bTPu`Ba8TBfu8?mKdLwMAUQ4U#yO{`v{>43(}*)yP4=cx?KdK(^1Kja*Jpc z$%jkdPPZ0whg&``#gyftJgxd1`MJG(ZVN=pu!`{Y74(fK;y;D54WDb+2B_;mas9L} zDy>(YyzGhKe0P(1Be=-*blyl^5Rk_qAM}^4wp_`Cr>Bf8cvxzCGgYC?jFWvYu`lVX zB=Q$5u9YhWUw0);h2f<;jUe;1E6lN@o^qQRP_My^TzfOh5Hxp(*rPG4IbEaFsO&6K z_XCc7*8yc{q>R9YXQu2-`)e;aQ9WFK1bXf523|jwlYbrpqBr_O-+!mvgZFJ`vtwp; zfHYYq{IYwmQh0QRolJO9NdJ*~4OQqJ^-ti?c%nP$9i!YQf8Bl*T_)6%uQoG)(ru3+ zU}FolVvq(mNKR_tmPj@j4$ylSkY@el)MbEA@(m`k?#E7@RFP{(efY^K53Ba&#eB@Y z1YESD5dTfh5gv*>Oatn!Ol9)BNd6l{Mx+d&s1wHE*{MasAMjyI&%l_1x@z7W_gbBq zZ&?=>E*wey%1%c8mb!GMwmKrMw=B9M;kSt&IPD{K15d=>sf1BR9fCp`r?xu~8nNcIKL{w(D3W*p?& zSuyxvg|v3U+AJ|QNZErChvCbp%{WlMfzU_L+3PKGr$*#p^3J^b_sgO;V>^@cHO9`7 zFy@{$EWbDj%9jDtznsEWlw4zlx5r7eChCfuaq3B5F1(+3F4(uBS#QqiY_aZ3M_K6Q zC;BdY#UOXg)g6M<=pLGaz0r(#lY8{D29eXtHs$cHd&1e+r%^!MOC5`?=mo`Viq1kM z9gOGV6-}|}=ZZ`b7O|n>~1+ULijx3r9Ru26@bb2EH=UmzfqHiHS7P$tO#-0_vijxL% zp=Uy8oA)WzbS?1}9uKE1Z^i=mwjr)Cl%^JZXWCm_WN|IzIbqGaof_fcsOLq>){NFJ zd=+`{&1_wLBTk(r7k-bh4Nh*O81W@u)LX+SLks;Bx`m-lA49s19g4Z3yu>=M+=)T? zBV@|8jCxwW^TaG1UDg3e_h~;9#QXqdNt{1uf<(O)dGZbJBQxmxiRxkS48H{xt;1!~ ze+ENta?*6+wHfI)&^}RrLif2z`;c@|rCva$?wn@V4hi?_^mIps@-@#qoyAB~n7dwk zm2jkxMq*gYYalWzX#>7DbrbyvkZ%UP!$*+zM&J_n7*AE=T4gc9g8H9zA-DF3QwT$7 tyY;-1LicQmbtla{%uL1r=~*UF?!iRgZbFvzS$5wnU>n4833#hSRb`FG!_G*)5Rc7@ zhn3owJR!AG-G9|z&|lMUd;t~~)Fvq}7&G7feCNzBpZt0J@F@8#pE=mOCJb}0Wcx&f zen3myGhOQX!6IIM-Szs(V>|mZ{v$s5HNK6PpPDu&Av_h9&+E17X}O{`Rcbe>T5nDp zYSSW?Iz6pv6|GjTTC@tj)}aJp)Wz}g0l1A^>V>)6!t#NcI5y41^}`vyrY7;dP#${q z*T+Xk2DTiED(7foj5~VYe4)2`y76K-D5WnhyGE&XH8F<$&e&{s4840fL4}E|Odt|L zi7pKg_ojZ};*k5^iJ};c3lt`fWi03h{A3i(XdrwKJDbOmKV7*I=!W@-;WqFBNbhF z{7pm=x1oGK} zd(AHKUaL+Yy15&u<6`b0Y}?RcD1AHD2!?YfJt$d4Bj$S)d1)hdZ1i?}+?h;vV~yBC zZ~`OQ8!iP%^&mioqQp|5nS%+nk+oQLS`-00#EER~T_+>3 zaYzyNY!v#Ttitsp5A3IiV!~2C0G2er_eV;a*J1Eoz-I7kPl6A|vVv*0I)&AXZu>++ zvX!rpgnoj0iL4}0o_AX}VBT$6shWP>Fq$qz+Tt555>sFRmxdU6f+eSXCVW^R8ZW;{ z38LiMv51ymnI=Hrv4tsU2*mJPTwssb7zW&m;AxJXh>CwUBx!F{4xj|GLSFmQ)rYQ8@f^E7jm+~e^s6z?o)Azz6Egt`zu;V^hREi5k-WkJ??AHZUE z7UgJElzr3IS6eFQyo!-dx_)@l<6--%wK?g)w~G4zNet9WF>#u;g~(XjoB0SR7C8XoJ`oax3T04>x`Zf440Q?S)0OL zrh(OF{j{NKXSiBDt&+xBgY59d+vEIRq8$W-G1dnYy=8QVgZ1i%0rmtY@+$I9>(4!# z&$%6enryg|7Z?n}LP!+AEd}T|7k`kc?3#kIhU|zG-de9>ezJ((L>R6j;TCQ-3i|-q z*(~FKWs8kKYn<-#5?oeROFQ?SIj{~-uD)zK`C Dict[str, Any]: + data = self.loader.load() + texts = self.text_splitter.split_documents(data) + vector_store = SupabaseVectorStore.from_documents( + texts, self.embeddings, client=self.client + ) + return vector_store -text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20) -texts = text_splitter.split_documents(data) -embeddings = OpenAIEmbeddings(openai_api_key = st.secrets["OPENAI_API_KEY"]) -docsearch = FAISS.from_documents(texts, embeddings) +def run(): + secrets = Secrets( + SUPABASE_URL=st.secrets["SUPABASE_URL"], + SUPABASE_SERVICE_KEY=st.secrets["SUPABASE_SERVICE_KEY"], + OPENAI_API_KEY=st.secrets["OPENAI_API_KEY"], + ) + config = Config() + doc_processor = DocumentProcessor(secrets, config) + result = doc_processor.process() + return result -docsearch.save_local("faiss_index") -# with open("vectors.pkl", "wb") as f: -# pickle.dump(docsearch, f) +if __name__ == "__main__": + run() diff --git a/main.py b/main.py index 0158c4e..b1345ab 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,3 @@ - import openai import streamlit as st import warnings @@ -6,42 +5,59 @@ from langchain.embeddings.openai import OpenAIEmbeddings from streamlit import components from utils.snowflake import query_data_warehouse -from langchain.vectorstores import FAISS +from langchain.vectorstores import SupabaseVectorStore from utils.snowddl import Snowddl -from utils.snowchat_ui import reset_chat_history, extract_code, message_func, is_sql_query +from utils.snowchat_ui import ( + reset_chat_history, + extract_code, + message_func, + is_sql_query, +) from snowflake.connector.errors import ProgrammingError -warnings.filterwarnings('ignore') +from supabase.client import Client, create_client + +warnings.filterwarnings("ignore") openai.api_key = st.secrets["OPENAI_API_KEY"] MAX_INPUTS = 1 chat_history = [] +supabase_url = st.secrets["SUPABASE_URL"] +supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] +supabase: Client = create_client(supabase_url, supabase_key) + st.set_page_config( page_title="snowChat", page_icon="❄️", layout="centered", initial_sidebar_state="auto", menu_items={ - 'Report a bug': "https://github.com/kaarthik108/snowChat", - 'About': '''snowChat is a chatbot designed to help you with Snowflake Database. It is built using OpenAI's GPT-4 and Streamlit. + "Report a bug": "https://github.com/kaarthik108/snowChat", + "About": """snowChat is a chatbot designed to help you with Snowflake Database. It is built using OpenAI's GPT-4 and Streamlit. Go to the GitHub repo to learn more about the project. https://github.com/kaarthik108/snowChat - ''' - } + """, + }, ) + def load_chain(): - ''' + """ Load the chain from the local file system Returns: chain (Chain): The chain object - ''' + """ - embeddings = OpenAIEmbeddings(openai_api_key=st.secrets["OPENAI_API_KEY"]) - vectorstore = FAISS.load_local("faiss_index", embeddings) + embeddings = OpenAIEmbeddings( + openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" + ) + vectorstore = SupabaseVectorStore( + embedding=embeddings, client=supabase, table_name="documents" + ) return get_chain(vectorstore) + snow_ddl = Snowddl() st.title("snowChat") @@ -58,25 +74,28 @@ def load_chain(): # Create a sidebar with a dropdown menu selected_table = st.sidebar.selectbox( - "Select a table:", options=list(snow_ddl.ddl_dict.keys())) + "Select a table:", options=list(snow_ddl.ddl_dict.keys()) +) st.sidebar.markdown(f"### DDL for {selected_table} table") st.sidebar.code(snow_ddl.ddl_dict[selected_table], language="sql") st.write(styles_content, unsafe_allow_html=True) -if 'generated' not in st.session_state: - st.session_state['generated'] = [ - "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍"] -if 'past' not in st.session_state: - st.session_state['past'] = ["Hey!"] +if "generated" not in st.session_state: + st.session_state["generated"] = [ + "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍" + ] +if "past" not in st.session_state: + st.session_state["past"] = ["Hey!"] if "input" not in st.session_state: st.session_state["input"] = "" if "stored_session" not in st.session_state: st.session_state["stored_session"] = [] -if 'messages' not in st.session_state: - st.session_state['messages'] = [ - ("Hello! I'm a chatbot designed to help you with Snowflake Database.")] +if "messages" not in st.session_state: + st.session_state["messages"] = [ + ("Hello! I'm a chatbot designed to help you with Snowflake Database.") + ] if "query_count" not in st.session_state: st.session_state["query_count"] = 0 @@ -84,26 +103,32 @@ def load_chain(): RESET = True messages_container = st.container() -with st.form(key='my_form'): - query = st.text_input("Query: ", key="input", value="", - placeholder="Type your query here...", label_visibility="hidden") - submit_button = st.form_submit_button(label='Submit') +with st.form(key="my_form"): + query = st.text_input( + "Query: ", + key="input", + value="", + placeholder="Type your query here...", + label_visibility="hidden", + ) + submit_button = st.form_submit_button(label="Submit") col1, col2 = st.columns([1, 3.2]) reset_button = col1.button("Reset Chat History") -if reset_button or st.session_state['query_count'] >= MAX_INPUTS and RESET: +if reset_button or st.session_state["query_count"] >= MAX_INPUTS and RESET: RESET = False - st.session_state['query_count'] = 0 + st.session_state["query_count"] = 0 reset_chat_history() -if 'messages' not in st.session_state: - st.session_state['messages'] = [] +if "messages" not in st.session_state: + st.session_state["messages"] = [] + def update_progress_bar(value, prefix, progress_bar=None): if progress_bar is None: progress_bar = st.empty() - key = f'{prefix}_progress_bar_value' + key = f"{prefix}_progress_bar_value" if key not in st.session_state: st.session_state[key] = 0 @@ -113,50 +138,58 @@ def update_progress_bar(value, prefix, progress_bar=None): st.session_state[key] = 0 progress_bar.empty() + chain = load_chain() if len(query) > 2 and submit_button: submit_progress_bar = st.empty() - messages = st.session_state['messages'] - update_progress_bar(33, 'submit', submit_progress_bar) + messages = st.session_state["messages"] + update_progress_bar(33, "submit", submit_progress_bar) result = chain({"question": query, "chat_history": chat_history}) - update_progress_bar(66, 'submit', submit_progress_bar) - st.session_state['query_count'] += 1 + update_progress_bar(66, "submit", submit_progress_bar) + st.session_state["query_count"] += 1 messages.append((query, result["answer"])) st.session_state.past.append(query) - st.session_state.generated.append(result['answer']) - update_progress_bar(100, 'submit', submit_progress_bar) + st.session_state.generated.append(result["answer"]) + update_progress_bar(100, "submit", submit_progress_bar) + def self_heal(df, to_extract, i): - ''' + """ If the query fails, try to fix it by extracting the code from the error message and running it again. - + Args: df (pandas.DataFrame): The dataframe generated from the query to_extract (str): The query i (int): The index of the query in the chat history - + Returns: df (pandas.DataFrame): The dataframe generated from the query - ''' - + """ + error_message = str(df) - error_message = "I have an SQL query that's causing an error. FIX The SQL query by searching the schema definition: \n```sql\n" + to_extract + "\n```\n Error message: \n " + error_message + error_message = ( + "I have an SQL query that's causing an error. FIX The SQL query by searching the schema definition: \n```sql\n" + + to_extract + + "\n```\n Error message: \n " + + error_message + ) recover = chain({"question": error_message, "chat_history": ""}) - message_func(recover['answer']) - to_extract = extract_code(recover['answer']) - st.session_state["generated"][i] = recover['answer'] + message_func(recover["answer"]) + to_extract = extract_code(recover["answer"]) + st.session_state["generated"][i] = recover["answer"] if is_sql_query(to_extract): df = query_data_warehouse(to_extract) - + return df + def generate_df(to_extract: str, i: int): - ''' + """ Generate a dataframe from the query by querying the data warehouse. Args: @@ -165,7 +198,7 @@ def generate_df(to_extract: str, i: int): Returns: df (pandas.DataFrame): The dataframe generated from the query - ''' + """ df = query_data_warehouse(to_extract) if isinstance(df, ProgrammingError) and is_sql_query(to_extract): message_func("uh oh, I made an error, let me try to fix it") @@ -175,9 +208,9 @@ def generate_df(to_extract: str, i: int): with messages_container: - if st.session_state['generated']: - for i in range(len(st.session_state['generated'])): - message_func(st.session_state['past'][i], is_user=True) + if st.session_state["generated"]: + for i in range(len(st.session_state["generated"])): + message_func(st.session_state["past"][i], is_user=True) message_func(st.session_state["generated"][i]) if i > 0 and is_sql_query(st.session_state["generated"][i]): code = extract_code(st.session_state["generated"][i]) @@ -187,15 +220,17 @@ def generate_df(to_extract: str, i: int): except: # noqa: E722 pass -if st.session_state['query_count'] == MAX_INPUTS and RESET: +if st.session_state["query_count"] == MAX_INPUTS and RESET: st.warning( - "You have reached the maximum number of inputs. The chat history will be cleared after the next input.") + "You have reached the maximum number of inputs. The chat history will be cleared after the next input." + ) col2.markdown( - f'
{st.session_state["query_count"]}/{MAX_INPUTS}
', unsafe_allow_html=True) + f'
{st.session_state["query_count"]}/{MAX_INPUTS}
', + unsafe_allow_html=True, +) -st.markdown('
', - unsafe_allow_html=True) +st.markdown('
', unsafe_allow_html=True) components.v1.html( """ diff --git a/requirements.txt b/requirements.txt index 7635bb8..a53100a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -134,4 +134,5 @@ yarg==0.1.9 yarl==1.8.2 zipp==3.15.0 packaging==23.1 -BeautifulSoup4==4.12.2 \ No newline at end of file +BeautifulSoup4==4.12.2 +supabase==1.0.3 \ No newline at end of file diff --git a/schema.md b/schema.md deleted file mode 100644 index 8e852c9..0000000 --- a/schema.md +++ /dev/null @@ -1,47 +0,0 @@ -**Table 1: STREAM_HACKATHON.STREAMLIT.CUSTOMER_DETAILS** (Stores customer information) - -This table contains the personal information of customers who have made purchases on the platform. - -- CUSTOMER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for customers -- FIRST_NAME: Varchar (255) - First name of the customer -- LAST_NAME: Varchar (255) - Last name of the customer -- EMAIL: Varchar (255) - Email address of the customer -- PHONE: Varchar (20) - Phone number of the customer -- ADDRESS: Varchar (255) - Physical address of the customer - -**Table 2: STREAM_HACKATHON.STREAMLIT.ORDER_DETAILS** (Stores order information) - -This table contains information about orders placed by customers, including the date and total amount of each order. - -- ORDER_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for orders -- CUSTOMER_ID: Number (38,0) [Foreign Key - CUSTOMER_DETAILS(CUSTOMER_ID)] - Customer who made the order -- ORDER_DATE: Date - Date when the order was made -- TOTAL_AMOUNT: Number (10,2) - Total amount of the order - -**Table 3: STREAM_HACKATHON.STREAMLIT.PAYMENTS** (Stores payment information) - -This table contains information about payments made by customers for their orders, including the date and amount of each payment. - -- PAYMENT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for payments -- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the payment -- PAYMENT_DATE: Date - Date when the payment was made -- AMOUNT: Number (10,2) - Amount of the payment - -**Table 4: STREAM_HACKATHON.STREAMLIT.PRODUCTS** (Stores product information) - -This table contains information about the products available for purchase on the platform, including their name, category, and price. - -- PRODUCT_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for products -- PRODUCT_NAME: Varchar (255) - Name of the product -- CATEGORY: Varchar (255) - Category of the product -- PRICE: Number (10,2) - Price of the product - -**Table 5: STREAM_HACKATHON.STREAMLIT.TRANSACTIONS** (Stores transaction information) - -This table contains information about individual transactions that occur when customers purchase products, including the associated order, product, quantity, and price. - -- TRANSACTION_ID: Number (38,0) [Primary Key, Not Null] - Unique identifier for transactions -- ORDER_ID: Number (38,0) [Foreign Key - ORDER_DETAILS(ORDER_ID)] - Associated order for the transaction -- PRODUCT_ID: Number (38,0) - Product involved in the transaction -- QUANTITY: Number (38,0) - Quantity of the product in the transaction -- PRICE: Number (10,2) - Price of the product in the transaction diff --git a/streamlit-hack.yml b/streamlit-hack.yml deleted file mode 100644 index 34b2e4f..0000000 --- a/streamlit-hack.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: streamlit-hack -channels: - - defaults - - conda-forge -dependencies: - - python==3.9.16 - - pip~=21.3.1 - - pip: - - snowflake-connector-python[pandas]==3.0.0 - - streamlit - - pandas - - numpy - - python-dotenv - - matplotlib - - openai - - streamlit-chat - - langchain - - tiktoken - - unstructured - - faiss-cpu diff --git a/supabase/scripts.sql b/supabase/scripts.sql new file mode 100644 index 0000000..cac92a3 --- /dev/null +++ b/supabase/scripts.sql @@ -0,0 +1,36 @@ + +CREATE extension vector; + +CREATE TABLE documents ( + id UUID PRIMARY KEY, + content text, + metadata jsonb, + embedding vector(1536) +); + +CREATE OR REPLACE FUNCTION match_documents(query_embedding vector(1536), match_count int) + RETURNS TABLE( + id UUID, + content text, + metadata jsonb, + -- we return matched vectors to enable maximal marginal relevance searches + embedding vector(1536), + similarity float) + LANGUAGE plpgsql + AS $$ + # variable_conflict use_column + BEGIN + RETURN query + SELECT + id, + content, + metadata, + embedding, + 1 -(documents.embedding <=> query_embedding) AS similarity + FROM + documents + ORDER BY + documents.embedding <=> query_embedding + LIMIT match_count; + END; + $$; diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index 881899c..2a716bb 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -3,13 +3,14 @@ import re import html + def format_message(text): - ''' + """ This function is used to format the messages in the chatbot UI. Parameters: text (str): The text to be formatted. - ''' + """ text_blocks = re.split(r"```[\s\S]*?```", text) code_blocks = re.findall(r"```([\s\S]*?)```", text) @@ -23,22 +24,24 @@ def format_message(text): return formatted_text + def message_func(text, is_user=False): - ''' + """ This function is used to display the messages in the chatbot UI. - + Parameters: text (str): The text to be displayed. is_user (bool): Whether the message is from the user or the chatbot. key (str): The key to be used for the message. avatar_style (str): The style of the avatar to be used. - ''' + """ if is_user: avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned" message_alignment = "flex-end" message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)" avatar_class = "user-avatar" - st.write(f""" + st.write( + f"""
{text} @@ -46,52 +49,64 @@ def message_func(text, is_user=False): avatar
- """, unsafe_allow_html=True) + """, + unsafe_allow_html=True, + ) else: text = format_message(text) avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light" message_alignment = "flex-start" message_bg_color = "#71797E" avatar_class = "bot-avatar" - st.write(f""" + st.write( + f"""
avatar
{text} \n
- """, unsafe_allow_html=True) + """, + unsafe_allow_html=True, + ) def reset_chat_history(): - ''' + """ This function is used to reset the chat history. - ''' - st.session_state['generated'] = ["Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍"] - st.session_state['past'] = ["Hi..."] + """ + st.session_state["generated"] = [ + "Hey there, I'm Chatty McQueryFace, your SQL-speaking sidekick, ready to chat up Snowflake and fetch answers faster than a snowball fight in summer! ❄️🔍" + ] + st.session_state["past"] = ["Hi..."] st.session_state["stored_session"] = [] - st.session_state['messages'] = [("Hello! I'm a chatbot designed to help you with Snowflake Database.")] + st.session_state["messages"] = [ + ("Hello! I'm a chatbot designed to help you with Snowflake Database.") + ] # can be removed with better prompt def extract_code(text) -> str: - ''' + """ This function is used to extract the SQL code from the user's input. - + Parameters: text (str): The text to be processed. - + Returns: str: The SQL code extracted from the user's input. - ''' + """ if len(text) < 5: return None # Use OpenAI's GPT-3.5 to extract the SQL code response = openai.ChatCompletion.create( - model='gpt-3.5-turbo', - messages=[ - {'role': 'user', 'content': f"Extract only the code do not add text or any apostrophes or any sql keywords \n\n{text}"}, - ], - # stream=True + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"Extract only the code do not add text or any apostrophes or any sql keywords \n\n{text}", + }, + ], + # stream=True ) # Extract the SQL code from the response @@ -99,6 +114,7 @@ def extract_code(text) -> str: return sql_code + def is_sql_query(text: str) -> bool: """ Checks if the input text is likely an SQL query. @@ -108,13 +124,29 @@ def is_sql_query(text: str) -> bool: """ # Define a list of common SQL keywords keywords = [ - "SELECT", "FROM", "WHERE", "UPDATE", "INSERT", "DELETE", "JOIN", - "GROUP BY", "ORDER BY", "HAVING", "LIMIT", "OFFSET", "UNION", "CREATE", - "ALTER", "DROP", "TRUNCATE", "EXPLAIN", "WITH" + "SELECT", + "FROM", + "WHERE", + "UPDATE", + "INSERT", + "DELETE", + "JOIN", + "GROUP BY", + "ORDER BY", + "HAVING", + "LIMIT", + "OFFSET", + "UNION", + "CREATE", + "ALTER", + "DROP", + "TRUNCATE", + "EXPLAIN", + "WITH", ] # Create a single regular expression pattern to search for all keywords - pattern = r'\b(?:' + '|'.join(keywords) + r')\b' + pattern = r"\b(?:" + "|".join(keywords) + r")\b" # Check if any of the keywords are present in the input text (case-insensitive) if re.search(pattern, text, re.IGNORECASE): diff --git a/utils/snowddl.py b/utils/snowddl.py index 7867383..c347b99 100644 --- a/utils/snowddl.py +++ b/utils/snowddl.py @@ -1,13 +1,14 @@ class Snowddl: - ''' + """ Snowddl class loads DDL files for various tables in a database. - + Attributes: ddl_dict (dict): dictionary of DDL files for various tables in a database. - + Methods: load_ddls: loads DDL files for various tables in a database. - ''' + """ + def __init__(self): self.ddl_dict = self.load_ddls() @@ -18,7 +19,7 @@ def load_ddls(): "ORDER_DETAILS": "sql/ddl_orders.sql", "PAYMENTS": "sql/ddl_payments.sql", "PRODUCTS": "sql/ddl_products.sql", - "CUSTOMER_DETAILS": "sql/ddl_customer.sql" + "CUSTOMER_DETAILS": "sql/ddl_customer.sql", } ddl_dict = {} diff --git a/utils/snowflake.py b/utils/snowflake.py index c6c861a..77d2169 100644 --- a/utils/snowflake.py +++ b/utils/snowflake.py @@ -24,22 +24,22 @@ def query_data_warehouse(sql: str, parameters=None) -> any: :param sql: sql query to be executed :param parameters: named parameters used in the sql query (defaulted as None) :return: dataframe - """ + """ if parameters is None: parameters = {} query = sql - + try: cur.execute("USE DATABASE " + st.secrets["DATABASE"]) cur.execute("USE SCHEMA " + st.secrets["SCHEMA"]) cur.execute(query, parameters) all_rows = cur.fetchall() field_names = [i[0] for i in cur.description] - + except snowflake.connector.errors.ProgrammingError as e: # print(f"Error in query_data_warehouse: {e}") return e - + finally: print("closing cursor")