Skip to content

Commit

Permalink
vLLM + LLaMA works well
Browse files Browse the repository at this point in the history
yangky11 committed Jul 9, 2024
1 parent b9a83f8 commit 3c6dafd
Showing 11 changed files with 79 additions and 83 deletions.
40 changes: 19 additions & 21 deletions common.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
from transformers import get_constant_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from typing import Optional, List, Dict, Any, Tuple, Generator
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
@@ -353,18 +354,8 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:
return list(all_pos_premises)


_SPACES_REGEX = re.compile(r"\s+", re.DOTALL)


def normalize_spaces(s: str) -> str:
"""Repalce any consecutive block of whitespace characters in ``s`` with a single whitespace."""
return _SPACES_REGEX.sub(" ", s).strip()


def format_tactic(annot_tac: str, provenances, normalize: bool) -> str:
def format_tactic(annot_tac: str, provenances) -> str:
"""Use full names for the all <a>...</a>."""
if normalize:
annot_tac = normalize_spaces(annot_tac)
if len(provenances) == 0:
return annot_tac

@@ -412,22 +403,30 @@ def format_augmented_state(


def get_optimizers(
parameters, trainer: pl.Trainer, lr: float) -> Dict[str, Any]:
parameters, trainer: pl.Trainer, lr: float, warmup_steps: int
) -> Dict[str, Any]:
"""Return an AdamW optimizer with cosine warmup learning rate schedule."""
strategy = trainer.strategy

if isinstance(strategy, DeepSpeedStrategy):
if "offload_optimizer" in strategy.config["zero_optimization"]:
logger.info("Optimizing with DeepSpeedCPUAdam")
return DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
optimizer = DeepSpeedCPUAdam(parameters, lr=lr, adamw_mode=True)
else:
logger.info("Optimizing with FusedAdam")
return FusedAdam(parameters, lr=lr, adam_w_mode=True)
optimizer = FusedAdam(parameters, lr=lr, adam_w_mode=True)
else:
logger.info("Optimizing with AdamW")
return torch.optim.AdamW(parameters, lr=lr)
optimizer = torch.optim.AdamW(parameters, lr=lr)


scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}


def _is_deepspeed_checkpoint(path: str):
@@ -438,14 +437,13 @@ def _is_deepspeed_checkpoint(path: str):

def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool):
"""Handle DeepSpeed checkpoints in model loading."""
if not _is_deepspeed_checkpoint(ckpt_path):
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
else:
if _is_deepspeed_checkpoint(ckpt_path):
with tempfile.TemporaryDirectory() as dirname:
path = os.path.join(dirname, "lightning.cpkt")
convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path)
model = model_cls.load_from_checkpoint(path, strict=False)
model = model.to(device)
model = model_cls.load_from_checkpoint(path, strict=False).to(device)
else: # PyTorch Ligthning checkpoints
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
if freeze:
model.freeze()
return model
7 changes: 2 additions & 5 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
@@ -12,9 +12,8 @@ trainer:
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_novel_premises
gradient_clip_val: 1.0
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
@@ -51,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/novel_premises/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
8 changes: 2 additions & 6 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
@@ -12,10 +12,8 @@ trainer:
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
project: ReProver
name: generator_random
save_dir: logs/generator_random
gradient_clip_val: 1.0
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
@@ -52,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
7 changes: 4 additions & 3 deletions generator/confs/torchtune-llama3-8B_full.yaml
Original file line number Diff line number Diff line change
@@ -49,18 +49,19 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
output_dir: ./models/Meta-Llama-3-8B-finetuned/
output_dir: ./models/Meta-Llama-3-8B-finetuned-lr2e-5/
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
epochs: 1
epochs: 5

optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
warmup_steps: 2000

loss:
_component_: torch.nn.CrossEntropyLoss
@@ -83,6 +84,6 @@ metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
project: ReProver
log_dir: ${output_dir}
output_dir: ./logs/leandojo-llama3-finetune
output_dir: ./logs/Meta-Llama-3-8B-finetuned-lr2e-5/
log_every_n_steps: 1
log_peak_memory_stats: false
29 changes: 6 additions & 23 deletions generator/datamodule.py
Original file line number Diff line number Diff line change
@@ -26,43 +26,36 @@ def __init__(
self,
data_path: str,
corpus: Corpus,
keep_marks: bool,
preds: List[Dict[str, Any]],
max_inp_seq_len: int,
max_oup_seq_len: int,
p_drop: float,
normalize_tactics: bool,
tokenizer: ByT5Tokenizer,
is_train: bool,
) -> None:
super().__init__()
self.corpus = corpus
self.keep_marks = keep_marks
self.preds = preds
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len
self.p_drop = p_drop
self.tokenizer = tokenizer
self.is_train = is_train
self.data = self._load_data(data_path, normalize_tactics)
self.data = self._load_data(data_path)

def _load_data(self, data_path: str, normalize_tactics: bool) -> List[Example]:
def _load_data(self, data_path: str) -> List[Example]:
data = []
for thm in tqdm(json.load(open(data_path))):
for tac in thm["traced_tactics"]:
if "annotated_tactic" in tac:
tactic = format_tactic(*tac["annotated_tactic"], normalize_tactics)
else:
tactic = format_tactic(tac["tactic"], [], normalize_tactics)
if not self.keep_marks:
tactic = remove_marks(tactic)
tactic = remove_marks(tac["tactic"])
data.append(
{
"url": thm["url"],
"commit": thm["commit"],
"file_path": thm["file_path"],
"full_name": thm["full_name"],
"state": format_state(tac["state_before"]),
# "state": format_state(tac["state_before"]),
"state": tac["state_before"],
"tactic": tactic,
}
)
@@ -86,9 +79,7 @@ def __getitem__(self, idx: int) -> Example:
self.p_drop if self.is_train else 0.0,
)

if not self.keep_marks:
ex["state"] = remove_marks(ex["state"])

ex["state"] = remove_marks(ex["state"])
return ex

def collate(self, examples: List[Example]) -> Batch:
@@ -131,14 +122,12 @@ class GeneratorDataModule(pl.LightningDataModule):
def __init__(
self,
data_path: str,
keep_marks: bool,
model_name: str,
batch_size: int,
eval_batch_size: int,
max_inp_seq_len: int,
max_oup_seq_len: int,
p_drop: float,
normalize_tactics: bool,
num_workers: int,
corpus_path: Optional[str] = None,
preds_path: Optional[str] = None,
@@ -149,13 +138,11 @@ def __init__(
self.corpus = Corpus(corpus_path)
else:
self.corpus = None
self.keep_marks = keep_marks
self.batch_size = batch_size
self.eval_batch_size = eval_batch_size
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len
self.p_drop = p_drop
self.normalize_tactics = normalize_tactics
self.num_workers = num_workers
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

@@ -177,12 +164,10 @@ def setup(self, stage: Optional[str] = None) -> None:
self.ds_train = GeneratorDataset(
os.path.join(self.data_path, "train.json"),
self.corpus,
self.keep_marks,
self.preds,
self.max_inp_seq_len,
self.max_oup_seq_len,
self.p_drop,
self.normalize_tactics,
self.tokenizer,
is_train=True,
)
@@ -191,12 +176,10 @@ def setup(self, stage: Optional[str] = None) -> None:
self.ds_val = GeneratorDataset(
os.path.join(self.data_path, "val.json"),
self.corpus,
self.keep_marks,
self.preds,
self.max_inp_seq_len,
self.max_oup_seq_len,
self.p_drop,
self.normalize_tactics,
self.tokenizer,
is_train=False,
)
7 changes: 7 additions & 0 deletions generator/full_finetune_distributed.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,8 @@
StateDictType,
)
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import get_constant_schedule_with_warmup
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, utils
@@ -226,6 +228,9 @@ def setup(self, cfg: DictConfig) -> None:
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)
self._scheduler = get_constant_schedule_with_warmup(
self._optimizer, cfg.warmup_steps
)

self._loss_fn = config.instantiate(cfg.loss)

@@ -374,6 +379,7 @@ def _setup_optimizer(

if self._is_rank_zero:
log.info("Optimizer is initialized.")

return optimizer

def _setup_data(
@@ -534,6 +540,7 @@ def train(self) -> None:
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._scheduler.step()

# Update the number of steps when the weights are updated
self.global_step += 1
29 changes: 14 additions & 15 deletions generator/model.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

import os
import ray
import math
import torch
import shutil
import openai
@@ -12,7 +11,6 @@
import pytorch_lightning as pl
from torchmetrics import Metric
from abc import ABC, abstractmethod
from vllm import SamplingParams
from typing import List, Dict, Any, Optional, Tuple
from transformers import T5ForConditionalGeneration, AutoTokenizer

@@ -25,6 +23,7 @@
format_augmented_state,
)
from retrieval.model import PremiseRetriever
from generator.template import StateTacticPairTemplate


torch.set_float32_matmul_precision("medium")
@@ -167,7 +166,9 @@ def training_step(self, batch, batch_idx: int):
return loss

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(self.parameters(), self.trainer, self.lr)
return get_optimizers(
self.parameters(), self.trainer, self.lr, self.warmup_steps
)

def _log_io_texts(
self,
@@ -542,13 +543,12 @@ def generate(
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
outputs = ray.get(
self.vllm_actor.generate.remote(
f"### State:\n{state}\n\n### Tactic:", num_samples
)
)
# prompt = StateTacticPairTemplate.format({"state": state})
# prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
prompt = f"[GOAL]\n{state}\n[PROOFSTEP]\n"
outputs = ray.get(self.vllm_actor.generate.remote(prompt, num_samples))
return [
(remove_marks(x.text), math.exp(x.cumulative_logprob))
(remove_marks(x.text).strip(), x.cumulative_logprob)
for x in outputs[0].outputs
]

@@ -560,12 +560,11 @@ def batch_generate(
theorem_pos: List[Pos],
num_samples: int,
) -> List[List[Tuple[str, float]]]:
inputs = [f"### State:\n{s}\n\n### Tactic:" for s in state]
outputs = ray.get(self.vllm_actor.generate.remote(inputs, num_samples))
# prompts = [StateTacticPairTemplate.format({"state": s}) for s in state]
# prompts = [f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{s}\n[PROOFSTEP]\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" for s in state]
prompts = [f"[GOAL]\n{s}\n[PROOFSTEP]\n" for s in state]
outputs = ray.get(self.vllm_actor.generate.remote(prompts, num_samples))
return [
[
(remove_marks(x.text), math.exp(x.cumulative_logprob))
for x in oup.outputs
]
[(remove_marks(x.text).strip(), x.cumulative_logprob) for x in oup.outputs]
for oup in outputs
]
22 changes: 17 additions & 5 deletions generator/preprocess_data.py
Original file line number Diff line number Diff line change
@@ -12,19 +12,31 @@ def main() -> None:

for thm in json.load(open(data_path)):
for tac in thm["traced_tactics"]:
if "annotated_tactic" in tac:
tactic = format_tactic(*tac["annotated_tactic"], normalize=True)
else:
tactic = format_tactic(tac["tactic"], [], normalize=True)
pairs.append({"state": format_state(tac["state_before"]), "output": tactic})
# if "annotated_tactic" in tac:
# tactic = format_tactic(*tac["annotated_tactic"], normalize=True)
# else:
# tactic = format_tactic(tac["tactic"], [], normalize=True)
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})

random.shuffle(pairs)

"""
with open("state_tactic_pairs.csv", "wt") as oup:
wt = csv.DictWriter(oup, fieldnames=["state", "output"])
wt.writeheader()
for st in pairs:
wt.writerow(st)
"""
data = []
for pair in pairs:
data.append(
{
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
"input": "",
"output": pair["output"],
}
)
json.dump(data, open("state_tactic_pairs.json", "wt"))


if __name__ == "__main__":
3 changes: 1 addition & 2 deletions generator/template.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,7 @@


class StateTacticPairTemplate(InstructTemplate):
template = "### State:\n{state}\n\n### Tactic:"
# template = "[GOAL]\n{state}\n[PROOFSTEP]\n"
template = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n[GOAL]\n{state}\n<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n[PROOFSTEP]\n"

@classmethod
def format(
8 changes: 6 additions & 2 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
@@ -331,7 +331,7 @@ class VllmActor:
def __init__(self, model_path: str) -> None:
self.num_gpus = len(ray.get_gpu_ids())
self.model_path = model_path

def initialize(self) -> None:
logger.info("Initializing vLLM")
# TODO: Try `--enable-prefix-caching` and other parameters in https://docs.vllm.ai/en/stable/models/engine_args.html#engine-args.
@@ -341,7 +341,11 @@ def generate(
self, inputs: Union[str, List[str]], num_samples: int
) -> List[RequestOutput]:
sampling_params = SamplingParams(
n=num_samples, temperature=0, use_beam_search=True, early_stopping=False
n=num_samples,
temperature=0,
length_penalty=0,
use_beam_search=True,
early_stopping=False,
)
outputs = self.llm.generate(inputs, sampling_params, use_tqdm=False)
if isinstance(inputs, str):
2 changes: 1 addition & 1 deletion retrieval/model.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ def on_train_batch_end(self, outputs, batch, _) -> None:
self.embeddings_staled = True

def configure_optimizers(self) -> Dict[str, Any]:
return get_optimizers(self.parameters(), self.trainer, self.lr)
return get_optimizers(self.parameters(), self.trainer, self.lr, self.warmup_steps)

##############
# Validation #

0 comments on commit 3c6dafd

Please sign in to comment.