-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathincubate.py
82 lines (62 loc) · 2.53 KB
/
incubate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from torch import nn
from torch.optim import Adam
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import re
import json
from tqdm import tqdm
from collections import Counter
from classifier import Classifier
import argparse
parser = argparse.ArgumentParser(description='Parser for Incubator.')
parser.add_argument('--n_epoch', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--device', type=int)
parser.add_argument('--n_sample', type=int)
parser.add_argument('--max_new_tokens', type=int)
parser.add_argument('--nli_finetune_epoch', type=int)
parser.add_argument('--instruction', type=str)
parser.add_argument('--incubator', type=str)
parser.add_argument('--classifier', type=str)
parser.add_argument('--save_path', type=str)
args = parser.parse_args()
n_epoch = args.n_epoch
batch_size = args.batch_size
device = args.device
n_sample = args.n_sample
max_new_tokens = args.max_new_tokens
instruction = args.instruction
incubator = args.incubator
classifier = args.classifier
save_path = args.save_path
tokenizer = AutoTokenizer.from_pretrained(incubator)
model = AutoModelForCausalLM.from_pretrained(incubator, torch_dtype=torch.float16)
model = model.to(f"cuda:{device}")
input_text = f"[INST] {instruction} [/INST]"
input_ids = tokenizer(input_text, return_tensors="pt").to(f"cuda:{device}")
dataset = []
with torch.no_grad():
for _ in tqdm(range(n_sample)):
try:
outputs = model.generate(**input_ids, max_new_tokens=max_new_tokens, do_sample=True)
data = re.findall("({.*?})", tokenizer.decode(outputs[0]))[0]
data = json.loads(data)
dataset.append(data)
except:
pass
labels = ["#".join(list(data.keys())) for data in dataset]
label_texts = Counter(labels).most_common(1)[0][0].split("#")
new_dataset = []
for data in dataset:
if list(data.keys()) == label_texts:
for label in data:
new_dataset.append({"text": data[label], "label": label})
classifier = Classifier(model_name=classifier, device=f"cuda:{device}", num_labels=len(label_texts), label_texts=label_texts)
for epoch in range(n_epoch):
classifier.train(new_dataset, batch_size)
classifier.tok.save_pretrained(save_path)
classifier.classifier.save_pretrained(save_path)
for input_text in ["I love 'Spiderman 2'!", "I ate a delicious pudding!"]:
label = label_texts[classifier.predict(input_text).argmax().item()]
print("Input:", input_text)
print("Predicted Label:", label)