From 3c6dafdb1693806966106760f48395c0dd3aa3eb Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Tue, 9 Jul 2024 00:48:38 +0000 Subject: [PATCH] vLLM + LLaMA works well --- common.py | 40 +++++++++---------- generator/confs/cli_lean4_novel_premises.yaml | 7 +--- generator/confs/cli_lean4_random.yaml | 8 +--- generator/confs/torchtune-llama3-8B_full.yaml | 7 ++-- generator/datamodule.py | 29 +++----------- generator/full_finetune_distributed.py | 7 ++++ generator/model.py | 29 +++++++------- generator/preprocess_data.py | 22 +++++++--- generator/template.py | 3 +- prover/proof_search.py | 8 +++- retrieval/model.py | 2 +- 11 files changed, 79 insertions(+), 83 deletions(-) diff --git a/common.py b/common.py index cd90f0b..7be2a79 100644 --- a/common.py +++ b/common.py @@ -13,6 +13,7 @@ from pytorch_lightning.utilities.deepspeed import ( convert_zero_checkpoint_to_fp32_state_dict, ) +from transformers import get_constant_schedule_with_warmup from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam from typing import Optional, List, Dict, Any, Tuple, Generator from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy @@ -353,18 +354,8 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]: return list(all_pos_premises) -_SPACES_REGEX = re.compile(r"\s+", re.DOTALL) - - -def normalize_spaces(s: str) -> str: - """Repalce any consecutive block of whitespace characters in ``s`` with a single whitespace.""" - return _SPACES_REGEX.sub(" ", s).strip() - - -def format_tactic(annot_tac: str, provenances, normalize: bool) -> str: +def format_tactic(annot_tac: str, provenances) -> str: """Use full names for the all ....""" - if normalize: - annot_tac = normalize_spaces(annot_tac) if len(provenances) == 0: return annot_tac @@ -412,22 +403,30 @@ def format_augmented_state( def get_optimizers( - parameters, trainer: pl.Trainer, lr: float) -> Dict[str, Any]: + parameters, trainer: pl.Trainer, lr: float, warmup_steps: int +) -> Dict[str, Any]: """Return an AdamW optimizer with cosine warmup learning rate schedule.""" strategy = trainer.strategy if isinstance(strategy, DeepSpeedStrategy): if "offload_optimizer" in strategy.config["zero_optimization"]: logger.info("Optimizing with DeepSpeedCPUAdam") - return DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True) + optimizer = DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True) else: logger.info("Optimizing with FusedAdam") - return FusedAdam(parameters, lr=lr, adam_w_mode=True) + optimizer = FusedAdam(parameters, lr=lr, adam_w_mode=True) else: logger.info("Optimizing with AdamW") - return torch.optim.AdamW(parameters, lr=lr) + optimizer = torch.optim.AdamW(parameters, lr=lr) - + scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + }, + } def _is_deepspeed_checkpoint(path: str): @@ -438,14 +437,13 @@ def _is_deepspeed_checkpoint(path: str): def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool): """Handle DeepSpeed checkpoints in model loading.""" - if not _is_deepspeed_checkpoint(ckpt_path): - model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device) - else: + if _is_deepspeed_checkpoint(ckpt_path): with tempfile.TemporaryDirectory() as dirname: path = os.path.join(dirname, "lightning.cpkt") convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path) - model = model_cls.load_from_checkpoint(path, strict=False) - model = model.to(device) + model = model_cls.load_from_checkpoint(path, strict=False).to(device) + else: # PyTorch Ligthning checkpoints + model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device) if freeze: model.freeze() return model diff --git a/generator/confs/cli_lean4_novel_premises.yaml b/generator/confs/cli_lean4_novel_premises.yaml index 567193e..50128af 100644 --- a/generator/confs/cli_lean4_novel_premises.yaml +++ b/generator/confs/cli_lean4_novel_premises.yaml @@ -12,9 +12,8 @@ trainer: logger: class_path: pytorch_lightning.loggers.WandbLogger init_args: - project: ReProver - name: generator_novel_premises - gradient_clip_val: 1.0 + name: null + save_dir: null max_steps: 500000 check_val_every_n_epoch: 1 num_sanity_val_steps: 0 @@ -51,12 +50,10 @@ model: data: data_path: data/leandojo_benchmark_4/novel_premises/ corpus_path: data/leandojo_benchmark_4/corpus.jsonl - keep_marks: true preds_path: null batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices eval_batch_size: 64 max_inp_seq_len: 2300 max_oup_seq_len: 512 p_drop: 0.5 - normalize_tactics: true num_workers: 2 diff --git a/generator/confs/cli_lean4_random.yaml b/generator/confs/cli_lean4_random.yaml index b5a6555..6f03a45 100644 --- a/generator/confs/cli_lean4_random.yaml +++ b/generator/confs/cli_lean4_random.yaml @@ -12,10 +12,8 @@ trainer: logger: class_path: pytorch_lightning.loggers.WandbLogger init_args: - project: ReProver - name: generator_random - save_dir: logs/generator_random - gradient_clip_val: 1.0 + name: null + save_dir: null max_steps: 500000 check_val_every_n_epoch: 1 num_sanity_val_steps: 0 @@ -52,12 +50,10 @@ model: data: data_path: data/leandojo_benchmark_4/random/ corpus_path: data/leandojo_benchmark_4/corpus.jsonl - keep_marks: true preds_path: null batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices eval_batch_size: 64 max_inp_seq_len: 2300 max_oup_seq_len: 512 p_drop: 0.5 - normalize_tactics: true num_workers: 2 diff --git a/generator/confs/torchtune-llama3-8B_full.yaml b/generator/confs/torchtune-llama3-8B_full.yaml index 3cd35c0..7a5c490 100644 --- a/generator/confs/torchtune-llama3-8B_full.yaml +++ b/generator/confs/torchtune-llama3-8B_full.yaml @@ -49,18 +49,19 @@ checkpointer: model-00004-of-00004.safetensors, ] recipe_checkpoint: null - output_dir: ./models/Meta-Llama-3-8B-finetuned/ + output_dir: ./models/Meta-Llama-3-8B-finetuned-lr2e-5/ model_type: LLAMA3 resume_from_checkpoint: False # Fine-tuning arguments batch_size: 4 -epochs: 1 +epochs: 5 optimizer: _component_: torch.optim.AdamW lr: 2e-5 foreach: False +warmup_steps: 2000 loss: _component_: torch.nn.CrossEntropyLoss @@ -83,6 +84,6 @@ metric_logger: _component_: torchtune.utils.metric_logging.WandBLogger project: ReProver log_dir: ${output_dir} -output_dir: ./logs/leandojo-llama3-finetune +output_dir: ./logs/Meta-Llama-3-8B-finetuned-lr2e-5/ log_every_n_steps: 1 log_peak_memory_stats: false diff --git a/generator/datamodule.py b/generator/datamodule.py index d0654f7..573ade9 100644 --- a/generator/datamodule.py +++ b/generator/datamodule.py @@ -26,43 +26,36 @@ def __init__( self, data_path: str, corpus: Corpus, - keep_marks: bool, preds: List[Dict[str, Any]], max_inp_seq_len: int, max_oup_seq_len: int, p_drop: float, - normalize_tactics: bool, tokenizer: ByT5Tokenizer, is_train: bool, ) -> None: super().__init__() self.corpus = corpus - self.keep_marks = keep_marks self.preds = preds self.max_inp_seq_len = max_inp_seq_len self.max_oup_seq_len = max_oup_seq_len self.p_drop = p_drop self.tokenizer = tokenizer self.is_train = is_train - self.data = self._load_data(data_path, normalize_tactics) + self.data = self._load_data(data_path) - def _load_data(self, data_path: str, normalize_tactics: bool) -> List[Example]: + def _load_data(self, data_path: str) -> List[Example]: data = [] for thm in tqdm(json.load(open(data_path))): for tac in thm["traced_tactics"]: - if "annotated_tactic" in tac: - tactic = format_tactic(*tac["annotated_tactic"], normalize_tactics) - else: - tactic = format_tactic(tac["tactic"], [], normalize_tactics) - if not self.keep_marks: - tactic = remove_marks(tactic) + tactic = remove_marks(tac["tactic"]) data.append( { "url": thm["url"], "commit": thm["commit"], "file_path": thm["file_path"], "full_name": thm["full_name"], - "state": format_state(tac["state_before"]), + # "state": format_state(tac["state_before"]), + "state": tac["state_before"], "tactic": tactic, } ) @@ -86,9 +79,7 @@ def __getitem__(self, idx: int) -> Example: self.p_drop if self.is_train else 0.0, ) - if not self.keep_marks: - ex["state"] = remove_marks(ex["state"]) - + ex["state"] = remove_marks(ex["state"]) return ex def collate(self, examples: List[Example]) -> Batch: @@ -131,14 +122,12 @@ class GeneratorDataModule(pl.LightningDataModule): def __init__( self, data_path: str, - keep_marks: bool, model_name: str, batch_size: int, eval_batch_size: int, max_inp_seq_len: int, max_oup_seq_len: int, p_drop: float, - normalize_tactics: bool, num_workers: int, corpus_path: Optional[str] = None, preds_path: Optional[str] = None, @@ -149,13 +138,11 @@ def __init__( self.corpus = Corpus(corpus_path) else: self.corpus = None - self.keep_marks = keep_marks self.batch_size = batch_size self.eval_batch_size = eval_batch_size self.max_inp_seq_len = max_inp_seq_len self.max_oup_seq_len = max_oup_seq_len self.p_drop = p_drop - self.normalize_tactics = normalize_tactics self.num_workers = num_workers self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -177,12 +164,10 @@ def setup(self, stage: Optional[str] = None) -> None: self.ds_train = GeneratorDataset( os.path.join(self.data_path, "train.json"), self.corpus, - self.keep_marks, self.preds, self.max_inp_seq_len, self.max_oup_seq_len, self.p_drop, - self.normalize_tactics, self.tokenizer, is_train=True, ) @@ -191,12 +176,10 @@ def setup(self, stage: Optional[str] = None) -> None: self.ds_val = GeneratorDataset( os.path.join(self.data_path, "val.json"), self.corpus, - self.keep_marks, self.preds, self.max_inp_seq_len, self.max_oup_seq_len, self.p_drop, - self.normalize_tactics, self.tokenizer, is_train=False, ) diff --git a/generator/full_finetune_distributed.py b/generator/full_finetune_distributed.py index 0e637cd..a0c53ec 100644 --- a/generator/full_finetune_distributed.py +++ b/generator/full_finetune_distributed.py @@ -24,6 +24,8 @@ StateDictType, ) from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from transformers import get_constant_schedule_with_warmup from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, utils @@ -226,6 +228,9 @@ def setup(self, cfg: DictConfig) -> None: ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None ), ) + self._scheduler = get_constant_schedule_with_warmup( + self._optimizer, cfg.warmup_steps + ) self._loss_fn = config.instantiate(cfg.loss) @@ -374,6 +379,7 @@ def _setup_optimizer( if self._is_rank_zero: log.info("Optimizer is initialized.") + return optimizer def _setup_data( @@ -534,6 +540,7 @@ def train(self) -> None: if (idx + 1) % self._gradient_accumulation_steps == 0: self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) + self._scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 diff --git a/generator/model.py b/generator/model.py index b171d5e..62a8633 100644 --- a/generator/model.py +++ b/generator/model.py @@ -2,7 +2,6 @@ import os import ray -import math import torch import shutil import openai @@ -12,7 +11,6 @@ import pytorch_lightning as pl from torchmetrics import Metric from abc import ABC, abstractmethod -from vllm import SamplingParams from typing import List, Dict, Any, Optional, Tuple from transformers import T5ForConditionalGeneration, AutoTokenizer @@ -25,6 +23,7 @@ format_augmented_state, ) from retrieval.model import PremiseRetriever +from generator.template import StateTacticPairTemplate torch.set_float32_matmul_precision("medium") @@ -167,7 +166,9 @@ def training_step(self, batch, batch_idx: int): return loss def configure_optimizers(self) -> Dict[str, Any]: - return get_optimizers(self.parameters(), self.trainer, self.lr) + return get_optimizers( + self.parameters(), self.trainer, self.lr, self.warmup_steps + ) def _log_io_texts( self, @@ -542,13 +543,12 @@ def generate( theorem_pos: Pos, num_samples: int, ) -> List[Tuple[str, float]]: - outputs = ray.get( - self.vllm_actor.generate.remote( - f"### State:\n{state}\n\n### Tactic:", num_samples - ) - ) + # prompt = StateTacticPairTemplate.format({"state": state}) + # prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + prompt = f"[GOAL]\n{state}\n[PROOFSTEP]\n" + outputs = ray.get(self.vllm_actor.generate.remote(prompt, num_samples)) return [ - (remove_marks(x.text), math.exp(x.cumulative_logprob)) + (remove_marks(x.text).strip(), x.cumulative_logprob) for x in outputs[0].outputs ] @@ -560,12 +560,11 @@ def batch_generate( theorem_pos: List[Pos], num_samples: int, ) -> List[List[Tuple[str, float]]]: - inputs = [f"### State:\n{s}\n\n### Tactic:" for s in state] - outputs = ray.get(self.vllm_actor.generate.remote(inputs, num_samples)) + # prompts = [StateTacticPairTemplate.format({"state": s}) for s in state] + # prompts = [f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{s}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" for s in state] + prompts = [f"[GOAL]\n{s}\n[PROOFSTEP]\n" for s in state] + outputs = ray.get(self.vllm_actor.generate.remote(prompts, num_samples)) return [ - [ - (remove_marks(x.text), math.exp(x.cumulative_logprob)) - for x in oup.outputs - ] + [(remove_marks(x.text).strip(), x.cumulative_logprob) for x in oup.outputs] for oup in outputs ] diff --git a/generator/preprocess_data.py b/generator/preprocess_data.py index 51cb435..665dc77 100644 --- a/generator/preprocess_data.py +++ b/generator/preprocess_data.py @@ -12,19 +12,31 @@ def main() -> None: for thm in json.load(open(data_path)): for tac in thm["traced_tactics"]: - if "annotated_tactic" in tac: - tactic = format_tactic(*tac["annotated_tactic"], normalize=True) - else: - tactic = format_tactic(tac["tactic"], [], normalize=True) - pairs.append({"state": format_state(tac["state_before"]), "output": tactic}) + # if "annotated_tactic" in tac: + # tactic = format_tactic(*tac["annotated_tactic"], normalize=True) + # else: + # tactic = format_tactic(tac["tactic"], [], normalize=True) + pairs.append({"state": tac["state_before"], "output": tac["tactic"]}) random.shuffle(pairs) + """ with open("state_tactic_pairs.csv", "wt") as oup: wt = csv.DictWriter(oup, fieldnames=["state", "output"]) wt.writeheader() for st in pairs: wt.writerow(st) + """ + data = [] + for pair in pairs: + data.append( + { + "instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n", + "input": "", + "output": pair["output"], + } + ) + json.dump(data, open("state_tactic_pairs.json", "wt")) if __name__ == "__main__": diff --git a/generator/template.py b/generator/template.py index aa6d227..cc5b710 100644 --- a/generator/template.py +++ b/generator/template.py @@ -3,8 +3,7 @@ class StateTacticPairTemplate(InstructTemplate): - template = "### State:\n{state}\n\n### Tactic:" - # template = "[GOAL]\n{state}\n[PROOFSTEP]\n" + template = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n[PROOFSTEP]\n" @classmethod def format( diff --git a/prover/proof_search.py b/prover/proof_search.py index 8752e18..023d855 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -331,7 +331,7 @@ class VllmActor: def __init__(self, model_path: str) -> None: self.num_gpus = len(ray.get_gpu_ids()) self.model_path = model_path - + def initialize(self) -> None: logger.info("Initializing vLLM") # TODO: Try `--enable-prefix-caching` and other parameters in https://docs.vllm.ai/en/stable/models/engine_args.html#engine-args. @@ -341,7 +341,11 @@ def generate( self, inputs: Union[str, List[str]], num_samples: int ) -> List[RequestOutput]: sampling_params = SamplingParams( - n=num_samples, temperature=0, use_beam_search=True, early_stopping=False + n=num_samples, + temperature=0, + length_penalty=0, + use_beam_search=True, + early_stopping=False, ) outputs = self.llm.generate(inputs, sampling_params, use_tqdm=False) if isinstance(inputs, str): diff --git a/retrieval/model.py b/retrieval/model.py index e912e27..75356ad 100644 --- a/retrieval/model.py +++ b/retrieval/model.py @@ -157,7 +157,7 @@ def on_train_batch_end(self, outputs, batch, _) -> None: self.embeddings_staled = True def configure_optimizers(self) -> Dict[str, Any]: - return get_optimizers(self.parameters(), self.trainer, self.lr) + return get_optimizers(self.parameters(), self.trainer, self.lr, self.warmup_steps) ############## # Validation #