-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
94 lines (78 loc) · 3.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import readline
import typer
from click import BadParameter
from rich.progress import Progress, SpinnerColumn, TextColumn
from tempfile import NamedTemporaryFile
from time import sleep
from typing import Callable
# from openai_client import OpenAIClient
def loading_spinner(func: Callable) -> Callable:
"""
Decorator that adds a loading spinner animation to a function that uses the OpenAI API.
:param func: Function to wrap.
:return: Wrapped function with loading.
"""
def wrapper(*args, **kwargs):
# if not kwargs.pop("spinner"):
# return func(*args, **kwargs)
text = TextColumn("[green]Consulting with robots...")
with Progress(SpinnerColumn(), text, transient=True) as progress:
progress.add_task("request")
return func(*args, **kwargs)
return wrapper
def get_edited_prompt() -> str:
"""
Opens the user's default editor to let them
input a prompt, and returns the edited text.
:return: String prompt.
"""
with NamedTemporaryFile(suffix=".txt", delete=False) as file:
# Create file and store path.
file_path = file.name
editor = os.environ.get("EDITOR", "vim")
# This will write text to file using $EDITOR.
os.system(f"{editor} {file_path}")
# Read file when editor is closed.
with open(file_path, "r") as file:
output = file.read()
os.remove(file_path)
if not output:
raise BadParameter("Couldn't get valid PROMPT from $EDITOR")
return output
def typer_writer(text: str, code: bool = False, shell: bool = False, animate: bool = True) -> None:
"""
Writes output to the console, with optional typewriter animation and color.
:param text: Text to output.
:param code: If content of text is code.
:param shell: if content of text is shell command.
:param animate: Enable/Disable typewriter animation.
:return: None
"""
shell_or_code = shell or code
color = "magenta" if shell_or_code else None
if animate and not shell_or_code:
for char in text:
typer.secho(char, nl=False, fg=color, bold=shell_or_code)
sleep(0.005)
# Add new line at the end, to prevent % from appearing.
typer.echo("")
return
typer.secho(text, fg=color, bold=shell_or_code)
# def echo_chat_messages(chat_id: str) -> None:
# # Prints all messages from a specified chat ID to the console.
# for index, message in enumerate(OpenAIClient.chat_cache.show(chat_id)):
# color = "cyan" if index % 2 == 0 else "green"
# typer.secho(message, fg=color)
# def echo_chat_ids() -> None:
# # Prints all existing chat IDs to the console.
# for chat_id in OpenAIClient.chat_cache.list():
# typer.echo(chat_id)
def get_tmp_env_var_name(env_var2val: dict, prefix="TEM_ENV_VAR_") -> str:
largest_env_var_name_idx = 0
for env_var in env_var2val:
if prefix in env_var:
env_var_name_idx = int(env_var.split("_")[-1])
if env_var_name_idx > largest_env_var_name_idx:
largest_env_var_name_idx = env_var_name_idx
return f"{prefix}{largest_env_var_name_idx + 1}"