From 822dcb1c623aa288d3ad62aaa2255f8ab4f7d0ac Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Thu, 4 Apr 2024 22:21:43 +0000 Subject: [PATCH] update --- README.md | 4 ++-- generator/confs/cli_lean4_novel_premises.yaml | 17 +++++++++++++++-- generator/confs/cli_lean4_random.yaml | 17 +++++++++++++++-- prover/evaluate.py | 1 + prover/proof_search.py | 6 +++--- scripts/download_data.py | 4 ++-- 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 78f7840..1b5e820 100644 --- a/README.md +++ b/README.md @@ -306,8 +306,8 @@ After the tactic generator is trained, we combine it with best-first search to p For models without retrieval, run: ```bash -python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-cpus 8 --with-gpus -python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-cpus 8 --with-gpus +python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1 +python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1 ``` For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises): diff --git a/generator/confs/cli_lean4_novel_premises.yaml b/generator/confs/cli_lean4_novel_premises.yaml index f720237..f6377a5 100644 --- a/generator/confs/cli_lean4_novel_premises.yaml +++ b/generator/confs/cli_lean4_novel_premises.yaml @@ -17,6 +17,19 @@ trainer: - class_path: pytorch_lightning.callbacks.LearningRateMonitor init_args: logging_interval: step + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + verbose: true + save_top_k: 1 + save_last: true + monitor: Pass@1_val + mode: max + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + monitor: Pass@1_val + patience: 2 + mode: max + verbose: true model: model_name: google/byt5-small @@ -26,9 +39,9 @@ model: length_penalty: 0.0 ret_ckpt_path: null eval_num_retrieved: 100 - eval_num_workers: 8 + eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory. eval_num_gpus: 1 - eval_num_theorems: 400 + eval_num_theorems: 300 # Lower this number will make validation faster (but noiser). data: data_path: data/leandojo_benchmark_4/novel_premises/ diff --git a/generator/confs/cli_lean4_random.yaml b/generator/confs/cli_lean4_random.yaml index 565184a..58f8527 100644 --- a/generator/confs/cli_lean4_random.yaml +++ b/generator/confs/cli_lean4_random.yaml @@ -17,6 +17,19 @@ trainer: - class_path: pytorch_lightning.callbacks.LearningRateMonitor init_args: logging_interval: step + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + verbose: true + save_top_k: 1 + save_last: true + monitor: Pass@1_val + mode: max + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + monitor: Pass@1_val + patience: 2 + mode: max + verbose: true model: model_name: google/byt5-small @@ -26,9 +39,9 @@ model: length_penalty: 0.0 ret_ckpt_path: null eval_num_retrieved: 100 - eval_num_workers: 8 + eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory. eval_num_gpus: 1 - eval_num_theorems: 400 + eval_num_theorems: 300 # Lower this number will make validation faster (but noiser). data: data_path: data/leandojo_benchmark_4/random/ diff --git a/prover/evaluate.py b/prover/evaluate.py index 96f4f41..e5c5125 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -208,6 +208,7 @@ def main() -> None: args = parser.parse_args() assert args.ckpt_path or args.tactic + assert args.num_gpus <= args.num_workers logger.info(f"PID: {os.getpid()}") logger.info(args) diff --git a/prover/proof_search.py b/prover/proof_search.py index 33caf1c..5df6a50 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -100,8 +100,8 @@ def search( with torch.no_grad(): try: self._best_first_search() - except DojoCrashError: - logger.warning(f"Dojo crashed when proving {thm}") + except DojoCrashError as ex: + logger.warning(f"Dojo crashed with {ex} when proving {thm}") pass if self.root.status == Status.PROVED: @@ -389,7 +389,7 @@ def __init__( if ckpt_path is None: tac_gen = FixedTacticGenerator(tactic, module) else: - device = torch.device("cuda") if with_gpus else torch.device("cpu") + device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu") tac_gen = RetrievalAugmentedGenerator.load( ckpt_path, device=device, freeze=True ) diff --git a/scripts/download_data.py b/scripts/download_data.py index 7b546c0..4de6a55 100644 --- a/scripts/download_data.py +++ b/scripts/download_data.py @@ -7,10 +7,10 @@ LEANDOJO_BENCHMARK_4_URL = ( - "https://zenodo.org/records/10823489/files/leandojo_benchmark_4.tar.gz" + "https://zenodo.org/records/10929138/files/leandojo_benchmark_4.tar.gz?download=1" ) DOWNLOADS = { - LEANDOJO_BENCHMARK_4_URL: "c45383c1a94b0ab17395401fc8b03f36", + LEANDOJO_BENCHMARK_4_URL: "84a75ce552b31731165d55542b1aaca9", }