Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Apr 10, 2024
1 parent 822dcb1 commit d69dc30
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ After the tactic generator is trained, we combine it with best-first search to p

For models without retrieval, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 12 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 12 --num-gpus 1
```

For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
Expand Down
4 changes: 2 additions & 2 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory.
eval_num_workers: 5 # Lower this number if you don't have 80GB GPU memory.
eval_num_gpus: 1
eval_num_theorems: 300 # Lower this number will make validation faster (but noiser).
eval_num_theorems: 250 # Lower this number will make validation faster (but noiser).

data:
data_path: data/leandojo_benchmark_4/novel_premises/
Expand Down
4 changes: 2 additions & 2 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory.
eval_num_workers: 5 # Lower this number if you don't have 80GB GPU memory.
eval_num_gpus: 1
eval_num_theorems: 300 # Lower this number will make validation faster (but noiser).
eval_num_theorems: 250 # Lower this number will make validation faster (but noiser).

data:
data_path: data/leandojo_benchmark_4/random/
Expand Down
6 changes: 5 additions & 1 deletion generator/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Lightning module for the tactic generator."""

import os
import torch
import openai
import pickle
Expand Down Expand Up @@ -243,7 +244,7 @@ def on_validation_epoch_end(self) -> None:

from prover.evaluate import evaluate # Avoid circular import.

ckpt_path = f"{self.trainer.log_dir}/checkpoints/last.ckpt"
ckpt_path = f"{self.trainer.log_dir}/checkpoints/last-tmp.ckpt"
self.trainer.save_checkpoint(ckpt_path)
logger.info(f"Saved checkpoint to {ckpt_path}. Evaluating...")
torch.cuda.empty_cache()
Expand Down Expand Up @@ -278,6 +279,9 @@ def on_validation_epoch_end(self) -> None:
self.log("Pass@1_val", acc, on_step=False, on_epoch=True, sync_dist=True)
logger.info(f"Pass@1: {acc}")

if os.path.exists(ckpt_path):
os.remove(ckpt_path)

##############
# Prediction #
##############
Expand Down
19 changes: 9 additions & 10 deletions retrieval/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,15 @@ def prepare_data(self) -> None:
pass

def setup(self, stage: Optional[str] = None) -> None:
if stage in (None, "fit"):
self.ds_train = RetrievalDataset(
[os.path.join(self.data_path, "train.json")],
self.corpus,
self.num_negatives,
self.num_in_file_negatives,
self.max_seq_len,
self.tokenizer,
is_train=True,
)
self.ds_train = RetrievalDataset(
[os.path.join(self.data_path, "train.json")],
self.corpus,
self.num_negatives,
self.num_in_file_negatives,
self.max_seq_len,
self.tokenizer,
is_train=True,
)

if stage in (None, "fit", "validate"):
self.ds_val = RetrievalDataset(
Expand Down

0 comments on commit d69dc30

Please sign in to comment.