Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added return_row flag, default llm, and added query method in __init__ for easier use #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@ Talk to your data
# Usage

```python
import pandas as pd
from langchain.chat_models import ChatOpenAI
import sayql

df = pd.read_csv('https://raw.githubusercontent.com/scpike/us-state-county-zip/master/geo-data.csv')

llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0, openai_api_key=OPENAI_API_KEY)
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# Interface subject to change!
SayQL(df, llm).query("What is the Zipcode of Clanton. AL? Return the whole row")
query = "What is the Zipcode of Clanton. AL?"
print(f"Query: {query}")

>>> state_fips state state_abbr zipcode county city
0 1 Alabama AL 35045 Chilton Clanton
# Interface subject to change!
print("Resulting df:")
print(sayql.query("What is the Zipcode of Clanton. AL?", return_row=True, df = df, llm = llm))
```
8 changes: 8 additions & 0 deletions sayql/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
from sayql.sayql import SayQL

def query(
query: str,
df,
llm = None,
return_row: bool = True
):
return SayQL(df, llm).query(query, return_row=return_row)
8 changes: 5 additions & 3 deletions sayql/sayql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import duckdb
import pandas as pd
from langchain import LLMChain, PromptTemplate

from langchain.chat_models import ChatOpenAI
from sayql.prompt import DEFAULT_PROMPT

# Will need to abstract for multiple datastores
Expand All @@ -16,7 +16,7 @@


class SayQL:
def __init__(self, df: pd.DataFrame, llm):
def __init__(self, df: pd.DataFrame, llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)):
self.df = df
self.llm = llm

Expand All @@ -34,7 +34,9 @@ def _to_sql(self, query: str) -> str:
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
return llm_chain.predict(query=query, schema_str=schema_str)

def query(self, query: str) -> pd.DataFrame:
def query(self, query: str, return_row: bool = True) -> pd.DataFrame:
if return_row:
query = f"{query} Return the whole row"
sql: str = self._to_sql(query)
# global naming for duckdb
df = self.df
Expand Down
14 changes: 14 additions & 0 deletions tests/geo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pandas as pd
from langchain.chat_models import ChatOpenAI
import sayql

df = pd.read_csv('https://raw.githubusercontent.com/scpike/us-state-county-zip/master/geo-data.csv')

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

query = "What is the Zipcode of Clanton. AL?"
print(f"Query: {query}")

# Interface subject to change!
print("Resulting df:")
print(sayql.query("What is the Zipcode of Clanton. AL?", return_row=True, df = df, llm = llm))