-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask2tool
executable file
·337 lines (260 loc) · 11.7 KB
/
task2tool
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#!/usr/bin/env python3
import argparse
from curses.ascii import isalnum
import os
from queue import Queue
import time
from typing import Iterable, override
import jinja2
import nltk
import numpy
import pandas
from together import Together
from together.types.chat_completions import ChatCompletionResponse
FREE_MODEL: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
def noop(_: argparse.Namespace) -> int:
print("No operation was given.")
return 0
def get_together_client() -> Together | None:
if api_key := os.environ.get('TOGETHER_API_KEY'):
return Together(api_key=api_key)
else:
print("Failed to get the API key, which is required for the operation. Set TOGETHER_API_KEY in the environment.")
return None
def together_prompt(model: str, prompt: str) -> str:
if client := get_together_client():
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
)
return response.choices[0].message.content
else:
print("Error: Failed to get the client.")
return ""
def together_chat_cmd(args: argparse.Namespace) -> int:
print("Result:", together_prompt(args.model, args.prompt))
return 0
class CheckThatTask2Data(object):
dev_ds: pandas.DataFrame
test_ds: pandas.DataFrame
train_ds: pandas.DataFrame
def __init__(self, repository: os.PathLike):
base_path: os.PathLike = os.path.join(repository, "task2", "data")
self.dev_ds = pandas.read_csv(os.path.join(base_path, "dev", "dev-eng.csv"))
self.test_ds = pandas.read_csv(os.path.join(base_path, "test", "test-eng.csv"))
self.train_ds = pandas.read_csv(os.path.join(base_path, "train", "train-eng.csv"))
class LLMBackend(object):
"""
Interface for whatever LLM we are using to implement the task.
"""
def initialize(self):
raise NotImplementedError()
def train(self, _: CheckThatTask2Data, **__):
"""
Run training on this model.
"""
raise NotImplementedError()
def query(self, _: str) -> str:
"""
Query the LLM.
"""
raise NotImplementedError()
class TogetherLLMBackend(LLMBackend):
"""
LLM backend for LLMs provided by the Together AI service.
Can make up to 10 requests per minute with the free model (set as default).
There could be support for training, but it is not implemented, and it costs a considerable amount of money.
"""
# Configure the throttling (max 10 requests per minute)
THROTTLE_TIME: float = 60.0
THROTTLE_MAX: int = 10
# This queue is used to throttle queries to Together.
_tq: Queue
_model: str
together_client: Together
def __init__(self, model: str = FREE_MODEL):
self._tq = Queue(TogetherLLMBackend.THROTTLE_MAX)
self._model = model
@override
def initialize(self):
self.together_client = get_together_client()
@override
def train(self, _: pandas.DataFrame):
pass
@override
def query(self, querytext: str) -> str:
# Throttling. Wait such that the oldest request occurred over a minute ago if we have issued
# 10 requests and the oldest of those occurred less than a minute ago.
if self._tq.qsize() >= TogetherLLMBackend.THROTTLE_MAX:
last_time: float = self._tq.get_nowait()
delta: float = time.monotonic() - last_time
if delta < TogetherLLMBackend.THROTTLE_TIME:
time.sleep(60.0 - delta)
# Put in the time we're running the query into the queue.
self._tq.put_nowait(time.monotonic())
response: ChatCompletionResponse = self.together_client.chat.completions.create(
model=self._model,
messages=[{
"role": "user",
"content": querytext,
}]
)
if type(response) is not ChatCompletionResponse:
raise NotImplementedError("Together API returned completion chunks.")
if len(response.choices) > 1:
print("More than one choice was given.")
elif len(response.choices) == 0:
raise Exception("Got zero choices")
return response.choices[0].message.content
class CheckThatTask2(object):
"""
Little class to manage state for running the CheckThat! task 2 stuff.
"""
prompt_template: str | None = None
eval_frame: pandas.DataFrame | None = None
data: CheckThatTask2Data
backend: LLMBackend
query_tmpl: jinja2.Template
profile_name: str
OUTPUT_COLUMN_EMPTY: str = '---'
trainargs: dict = {}
_cache_path: os.PathLike
def __init__(self, profile_name: str, backend: LLMBackend, query_tmpl: jinja2.Template, repository_path: os.PathLike = os.path.join(os.curdir, "checkthat_data"), cache_path: os.PathLike = os.path.join(os.curdir, '.checkthat_cache')):
# Fail here if the path to the CheckThat! data repository is not found.
if not os.path.isdir(repository_path):
print("ERROR: You must clone the CheckThat! data repository first! Use the provided makefile.")
raise FileNotFoundError(repository_path)
# Read the data from CSV files in the Git repository.
self.data = CheckThatTask2Data(repository_path)
if not os.path.isdir(cache_path):
try:
os.mkdir(cache_path)
except Exception as e:
print("ERROR: Failed to create cache directory.")
raise e
# Set paths.
self._cache_path = cache_path
self.profile_name = profile_name
self.backend = backend
self.query_tmpl = query_tmpl
# Determine where in the cache directory to store the eval table.
self.eval_file: os.PathLike = os.path.join(self._cache_path, f"{self.profile_name}-eval.csv")
# Initialize backend.
self.backend.initialize()
def delete_eval_table_file(self):
if os.path.isfile(self.eval_file):
os.remove(self.eval_file)
else:
print("Note: Eval file did not exist. Didn't change anything...")
def initialize_eval_table(self):
if not os.path.isfile(self.eval_file):
queries = []
results = []
refopts = []
for _, row in self.data.dev_ds.iterrows():
queries.append(self.query_tmpl.render(
train_rows=self.data.train_ds,
input=row['post'],
))
results.append(CheckThatTask2.OUTPUT_COLUMN_EMPTY)
refopts.append(row['normalized claim'])
self.eval_frame = pandas.DataFrame({"input": queries, "output": results, "reference": refopts})
print(f'init: Created new eval sheet with {len(queries)} test prompts.')
self.save_eval_table()
else:
print('init: Eval sheet already exists.')
self.eval_frame = pandas.read_csv(self.eval_file)
def save_eval_table(self):
self.eval_frame[['input', 'output', 'reference']].to_csv(self.eval_file)
def train(self):
"""
If we wanted to further train/refine an LLM using datasets, this function
serves as a place to do that.
Arbitrary kwargs may be passed in for parameterization of this process.
"""
self.backend.train(self.data, **self.trainargs)
def fill_eval_table(self):
rows_to_fill = []
for key, row in self.eval_frame.iterrows():
opt_value: float | str = row['output']
# Empty cells in pandas data frames are a bit of a headache.
if type(opt_value) is float or (type(opt_value) is str and (opt_value in ["", CheckThatTask2.OUTPUT_COLUMN_EMPTY])):
rows_to_fill.append(key)
print(f'fill-eval-table: Querying {len(rows_to_fill)}. This may take a long time.')
for c in rows_to_fill:
row: pandas.Series = self.eval_frame.loc[c]
query_result = self.backend.query(row['input'])
self.eval_frame.at[c, 'output'] = query_result
# Save the file each time we get a response to memoize them eagerly.
self.save_eval_table()
def calculate_meteor_score_avg(self) -> float:
tokenizer = nltk.tokenize.NLTKWordTokenizer()
def tokenize_custom(input: str) -> Iterable[str]:
return filter(lambda tok: all(map(lambda c: isalnum(c), tok)), tokenizer.tokenize(input))
meteors = [
nltk.meteor(references=[tokenize_custom(row['reference'])], hypothesis=tokenize_custom(row['output']))
for _, row in self.eval_frame.iterrows()
if type(row['output']) is str
]
print(f"calculate_meteor_score_avg: Calculated individual meteor score for {len(meteors)} rows.")
if len(meteors) > 0:
return round(numpy.average(meteors), 4)
else:
return 0.0
def checkthat_task2_cmd(args: argparse.Namespace) -> int:
llm: LLMBackend
match args.backend:
case "together-ai":
llm = TogetherLLMBackend()
case _:
raise NotImplementedError()
# Load profile from disk.
with open(os.path.join(os.curdir, "profiles", f"{args.profile}.jinja2")) as templ_file:
templ_str = templ_file.read()
# Construct the task 2 class.
ctt2 = CheckThatTask2(args.profile, llm, jinja2.Template(templ_str))
# Delete eval table if flag is set.
if args.clear_eval_table:
ctt2.delete_eval_table_file()
ctt2.initialize_eval_table()
# Bail out early if the user has told us to stop after initializing the evaluation table.
if args.init_only:
print("init-only mode: Not doing anything further.")
return 0
if not args.no_query:
ctt2.train()
ctt2.fill_eval_table()
else:
print("no-query mode: Not training or filling eval table for LLM.")
# Then, use eval table to determine the METEOR score average across each row.
meteor_score = ctt2.calculate_meteor_score_avg()
ctt2.save_eval_table()
# Print the evaluation result.
print(f'Meteor score for profile "{ctt2.profile_name}": {meteor_score}')
return 0
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("Run AI things.")
parser.set_defaults(cmd=noop)
subp = parser.add_subparsers()
together_chat_req = subp.add_parser('together-chat')
together_chat_req.add_argument('prompt', type=str, help='Prompt to submit to Together.AI')
together_chat_req.add_argument('-m', '--model', type=str, default="meta-llama/Llama-3.3-70B-Instruct-Turbo", help='Model ID to use.')
together_chat_req.set_defaults(cmd=together_chat_cmd)
checkthat_task2 = subp.add_parser('checkthat-task2', description='Harness to run the CheckThat! Task 2 and run tests on data.')
checkthat_task2.add_argument('backend', type=str, choices=['together-ai'], help='LLM backend to use.')
checkthat_task2.add_argument('profile', type=str, choices=[
os.path.splitext(x.name)[0]
for x in os.scandir(os.path.join(os.curdir, 'profiles'))
if os.path.splitext(x.name)[1] == '.jinja2'
], help='Profile to use.')
checkthat_task2.add_argument('-c', '--clear-eval-table', action='store_true', help='Deletes the eval table when provided.')
checkthat_task2.add_argument('-i', '--init-only', action='store_true', help='Only run the initialization of the eval table.')
checkthat_task2.add_argument('-nq', '--no-query', action='store_true', help='Do not query the LLM.')
checkthat_task2.set_defaults(cmd=checkthat_task2_cmd)
return parser.parse_args()
def main() -> int:
args = parse_args()
nltk.download('wordnet')
return args.cmd(args)
if __name__ == "__main__":
exit(main())