Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 6, 2024
1 parent 777265a commit abe325e
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 125 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,23 @@ Check out [Lean Copilot](https://github.com/lean-dojo/LeanCopilot) if you want t
1. Download and install [Miniconda Python 3](https://docs.conda.io/en/latest/miniconda.html) (Anaconda should also work).
2. Create the conda environment and install Python dependencies:
```bash
conda create --yes --name ReProver python=3.10 ipython numpy
conda create --yes --name ReProver python=3.11 ipython
conda activate ReProver
pip install torch --index-url https://download.pytorch.org/whl/cu121 # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers tensorboard openai rank_bm25 lean-dojo
pip install torch # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers wandb openai rank_bm25 lean-dojo vllm
pip install git+https://github.com/pytorch/torchtune
```
3. Prepend the repo's root to the `PYTHONPATH` environment variable.
4. Make sure `wget` and `tar` are available. Then, run `python scripts/download_data.py` to download [LeanDojo Benchmark 4](https://zenodo.org/doi/10.5281/zenodo.8040109). They will be saved to `./data`.
5. Satisfy the requirements of [LeanDojo](https://github.com/lean-dojo/LeanDojo#requirements).
6. Use [LeanDojo](https://github.com/lean-dojo/LeanDojo) to trace all repos in the datasets: `python scripts/trace_repos.py`. This step may take some time. Please refer to [LeanDojo's documentation](https://leandojo.readthedocs.io/en/latest/) if you encounter any issues.
7. Log in Weights & Biases and set its log directory.
```bash
wandb login
export WANDB_DIR=$(pwd)/tmp
mkdir -p tmp/wandb
mkdir logs
```


## Premise Selection
Expand Down
34 changes: 5 additions & 29 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
from transformers import get_cosine_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from typing import Optional, List, Dict, Any, Tuple, Generator
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
Expand Down Expand Up @@ -413,45 +412,22 @@ def format_augmented_state(


def get_optimizers(
parameters, trainer: pl.Trainer, lr: float, warmup_steps: int
) -> Dict[str, Any]:
parameters, trainer: pl.Trainer, lr: float) -> Dict[str, Any]:
"""Return an AdamW optimizer with cosine warmup learning rate schedule."""
strategy = trainer.strategy

if isinstance(strategy, DeepSpeedStrategy):
if "offload_optimizer" in strategy.config["zero_optimization"]:
logger.info("Optimizing with DeepSpeedCPUAdam")
optimizer = DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
return DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
else:
logger.info("Optimizing with FusedAdam")
optimizer = FusedAdam(parameters, lr=lr, adam_w_mode=True)
return FusedAdam(parameters, lr=lr, adam_w_mode=True)
else:
logger.info("Optimizing with AdamW")
optimizer = torch.optim.AdamW(parameters, lr=lr)
return torch.optim.AdamW(parameters, lr=lr)

if trainer.max_steps != -1:
max_steps = trainer.max_steps
else:
assert trainer.max_epochs is not None
max_steps = (
trainer.max_epochs
* len(trainer.datamodule.train_dataloader())
// trainer.accumulate_grad_batches
)

scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
)

return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}



def _is_deepspeed_checkpoint(path: str):
Expand Down
5 changes: 5 additions & 0 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_novel_premises
gradient_clip_val: 1.0
max_steps: 500000
check_val_every_n_epoch: 1
Expand Down
6 changes: 6 additions & 0 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_random
save_dir: logs/generator_random
gradient_clip_val: 1.0
max_steps: 500000
check_val_every_n_epoch: 1
Expand Down
51 changes: 43 additions & 8 deletions generator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import openai
import pickle
from vllm import LLM
from lean_dojo import Pos
from loguru import logger
import pytorch_lightning as pl
Expand Down Expand Up @@ -164,28 +165,30 @@ def training_step(self, batch, batch_idx: int):
return loss

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(
self.parameters(), self.trainer, self.lr, self.warmup_steps
)
return get_optimizers(self.parameters(), self.trainer, self.lr)

def _log_io_texts(
self,
split: str,
state_ids: torch.LongTensor,
tactic_ids: torch.LongTensor,
) -> None:
tb = self.logger.experiment
inp = self.tokenizer.decode(state_ids[0], skip_special_tokens=True)
oup_ids = torch.where(
tactic_ids[0] == -100, self.tokenizer.pad_token_id, tactic_ids[0]
)
oup = self.tokenizer.decode(oup_ids, skip_special_tokens=True)
tb.add_text(f"{split}_state", f"```\n{inp}\n```", self.global_step)
tb.add_text(f"{split}_tactic", f"`{oup}`", self.global_step)
self.logger.log_text(
f"{split}_samples",
["state", "tactic"],
[[inp, oup]],
step=self.global_step,
)

def on_fit_start(self) -> None:
if self.logger is not None:
self.logger.log_hyperparams(self.hparams)
self.logger.watch(self.generator)
assert self.trainer is not None
logger.info(f"Logging to {self.trainer.log_dir}")

Expand Down Expand Up @@ -223,9 +226,8 @@ def validation_step(self, batch: Dict[str, Any], _) -> None:
for i in range(batch_size)
]

tb = self.logger.experiment
msg = "\n".join(tactics_pred[0])
tb.add_text(f"preds_val", f"```\n{msg}\n```", self.global_step)
self.logger.log_text("preds_val", ["tactics"], [[msg]], step=self.global_step)

# Log the topk accuracies.
for k in range(1, self.num_beams + 1):
Expand Down Expand Up @@ -524,3 +526,36 @@ def batch_generate(
self.generate(s, f, tfn, tp, num_samples)
for s, f, tfn, tp in zip(state, file_path, theorem_full_name, theorem_pos)
]


class SyncVllmGenerator(TacticGenerator):
def __init__(self, model_path: str, num_gpus: int) -> None:
self.llm = LLM(model_path, tensor_parallel_size=num_gpus)

@abstractmethod
def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
import pdb

pdb.set_trace()
raise NotImplementedError

@abstractmethod
def batch_generate(
self,
state: List[str],
file_path: List[str],
theorem_full_name: List[str],
theorem_pos: List[Pos],
num_samples: int,
) -> List[List[Tuple[str, float]]]:
import pdb

pdb.set_trace()
raise NotImplementedError
10 changes: 7 additions & 3 deletions prover/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def evaluate(
timeout: int = 600,
num_workers: int = 1,
num_gpus: int = 0,
save_results: bool = False,
verbose: bool = False,
) -> float:
set_logger(verbose)
Expand Down Expand Up @@ -149,9 +150,10 @@ def evaluate(
# Save the results.
if exp_id is None:
exp_id = str(uuid.uuid4())
pickle_path = f"{exp_id}_results.pickle"
pickle.dump(results, open(pickle_path, "wb"))
logger.info(f"Results saved to {pickle_path}")
if save_results:
pickle_path = f"{exp_id}_results.pickle"
pickle.dump(results, open(pickle_path, "wb"))
logger.info(f"Results saved to {pickle_path}")

return pass_1

Expand Down Expand Up @@ -208,6 +210,7 @@ def main() -> None:
parser.add_argument(
"--num-gpus", type=int, default=0, help="The number of GPUs for proof search."
)
parser.add_argument("--save-results", action="store_true")
parser.add_argument(
"--verbose", action="store_true", help="Set the logging level to DEBUG."
)
Expand Down Expand Up @@ -235,6 +238,7 @@ def main() -> None:
args.timeout,
args.num_workers,
args.num_gpus,
args.save_results,
args.verbose,
)

Expand Down
2 changes: 1 addition & 1 deletion prover/proof_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Proof search using best-first search.
"""

import os
import sys
import ray
import time
Expand Down Expand Up @@ -390,6 +389,7 @@ def __init__(
self.distributed = num_workers > 1

if not self.distributed:
assert num_gpus <= 1
if ckpt_path is None:
tac_gen = FixedTacticGenerator(tactic, module)
else:
Expand Down
2 changes: 1 addition & 1 deletion prover/search_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from functools import total_ordering
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Iterable, Union
from typing import Optional, List, Iterable, Union


class Status(Enum):
Expand Down
6 changes: 6 additions & 0 deletions retrieval/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: retrieval_novel_premises
save_dir: ./wandb_logs/retrieval_novel_premises
gradient_clip_val: 1.0
max_steps: 800000
callbacks:
Expand Down
6 changes: 6 additions & 0 deletions retrieval/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: retrieval_random
save_dir: ./wandb_logs/retrieval_random
gradient_clip_val: 1.0
max_steps: 800000
callbacks:
Expand Down
4 changes: 1 addition & 3 deletions retrieval/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def on_train_batch_end(self, outputs, batch, _) -> None:
self.embeddings_staled = True

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(
self.parameters(), self.trainer, self.lr, self.warmup_steps
)
return get_optimizers(self.parameters(), self.trainer, self.lr)

##############
# Validation #
Expand Down
77 changes: 0 additions & 77 deletions torchtune/confs/llama3-8B_full.yaml

This file was deleted.

0 comments on commit abe325e

Please sign in to comment.