From 25f231d6b4ff1a554598b6856b77c02398b45efd Mon Sep 17 00:00:00 2001 From: Shyam Sudhakaran Date: Wed, 24 May 2023 15:29:29 -0700 Subject: [PATCH] added return_row flag, default llm, and added query method in __init__ for easier use --- README.md | 15 ++++++++++----- sayql/__init__.py | 8 ++++++++ sayql/sayql.py | 8 +++++--- tests/geo.py | 14 ++++++++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 tests/geo.py diff --git a/README.md b/README.md index 92c7bfc..57190ff 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,18 @@ Talk to your data # Usage ```python +import pandas as pd +from langchain.chat_models import ChatOpenAI +import sayql + df = pd.read_csv('https://raw.githubusercontent.com/scpike/us-state-county-zip/master/geo-data.csv') -llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0, openai_api_key=OPENAI_API_KEY) +llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) -# Interface subject to change! -SayQL(df, llm).query("What is the Zipcode of Clanton. AL? Return the whole row") +query = "What is the Zipcode of Clanton. AL?" +print(f"Query: {query}") ->>> state_fips state state_abbr zipcode county city -0 1 Alabama AL 35045 Chilton Clanton +# Interface subject to change! +print("Resulting df:") +print(sayql.query("What is the Zipcode of Clanton. AL?", return_row=True, df = df, llm = llm)) ``` \ No newline at end of file diff --git a/sayql/__init__.py b/sayql/__init__.py index 6682638..dcd6c97 100644 --- a/sayql/__init__.py +++ b/sayql/__init__.py @@ -1 +1,9 @@ from sayql.sayql import SayQL + +def query( + query: str, + df, + llm = None, + return_row: bool = True +): + return SayQL(df, llm).query(query, return_row=return_row) diff --git a/sayql/sayql.py b/sayql/sayql.py index d0e6db7..267c45d 100644 --- a/sayql/sayql.py +++ b/sayql/sayql.py @@ -1,7 +1,7 @@ import duckdb import pandas as pd from langchain import LLMChain, PromptTemplate - +from langchain.chat_models import ChatOpenAI from sayql.prompt import DEFAULT_PROMPT # Will need to abstract for multiple datastores @@ -16,7 +16,7 @@ class SayQL: - def __init__(self, df: pd.DataFrame, llm): + def __init__(self, df: pd.DataFrame, llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)): self.df = df self.llm = llm @@ -34,7 +34,9 @@ def _to_sql(self, query: str) -> str: llm_chain = LLMChain(llm=self.llm, prompt=prompt) return llm_chain.predict(query=query, schema_str=schema_str) - def query(self, query: str) -> pd.DataFrame: + def query(self, query: str, return_row: bool = True) -> pd.DataFrame: + if return_row: + query = f"{query} Return the whole row" sql: str = self._to_sql(query) # global naming for duckdb df = self.df diff --git a/tests/geo.py b/tests/geo.py new file mode 100644 index 0000000..ee414dc --- /dev/null +++ b/tests/geo.py @@ -0,0 +1,14 @@ +import pandas as pd +from langchain.chat_models import ChatOpenAI +import sayql + +df = pd.read_csv('https://raw.githubusercontent.com/scpike/us-state-county-zip/master/geo-data.csv') + +llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) + +query = "What is the Zipcode of Clanton. AL?" +print(f"Query: {query}") + +# Interface subject to change! +print("Resulting df:") +print(sayql.query("What is the Zipcode of Clanton. AL?", return_row=True, df = df, llm = llm))