-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstage.py
119 lines (78 loc) · 3.87 KB
/
stage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torch.nn.functional import one_hot
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import openai
import json
import re
import numpy as np
def fill_in_template(prompt, passage, style):
text = f'''<start_of_turn>user
{prompt}<end_of_turn>
<start_of_turn>model
{style}: {passage}<end_of_turn>'''
return text
def create_query(template, label, style, icl=False):
if not icl:
template = f'''Template:
<start_of_sequence> {template} <end_of_sequence>
Fill in the blanks in the template to produce a {style}.'''
else:
template = f'''Template:
<start_of_sequence> {template} <end_of_sequence>
Fill in the blanks in the template to produce another **{label}** {style} in the same writing style.'''
return template
def template_potential(prompt_1, prompt_2, passage, style, tokenizer, model):
with torch.no_grad():
len_prompt = 1+ len(tokenizer.tokenize(f'''<start_of_turn>user
{prompt_1}<end_of_turn>
<start_of_turn>model
{style}:'''))
inputs = tokenizer(fill_in_template(prompt_1, passage, style), return_tensors="pt").to("cuda")
logits = model(**inputs).logits
logits_1 = logits[:, len_prompt-1:-2][one_hot(inputs.input_ids[:, len_prompt:-1], len(tokenizer)).bool()]
len_prompt = 1+ len(tokenizer.tokenize(f'''<start_of_turn>user
{prompt_2}<end_of_turn>
<start_of_turn>model
{style}:'''))
inputs = tokenizer(fill_in_template(prompt_2, passage, style), return_tensors="pt").to("cuda")
logits = model(**inputs).logits
logits_2 = logits[:, len_prompt-1:-2][one_hot(inputs.input_ids[:, len_prompt:-1], len(tokenizer)).bool()]
diff = (logits_2 - logits_1)
return diff
def template_mine(dataset, label_name, style, k, tokenizer, model):
template_dataset = []
prompt_1 = f"Please write a short {style}."
prompt_2 = f"Please write a short **{label_name}** {style}."
scores, partials, sentences = [], [], []
for data in tqdm(dataset):
sentence, label = data["text"], data["label"]
if len(tokenizer.tokenize(sentence)) >= 5 and len(tokenizer.tokenize(sentence)) <= 512:
diff = template_potential(prompt_1, prompt_2, sentence, style, tokenizer, model)
mask = diff > diff[diff.argsort()[int(diff.shape[0]*(1-k))]]
slash_id = tokenizer.convert_tokens_to_ids(["_"])[0]
template = tokenizer.decode([idx if m else slash_id for m, idx in zip(mask, tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" "+sentence)))])
score = diff[diff.argsort()[int(diff.shape[0]*(1-k)):]].mean().item()
template_dataset.append({"score": score, "template": template, "original": sentence})
return template_dataset
def template_fill(template_dataset, label_name, style, model_engine, t):
top_t = int(t * len(template_dataset))
grafted_dataset = []
bar = tqdm(np.argsort([data["score"] for data in template_dataset])[::-1][:top_t])
for idx in bar:
data = template_dataset[idx]
messages = [
{"role": "user", "content": create_query(data["template"], None, style, False)},
{"role": "system", "content": f"<start_of_sequence> {data['original']} <end_of_sequence>"},
{"role": "user", "content": create_query(data["template"], label_name, style, True)},
]
grafted = openai.ChatCompletion.create(
model=model_engine,
temperature=0.0,
messages=messages,
).choices[0]['message']["content"]
grafted = re.findall("<start_of_sequence> (.*)? <end_of_sequence>", grafted.replace("\n", " "))[0].replace("_", "").lower()
data = {**data, "grafted": grafted}
grafted_dataset.append(data)
bar.set_description(f"#Data={len(grafted_dataset)}")
return grafted_dataset