From 3cb2f0c1c7e0a64ed26f5809a9679d04024786bf Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Tue, 30 Jul 2024 02:10:13 +0000 Subject: [PATCH] polish the preprocess code --- generation/preprocess.py | 41 +++++++++++++++++++++++++++++++++ generation/preprocess_data.py | 43 ----------------------------------- 2 files changed, 41 insertions(+), 43 deletions(-) create mode 100644 generation/preprocess.py delete mode 100644 generation/preprocess_data.py diff --git a/generation/preprocess.py b/generation/preprocess.py new file mode 100644 index 0000000..a451614 --- /dev/null +++ b/generation/preprocess.py @@ -0,0 +1,41 @@ +"""Script for preprocess state-tactic pairs into the format required by [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).""" + +import json +import random +import argparse +from loguru import logger + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-path", + type=str, + default="./data/leandojo_benchmark_4/random/train.json", + ) + parser.add_argument("--dst-path", type=str, default="state_tactic_pairs.json") + args = parser.parse_args() + logger.info(args) + + pairs = [] + for thm in json.load(open(args.data_path)): + for tac in thm["traced_tactics"]: + pairs.append({"state": tac["state_before"], "output": tac["tactic"]}) + logger.info(f"Read {len(pairs)} state-tactic paris from {args.data_path}") + + random.shuffle(pairs) + data = [ + { + "instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n", + "input": "", + "output": pair["output"], + } + for pair in pairs + ] + logger.info(data[0]) + json.dump(data, open(args.dst_path, "wt")) + logger.info(f"Preprocessed data saved to {args.dst_path}") + + +if __name__ == "__main__": + main() diff --git a/generation/preprocess_data.py b/generation/preprocess_data.py deleted file mode 100644 index 665dc77..0000000 --- a/generation/preprocess_data.py +++ /dev/null @@ -1,43 +0,0 @@ -import pdb -import csv -import json -import random - -from common import format_state, format_tactic - - -def main() -> None: - pairs = [] - data_path = "../data/leandojo_benchmark_4/random/train.json" - - for thm in json.load(open(data_path)): - for tac in thm["traced_tactics"]: - # if "annotated_tactic" in tac: - # tactic = format_tactic(*tac["annotated_tactic"], normalize=True) - # else: - # tactic = format_tactic(tac["tactic"], [], normalize=True) - pairs.append({"state": tac["state_before"], "output": tac["tactic"]}) - - random.shuffle(pairs) - - """ - with open("state_tactic_pairs.csv", "wt") as oup: - wt = csv.DictWriter(oup, fieldnames=["state", "output"]) - wt.writeheader() - for st in pairs: - wt.writerow(st) - """ - data = [] - for pair in pairs: - data.append( - { - "instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n", - "input": "", - "output": pair["output"], - } - ) - json.dump(data, open("state_tactic_pairs.json", "wt")) - - -if __name__ == "__main__": - main()