Skip to content

Commit

Permalink
add price handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mlpierce22 committed Feb 13, 2024
1 parent a0a8d4d commit afad2e8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 19 deletions.
64 changes: 46 additions & 18 deletions app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Dict, Union

import tiktoken
from component_creation import (
create_loading_component,
create_steep_component,
Expand All @@ -17,6 +18,7 @@
pour_tag,
set_import,
)
from langchain_community.callbacks.openai_info import get_openai_token_cost_for_model
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import AIMessageChunk
Expand All @@ -39,13 +41,38 @@ class TeaAgent:

def __init__(self, llm: Union[BaseLLM, BaseChatModel] = None):
self.llm = llm
self.input_prompt = None
self.model_response = None

def print_chunk(self, chunk: str, end: str = ""):
if file_log:
print(chunk, end=end, flush=True, file=file_log)
else:
print(chunk, end=end, flush=True)

def print_costs(self):
"""
Prints the cost of the text in USD
"""
input_cost = self._get_cost(self.input_prompt)
output_cost = self._get_cost(self.model_response)
log.info(
f"The prompt costs: {input_cost} and the response costs: ${output_cost} USD."
)
log.info(f"Total cost: ${input_cost + output_cost} USD.")

def _get_cost(self, text: str) -> float:
"""
Returns the cost of the text as a float
"""
try:
enc = tiktoken.encoding_for_model(self.llm.name)
num_tokens = len(enc.encode(text))
return get_openai_token_cost_for_model(self.llm.name, num_tokens)
except Exception:
log.info(f"Can't get cost for model called {self.llm.name}")
return 0.00

def _process_response(
self, chain: RunnableSerializable, args: Union[Dict, str]
) -> str:
Expand All @@ -59,8 +86,9 @@ def _process_response(
response += chunk.content

self.print_chunk("\n-------\n")
log.debug(response)
return response
self.model_response = response
log.debug(self.model_response)
return self.model_response

def pour(self, component_name: str, ctx: SteepContext):
log.info("Pouring tea...")
Expand All @@ -70,20 +98,20 @@ def pour(self, component_name: str, ctx: SteepContext):
component_location_parser
)
paths = get_paths_from_tsconfig(ctx.root_directory)
log.debug("Using the following prompt:")
log.debug(
component_location_prompt.format(
**{
"component_name": component_name,
"path_aliases": paths,
"root_files": os.listdir(ctx.root_directory),
"parent_component_path": ctx.file_path,
"root_path": ctx.root_directory,
"logical_path_examples": LOGICAL_PATH_EXAMPLES,
"import_statement_examples": IMPORT_STATEMENT_EXAMPLES,
}
)
self.input_prompt = component_location_prompt.format(
**{
"component_name": component_name,
"path_aliases": paths,
"root_files": os.listdir(ctx.root_directory),
"parent_component_path": ctx.file_path,
"root_path": ctx.root_directory,
"logical_path_examples": LOGICAL_PATH_EXAMPLES,
"import_statement_examples": IMPORT_STATEMENT_EXAMPLES,
}
)
log.debug("Using the following prompt:")
log.debug(self.input_prompt)

chain = component_location_prompt | self.llm

def handle_response(retries=2):
Expand Down Expand Up @@ -168,7 +196,7 @@ def steep(self, ctx: SteepContext):
tea_component = create_tea_component()
available_components = get_available_components(ctx.root_directory)

prompt = write_component_prompt(
self.input_prompt = write_component_prompt(
user_query=ctx.tea_tag.children,
steep_component_content=ctx.steep_content,
parent_file_content=ctx.file_content,
Expand All @@ -177,7 +205,7 @@ def steep(self, ctx: SteepContext):
source_file=ctx.steep_path,
)
log.debug("Steeping with the following prompt:")
log.debug(prompt)
log.debug(self.input_prompt)

# Add the import to the top of the file
modified = False
Expand All @@ -202,7 +230,7 @@ def steep(self, ctx: SteepContext):
# Now the component is heating, this is where we ask the llm for code
log.info("Creating component. This could take a while...")

full_response = self._process_response(self.llm, prompt)
full_response = self._process_response(self.llm, self.input_prompt)

try:
# Grab the code from between the backticks
Expand Down
8 changes: 7 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def process_tea_tag(self, tea_tag: TeaTag, ctx: FileContext = None):
if pour:
log.info(f"Pouring {pour} component...")
self.tea_agent.pour(component_name=pour, ctx=steep_ctx)
self.tea_agent.print_costs()
else:
log.info("Steeping new component...")
self.tea_agent.steep(ctx=steep_ctx)
self.tea_agent.print_costs()

def process_file(self, file_path: str, root_directory=None):
"""
Expand Down Expand Up @@ -176,14 +178,18 @@ def get_config_from_environment():
else config.model
)
llm = ChatOpenAI(
name=model,
model=model,
temperature=config.temperature,
api_key=config.openai_key,
max_tokens=1000,
)
else:
llm = Ollama(
model=config.model, temperature=config.temperature, base_url=config.base_url
name=config.model,
model=config.model,
temperature=config.temperature,
base_url=config.base_url,
)

main = Main(llm=llm, config=config)
Expand Down

0 comments on commit afad2e8

Please sign in to comment.