From 17fe28a2ce893ad32ae1a40e9604558f9a7a85dc Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sun, 1 Oct 2023 08:54:05 +1300 Subject: [PATCH] Add Claude from Bedrock --- .gitignore | 4 +- chain.py | 252 ++++++++++++++++++++----------------------- main.py | 18 ++-- requirements.txt | 7 +- supabase/scripts.sql | 57 +++++----- template.py | 64 +++++++++++ utils/snowchat_ui.py | 3 + 7 files changed, 222 insertions(+), 183 deletions(-) create mode 100644 template.py diff --git a/.gitignore b/.gitignore index afcf150a..33b021f9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,6 @@ secrets.toml archived_logs/ build/ -snowchat.egg-info/ \ No newline at end of file +snowchat.egg-info/ + +chroma_db \ No newline at end of file diff --git a/chain.py b/chain.py index 6bba2b65..3093bd55 100644 --- a/chain.py +++ b/chain.py @@ -1,161 +1,141 @@ +from typing import Any, Callable, Dict, Optional + +import boto3 import streamlit as st from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI, Replicate -from langchain.prompts.prompt import PromptTemplate +from langchain.llms.bedrock import Bedrock from langchain.vectorstores import SupabaseVectorStore +from pydantic import BaseModel, validator from supabase.client import Client, create_client -template = """You are an AI chatbot having a conversation with a human. - -Chat History:\""" -{chat_history} -\""" -Human Input: \""" -{question} -\""" -AI:""" - -condense_question_prompt = PromptTemplate.from_template(template) - -TEMPLATE = """ -You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. - -When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. - -Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema. -**You are only required to write one SQL query per question.** - -If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. - -When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries. - -If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." - -Write your response in markdown format. - -Question: ```{question}``` -{context} - -Answer: -""" -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -LLAMA_TEMPLATE = """ -You're specialized with Snowflake SQL. When providing answers, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. - -If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. - -If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." - -Write SQL code for this Question based on the below context details: {question} - -<> -context: \n {context} -<> - -write responses in markdown format - -Answer: - -""" - -LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST - -QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) -LLAMA_PROMPT = PromptTemplate( - template=LLAMA_TEMPLATE, input_variables=["question", "context"] -) +from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT supabase_url = st.secrets["SUPABASE_URL"] supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] supabase: Client = create_client(supabase_url, supabase_key) -VERSION = "be553392065353425e0f0193d2a896d6a5ff201549f5d7cd9180c8dfdeac39ed" +VERSION = "1f01a52ff933873dff339d5fb5e1fd6f24f77456836f514fa05e91c1a42699c7" LLAMA = "meta/codellama-13b-instruct:{}".format(VERSION) - -def get_chain_replicate(vectorstore, callback_handler=None): - """ - Get a chain for chatting with a vector database. - """ - q_llm = Replicate( - model=LLAMA, - input={"temperature": 0.2, "max_length": 200, "top_p": 1}, - replicate_api_token=st.secrets["REPLICATE_API_TOKEN"], - ) - llm = Replicate( - streaming=True, - callbacks=[callback_handler], - model=LLAMA, - input={"temperature": 0.2, "max_length": 300, "top_p": 1}, - replicate_api_token=st.secrets["REPLICATE_API_TOKEN"], - ) - - question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) - - doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) - conv_chain = ConversationalRetrievalChain( - callbacks=[callback_handler], - retriever=vectorstore.as_retriever(), - combine_docs_chain=doc_chain, - question_generator=question_generator, - ) - - return conv_chain - - -def get_chain_gpt(vectorstore, callback_handler=None): - """ - Get a chain for chatting with a vector database. - """ - q_llm = OpenAI( - temperature=0.1, - openai_api_key=st.secrets["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo-16k", - max_tokens=500, - ) - - llm = ChatOpenAI( - model_name="gpt-3.5-turbo", - temperature=0.5, - openai_api_key=st.secrets["OPENAI_API_KEY"], - max_tokens=500, - callbacks=[callback_handler], - streaming=True, - ) - question_generator = LLMChain(llm=q_llm, prompt=condense_question_prompt) - - doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=QA_PROMPT) - conv_chain = ConversationalRetrievalChain( - retriever=vectorstore.as_retriever(), - combine_docs_chain=doc_chain, - question_generator=question_generator, - ) - - return conv_chain +class ModelConfig(BaseModel): + model_type: str + secrets: Dict[str, Any] + callback_handler: Optional[Callable] = None + + @validator("model_type", pre=True, always=True) + def validate_model_type(cls, v): + if v not in ["code-llama", "gpt", "claude"]: + raise ValueError(f"Unsupported model type: {v}") + return v + + +class ModelWrapper: + def __init__(self, config: ModelConfig): + self.model_type = config.model_type + self.secrets = config.secrets + self.callback_handler = config.callback_handler + self.setup() + + def setup(self): + if self.model_type == "code-llama": + self.setup_llama() + elif self.model_type == "gpt": + self.setup_gpt() + elif self.model_type == "claude": + self.setup_claude() + + def setup_llama(self): + self.q_llm = Replicate( + model=LLAMA, + input={"temperature": 0.2, "max_length": 200, "top_p": 1}, + replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], + ) + self.llm = Replicate( + streaming=True, + callbacks=[self.callback_handler], + model=LLAMA, + input={"temperature": 0.2, "max_length": 300, "top_p": 1}, + replicate_api_token=self.secrets["REPLICATE_API_TOKEN"], + ) + + def setup_gpt(self): + self.q_llm = OpenAI( + temperature=0.1, + openai_api_key=self.secrets["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo-16k", + max_tokens=500, + ) + + self.llm = ChatOpenAI( + model_name="gpt-3.5-turbo-16k", + temperature=0.5, + openai_api_key=self.secrets["OPENAI_API_KEY"], + max_tokens=500, + callbacks=[self.callback_handler], + streaming=True, + ) + + def setup_claude(self): + bedrock_runtime = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"], + region_name="us-east-1", + ) + parameters = { + "max_tokens_to_sample": 1000, + "stop_sequences": [], + "temperature": 0, + "top_p": 0.9, + } + self.q_llm = Bedrock( + model_id="anthropic.claude-instant-v1", client=bedrock_runtime + ) + self.llm = Bedrock( + model_id="anthropic.claude-instant-v1", + client=bedrock_runtime, + callbacks=[self.callback_handler], + streaming=True, + model_kwargs=parameters, + ) + + def get_chain(self, vectorstore): + if not self.q_llm or not self.llm: + raise ValueError("Models have not been properly initialized.") + question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT) + doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT) + conv_chain = ConversationalRetrievalChain( + retriever=vectorstore.as_retriever(), + combine_docs_chain=doc_chain, + question_generator=question_generator, + ) + return conv_chain def load_chain(model_name="GPT-3.5", callback_handler=None): - """ - Load the chain from the local file system - - Returns: - chain (Chain): The chain object - - """ - embeddings = OpenAIEmbeddings( openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002" ) vectorstore = SupabaseVectorStore( - embedding=embeddings, client=supabase, table_name="documents" + embedding=embeddings, + client=supabase, + table_name="documents", + query_name="v_match_documents", ) - return ( - get_chain_gpt(vectorstore, callback_handler=callback_handler) - if "GPT-3.5" in model_name - else get_chain_replicate(vectorstore, callback_handler=callback_handler) + + if "claude" in model_name.lower(): + model_type = "claude" + elif "GPT-3.5" in model_name: + model_type = "gpt" + else: + model_type = "code-llama" + + config = ModelConfig( + model_type=model_type, secrets=st.secrets, callback_handler=callback_handler ) + model = ModelWrapper(config) + return model.get_chain(vectorstore) diff --git a/main.py b/main.py index ec3f3839..0e7903e7 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["✨ GPT-3.5", "🐐 code-LLama"], + options=["✨ GPT-3.5", "🐐 code-LLama", "♾️ Claude"], index=0, horizontal=True, ) @@ -94,8 +94,6 @@ def get_sql(text): def append_message(content, role="assistant", display=False): message = {"role": role, "content": content} - if model == "LLama-2": # unable to get streaming working with LLama-2 - message_func(content, False, display) st.session_state.messages.append(message) if role != "data": append_chat_history(st.session_state.messages[-2]["content"], content) @@ -138,11 +136,11 @@ def execute_sql(query, conn, retries=2): result = chain( {"question": content, "chat_history": st.session_state["history"]} )["answer"] - # print(result) + print(result) append_message(result) - if get_sql(result): - conn = SnowflakeConnection().get_session() - df = execute_sql(get_sql(result), conn) - if df is not None: - callback_handler.display_dataframe(df) - append_message(df, "data", True) + # if get_sql(result): + # conn = SnowflakeConnection().get_session() + # df = execute_sql(get_sql(result), conn) + # if df is not None: + # callback_handler.display_dataframe(df) + # append_message(df, "data", True) diff --git a/requirements.txt b/requirements.txt index 63d014d3..94f91515 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ -langchain==0.0.266 +langchain==0.0.305 pandas==1.5.0 pydantic==1.10.8 snowflake_snowpark_python==1.5.0 snowflake-snowpark-python[pandas] -streamlit==1.24.0 +streamlit==1.27.1 supabase==1.0.3 unstructured==0.7.12 tiktoken==0.4.0 openai==0.27.8 black==23.3.0 -replicate==0.8.4 \ No newline at end of file +replicate==0.8.4 +boto3==1.28.57 \ No newline at end of file diff --git a/supabase/scripts.sql b/supabase/scripts.sql index cac92a3d..5cfd0553 100644 --- a/supabase/scripts.sql +++ b/supabase/scripts.sql @@ -1,36 +1,27 @@ - CREATE extension vector; - CREATE TABLE documents ( - id UUID PRIMARY KEY, - content text, - metadata jsonb, - embedding vector(1536) + id UUID PRIMARY KEY, + content text, + metadata jsonb, + embedding vector(1536) ); - -CREATE OR REPLACE FUNCTION match_documents(query_embedding vector(1536), match_count int) - RETURNS TABLE( - id UUID, - content text, - metadata jsonb, - -- we return matched vectors to enable maximal marginal relevance searches - embedding vector(1536), - similarity float) - LANGUAGE plpgsql - AS $$ - # variable_conflict use_column - BEGIN - RETURN query - SELECT - id, - content, - metadata, - embedding, - 1 -(documents.embedding <=> query_embedding) AS similarity - FROM - documents - ORDER BY - documents.embedding <=> query_embedding - LIMIT match_count; - END; - $$; +CREATE INDEX ON documents USING hnsw (embedding vector_ip_ops); +CREATE FUNCTION v_match_documents ( + query_embedding vector (1536), + filter jsonb default '{}' +) RETURNS table ( + id uuid, + content text, + metadata jsonb, + similarity float +) language plpgsql as $$ #variable_conflict use_column +begin return query +select id, + content, + metadata, + 1 - (documents.embedding <=> query_embedding) as similarity +from documents +where metadata @> filter +order by documents.embedding <=> query_embedding; +END; +$$; \ No newline at end of file diff --git a/template.py b/template.py new file mode 100644 index 00000000..031af36a --- /dev/null +++ b/template.py @@ -0,0 +1,64 @@ +from langchain.prompts.prompt import PromptTemplate + +template = """You are an AI chatbot having a conversation with a human. + +Chat History:\""" +{chat_history} +\""" +Human: \""" +{question} +\""" +Assistant:""" + + +TEMPLATE = """ +You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. + +When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries. + +Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema. +**You are only required to write one SQL query per question.** + +If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. + +When the user expresses gratitude or says "Thanks", interpret it as a signal to conclude the conversation. Respond with an appropriate closing statement without generating further SQL queries. + +If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." + +Write your response in markdown format. + +Human: ```{question}``` +{context} + +Assistant: +""" +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +LLAMA_TEMPLATE = """ +You're specialized with Snowflake SQL. When providing answers, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. + +If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries. + +If you don't know the answer, simply state, "I'm sorry, I don't know the answer to your question." + +Write SQL code for this Question based on the below context details: {question} + +<> +context: \n {context} +<> + +write responses in markdown format + +Answer: + +""" + +LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST + +CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(template) + +QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) +LLAMA_PROMPT = PromptTemplate( + template=LLAMA_TEMPLATE, input_variables=["question", "context"] +) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index e2570c6c..b13dc621 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -143,3 +143,6 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): """ self.token_buffer = [] # Reset the buffer self.has_streaming_ended = True + + def __call__(self, *args, **kwargs): + pass