Skip to content

Commit

Permalink
fix output parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
drudilorenzo committed Jun 3, 2024
1 parent 2529d08 commit 0db6021
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions reverie/backend_server/persona/prompt_template/run_gpt_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,20 @@ def create_prompt_input(persona, task, duration, test_input=None):
prompt_input += [persona.scratch.get_str_firstname()]
return prompt_input

def __func_clean_up(gpt_response, prompt=""):
print (gpt_response)
print ("-==- -==- -==- ")
def extract_numeric_part(k):
# Use regular expression to find all digits
numeric_part = re.findall(r'\d+', k)
# Join all found digits into a single string and convert to int
return int(numeric_part[0])

def __func_clean_up(gpt_response, prompt=""):

debug = True

if debug:
print (gpt_response)
print ("-==- -==- -==- ")
print("(cleanup func): Enter function")
# TODO SOMETHING HERE sometimes fails... See screenshot
temp = [i.strip() for i in gpt_response.split("\n")]
_cr = []
Expand All @@ -370,22 +380,49 @@ def __func_clean_up(gpt_response, prompt=""):
else:
_cr += [i]
for count, i in enumerate(_cr):
if debug:
print("(cleanup func) Unpacking: ", i)

# Original version
# k = [j.strip() for j in i.split("(duration in minutes:")]

# Sometimes the simulation fails because it doesn't contain
# `duration in minutes` but only `duration`.
if "duration in minutes" in i:
k = [j.strip() for j in i.split("(duration in minutes:")]
else:
k = [j.strip() for j in i.split("(duration:")]
# k = [j.strip() for j in i.split("(duration in minutes:")]

if debug:
print("(cleanup func) Unpacked(k): ", k)
task = k[0]
if task[-1] == ".":
task = task[:-1]
duration = int(k[1].split(",")[0].strip())
minutes = k[1].split(",")[0]

if debug:
print("(cleanup func): Minutes: ", minutes)
duration = extract_numeric_part(minutes)
if debug:
print("(cleanup func): Duration: ", duration)

# Original version
# duration = int(k[1].split(",")[0].strip())

cr += [[task, duration]]

if debug:
print("(cleanup func) Unpacked(cr): ", cr)

if debug:
print("(cleanup func) Prompt:", prompt)

total_expected_min = int(prompt.split("(total duration in minutes")[-1]
.split("):")[0].strip())

if debug:
print("(cleanup func) Expected Minutes:", total_expected_min)

# TODO -- now, you need to make sure that this is the same as the sum of
# the current action sequence.
curr_min_slot = [["dummy", -1],] # (task_name, task_index)
Expand Down Expand Up @@ -439,8 +476,6 @@ def get_fail_safe():
prompt = generate_prompt(prompt_input, prompt_template)
fail_safe = get_fail_safe()

print ("?????")
print (prompt)
output = safe_generate_response(prompt, gpt_param, 5, get_fail_safe(),
__func_validate, __func_clean_up)

Expand All @@ -454,11 +489,12 @@ def get_fail_safe():
IndexError: list index out of range
"""

print ("IMPORTANT VVV DEBUG")

# print (prompt_input)
# Some debugging prints
# print ("DEBUG")
# print("PROMPT:")
# print (prompt)
print (output)
# print("\nOUTPUT:")
# print (output)

fin_output = []
time_sum = 0
Expand Down Expand Up @@ -2913,23 +2949,4 @@ def get_fail_safe():
gpt_param = {"engine": "gpt-35-turbo-0125", "max_tokens": 50,
"temperature": 0, "top_p": 1, "stream": False,
"frequency_penalty": 0, "presence_penalty": 0, "stop": None}
return output, [output, prompt, gpt_param, prompt_input, fail_safe]



















return output, [output, prompt, gpt_param, prompt_input, fail_safe]

0 comments on commit 0db6021

Please sign in to comment.