-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathutil.py
204 lines (177 loc) Β· 8.88 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import sys
import json
from typing import Dict, List
import requests
from dotenv import find_dotenv, load_dotenv
from rich.padding import Padding
from rich.console import Console
from rich.json import JSON
from openai import (
APIConnectionError,
APITimeoutError,
AuthenticationError,
BadRequestError,
ConflictError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
APIError,
UnprocessableEntityError,
OpenAI,
)
from openai import AzureOpenAI
import time
from jockey.model_config import AZURE_DEPLOYMENTS, OPENAI_MODELS
from langchain_core.messages.ai import AIMessageChunk
REQUIRED_ENVIRONMENT_VARIABLES = set(["TWELVE_LABS_API_KEY", "HOST_PUBLIC_DIR", "LLM_PROVIDER"])
AZURE_ENVIRONMENT_VARIABLES = set(["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY", "OPENAI_API_VERSION"])
OPENAI_ENVIRONMENT_VARIABLES = set(["OPENAI_API_KEY"])
ALL_JOCKEY_ENVIRONMENT_VARIABLES = REQUIRED_ENVIRONMENT_VARIABLES | AZURE_ENVIRONMENT_VARIABLES | OPENAI_ENVIRONMENT_VARIABLES
LOCAL_LANGGRAPH_URL = "http://localhost:8000"
console = Console()
async def parse_langchain_events_terminal(event: dict):
"""Used to parse events emitted from Jockey when called as an API."""
with open("event_log.txt", "a") as f:
f.write(f"{event}\n")
if event["event"] == "on_chat_model_stream":
if isinstance(event["data"]["chunk"], dict):
content = event["data"]["chunk"]["content"]
else:
content = event["data"]["chunk"].content
if content and "instructor" in event["tags"]:
console.print(f"[red]{content}", end="")
elif content and "planner" in event["tags"]:
console.print(f"[yellow]{content}", end="")
elif content and "supervisor" in event["tags"]:
console.print(f"[white]{content}", end="")
elif content and "reflect" in event["tags"]:
console.print(f"[cyan]{content}", end="")
elif event["event"] == "on_tool_start":
tool = event["name"]
console.print(Padding(f"[cyan]π Using: {tool}", (1, 0, 0, 2)))
console.print(Padding(f"[cyan]π Inputs:", (0, 2)))
try:
# Convert input data to a serializable format
input_data = event["data"]["input"]
if isinstance(input_data, dict):
# If it's a dictionary, try to convert any non-serializable objects to strings
serializable_input = {k: str(v) if hasattr(v, "__dict__") else v for k, v in input_data.items()}
else:
# If it's not a dictionary, convert the whole thing to string
serializable_input = str(input_data)
console.print(Padding(JSON(json.dumps(serializable_input), indent=2), (1, 6)))
except Exception as e:
# Fallback to string representation if JSON serialization fails
console.print(Padding(str(input_data), (1, 6)))
elif event["event"] == "on_tool_end":
tool = event["name"]
console.print(Padding(f"[cyan]π Finished Using: {tool}", (0, 2)))
console.print(Padding(f"[cyan]π Outputs:", (0, 2)))
try:
console.print(Padding(JSON(event["data"]["output"], indent=2), (1, 6)))
except (json.decoder.JSONDecodeError, TypeError):
console.print(Padding(str(event["data"]["output"]), (0, 6)))
elif event["event"] == "on_chat_model_start":
if "instructor" in event["tags"]:
console.print(Padding(f"[red]π Instructor: ", (1, 0)), end="")
elif "planner" in event["tags"]:
console.print(Padding(f"[yellow]π Planner: ", (1, 0)), end="")
elif "reflect" in event["tags"]:
console.print()
console.print(f"[cyan]π Jockey: ", end="")
elif event["event"] == "on_chain_end":
# Only process events from the planner node
metadata = event.get("metadata", {})
langgraph_node = metadata.get("langgraph_node")
if langgraph_node != "planner":
return
# Check for the specific data structure we want to process
output = event["data"].get("output")
if not isinstance(output, str): # The second event has a string output
return
console.print(Padding(f"[yellow]π Planner: {output}", (1, 0)), end="")
def check_environment_variables():
"""Check that a .env file contains the required environment variables.
Uses the current working directory tree to search for a .env file."""
# Assume the .env file is someone on the current working directory tree.
load_dotenv(find_dotenv(usecwd=True))
if REQUIRED_ENVIRONMENT_VARIABLES & os.environ.keys() != REQUIRED_ENVIRONMENT_VARIABLES:
missing_environment_variables = REQUIRED_ENVIRONMENT_VARIABLES - os.environ.keys()
print(f"Expected the following environment variables:\n\t{str.join(', ', REQUIRED_ENVIRONMENT_VARIABLES)}")
print(f"Missing:\n\t{str.join(', ', missing_environment_variables)}")
sys.exit("Missing required environment variables.")
if (
AZURE_ENVIRONMENT_VARIABLES & os.environ.keys() != AZURE_ENVIRONMENT_VARIABLES
and OPENAI_ENVIRONMENT_VARIABLES & os.environ.keys() != OPENAI_ENVIRONMENT_VARIABLES
):
missing_azure_environment_variables = AZURE_ENVIRONMENT_VARIABLES - os.environ.keys()
missing_openai_environment_variables = OPENAI_ENVIRONMENT_VARIABLES - os.environ.keys()
print(f"If using Azure, Expected the following environment variables:\n\t{str.join(', ', AZURE_ENVIRONMENT_VARIABLES)}")
print(f"Missing:\n\t{str.join(', ', missing_azure_environment_variables)}")
print(f"If using Open AI, Expected the following environment variables:\n\t{str.join(', ', OPENAI_ENVIRONMENT_VARIABLES)}")
print(f"Missing:\n\t{str.join(', ', missing_openai_environment_variables)}")
sys.exit("Missing Azure or Open AI environment variables.")
def preflight_checks():
print("Performing preflight checks...")
load_dotenv()
llm_provider = os.getenv("LLM_PROVIDER")
if llm_provider == "OPENAI":
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)
models = list(OPENAI_MODELS.values())
elif llm_provider == "AZURE":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
client = AzureOpenAI(api_key=api_key, azure_endpoint=endpoint)
models = [config["deployment_name"] for config in AZURE_DEPLOYMENTS.values()]
# assert that the models are correct
# print(models, [config["deployment_name"] for config in AZURE_DEPLOYMENTS.values()])
assert all(model in models for model in [config["deployment_name"] for config in AZURE_DEPLOYMENTS.values()])
else:
print("Invalid LLM_PROVIDER. Must be one of: [AZURE, OPENAI]")
sys.exit("Invalid LLM_PROVIDER environment variable.")
for model in models:
print(f"[DEBUG] Testing model: {model}")
for stream in [False, True]:
try:
response = client.chat.completions.create(
model=model,
messages=[{"role": "system", "content": "Test message"}],
temperature=0,
max_tokens=2048,
stream=stream,
timeout=10, # Add 10 second timeout
)
if stream:
# Process stream with timeout using a simple counter
start_time = time.time()
timeout_seconds = 10
has_content = False
for chunk in response:
if time.time() - start_time > timeout_seconds:
return f"Timeout occurred while processing stream. Model: {model}"
if chunk.choices and chunk.choices[0].delta.content is not None:
has_content = True
break
if not has_content:
return f"API request failed. Streaming: {stream}. Model: {model}. Check your API key or usage limits."
elif not response.choices[0].message.content:
return f"API request failed. Streaming: {stream}. Model: {model}. Check your API key or usage limits."
except (
APIConnectionError,
APITimeoutError,
AuthenticationError,
BadRequestError,
ConflictError,
InternalServerError,
NotFoundError,
PermissionDeniedError,
RateLimitError,
APIError,
UnprocessableEntityError,
requests.exceptions.Timeout, # Add requests timeout
) as e:
return f"{type(e).__name__} occurred. Model: {model}. Error: {str(e)}"
return "Preflight checks passed. All models functioning correctly."