-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathcreate-finetune-data.py
71 lines (59 loc) · 2.08 KB
/
create-finetune-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import json
import pathlib
import datasets
def create_dataset(
preset_name,
dataset_name,
dataset_config_name,
prompt_template,
response_template,
output_dir,
):
raw_datasets = datasets.load_dataset(dataset_name, dataset_config_name)
for split, dataset in raw_datasets.items():
filename = f"{preset_name}-{split}.jsonl"
outpath = pathlib.Path(output_dir) / filename
print(outpath)
with open(outpath, "w") as f:
for example in dataset:
prompt = prompt_template.format(**example)
response = response_template.format(**example)
f.write(json.dumps({"prompt": prompt, "response": response}))
f.write("\n")
presets = {}
presets["gsm8k"] = dict(
dataset_name="gsm8k",
dataset_config_name="main",
prompt="<<SYS>>\nAnswer the following Grade School Math problem.\n<</SYS>>\n[INST] {question} [/INST]\n",
response="{answer}",
)
presets["sqlctx"] = dict(
dataset_name="b-mc2/sql-create-context",
dataset_config_name="main",
prompt="<<SYS>>\nGenerate a correct SQL query from the following database schema.\n{context}\n<</SYS>>\n[INST] {question} [/INST]\n",
response="{answer}",
)
presets["viggo"] = dict(
dataset_name="GEM/viggo",
dataset_config_name="main",
prompt="<<SYS>>\nGenerate a description based on the following representation.\n<</SYS>>\n[INST] {meaning_representation} [/INST]\n",
response="{target}",
)
def main():
data_dir = pathlib.Path(__file__).parent / "data"
parser = argparse.ArgumentParser()
parser.add_argument("--preset", choices=presets.keys(), required=True)
parser.add_argument("--output_dir", default=str(data_dir))
args = parser.parse_args()
p = presets[args.preset]
create_dataset(
preset_name=args.preset,
dataset_name=p["dataset_name"],
dataset_config_name=p["dataset_config_name"],
prompt_template=p["prompt"],
response_template=p["response"],
output_dir=args.output_dir,
)
if __name__ == "__main__":
main()