Skip to content

Commit

Permalink
Merge pull request #40 from albertqjiang/aj/vllm_enabled
Browse files Browse the repository at this point in the history
Enabling vllm-based prover
  • Loading branch information
Kaiyu Yang authored Apr 5, 2024
2 parents 822dcb1 + 655818f commit 3236b83
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 32 deletions.
102 changes: 102 additions & 0 deletions generator/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Lightning module for the tactic generator."""

import torch
import time
import openai
import pickle
import multiprocessing.pool as mpp
from lean_dojo import Pos
from loguru import logger
import pytorch_lightning as pl
Expand Down Expand Up @@ -362,6 +364,106 @@ def batch_generate(
return tactics_with_scores


def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]]) -> List[Tuple[str, float]]:
client, num_retries, backoff_time, completion_args = args_tuple
trial = 0
while trial < num_retries:
try:
responses = client.completions.create(**completion_args)
texts_and_logprobs: List[Tuple[str, float]] = []
for choice in responses.choices:
text = choice.text.strip()
logprob = sum(choice.logprobs.token_logprobs)
texts_and_logprobs.append((text, logprob))
return texts_and_logprobs
except openai.OpenAIError as e:
logger.error(f"OpenAI API returned an error: {e}")
trial += 1
logger.info(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)

class VLLMGenerator(TacticGenerator):
def __init__(
self,
server_url: str,
model: str,
max_tokens: int,
temperature: float,
stop: List[str],
prompt_format: str,
num_retries: int = 3,
):
super().__init__()
if not server_url.startswith("http"):
server_url = f"http://{server_url}"
if not server_url.endswith("/v1"):
server_url = f"{server_url.rstrip('/')}/v1"
logger.info(f"Connecting to VLLM server at {server_url}")
self.server_url = server_url
self.client = openai.OpenAI(base_url=server_url, api_key="NONE")
logger.info(f"Initialized vllm client at {self.server_url}")
self.backoff_time = 3.0
self.model = model
self.stop = stop
self.max_tokens = max_tokens
self.temperature = temperature
self.prompt_format = prompt_format
assert prompt_format.count("TACTIC_STATE") == 1
self.num_retries = num_retries

def generate_from_args(self, args: List[Dict[str, Any]]) -> List[Tuple[str, float]]:
with mpp.ThreadPool(64) as p:
all_results = []
for result in p.imap(
trial_completion_with_args,
[(self.client, self.num_retries, self.backoff_time, arg) for arg in args],
):
all_results.extend(result)
return all_results

def generate(
self,
state: str,
file_path: str,
theorem_full_name: str,
theorem_pos: Pos,
num_samples: int,
) -> List[Tuple[str, float]]:
# If no stochasticity, sample one tactic only.
assert self.temperature > 0 or len(num_samples) == 1
prompt = self.prompt_format.replace("TACTIC_STATE", state.strip())
completion_args = self.get_completion_args(prompt)
return self.generate_from_args([completion_args] * num_samples)

def get_completion_args(self, prompt: str) -> dict[str, Any]:
return {
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"logprobs": 1,
"top_p": 1.0,
"echo": False,
"stop": self.stop,
"prompt": [prompt],
}

def batch_generate(
self,
state: List[str],
file_path: List[str],
theorem_full_name: List[str],
theorem_pos: List[Pos],
num_samples: int,
) -> List[List[Tuple[str, float]]]:
all_args: List[Dict[str, Any]] = []
for s in state:
prompt = self.prompt_format.replace("TACTIC_STATE", s.strip())
completion_args = self.get_completion_args(prompt)
for _ in range(num_samples):
all_args.append(completion_args)

return self.generate_from_args(all_args)

class GPT4TacticGenerator(TacticGenerator):
def __init__(
self,
Expand Down
74 changes: 53 additions & 21 deletions prover/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import argparse
from loguru import logger
from lean_dojo import Theorem
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Any
from lean_dojo import LeanGitRepo, Theorem, Pos, is_available_in_cache

from common import set_logger
from common import set_logger, zip_strict
from prover.proof_search import Status, DistributedProver


Expand All @@ -24,6 +24,7 @@ def _get_theorems(
name_filter: str,
num_theorems: int,
) -> Tuple[LeanGitRepo, List[Theorem], List[Pos]]:
logger.info(f"Loading theorems from {data_path}...")
repo, theorems, positions = _get_theorems_from_files(
data_path,
split,
Expand All @@ -32,6 +33,7 @@ def _get_theorems(
name_filter,
num_theorems,
)
logger.info(f"Loaded theorems from {data_path}...")

all_repos = {thm.repo for thm in theorems}
for r in all_repos:
Expand Down Expand Up @@ -96,30 +98,48 @@ def evaluate(
tactic: Optional[str] = None,
module: Optional[str] = None,
num_sampled_tactics: int = 64,
vllm_args: Optional[dict[str, Any]] = None,
timeout: int = 600,
num_workers: int = 1,
num_gpus: int = 0,
verbose: bool = False,
progress_dir: Optional[str] = None,
) -> float:
set_logger(verbose)

repo, theorems, positions = _get_theorems(
data_path, split, file_path, full_name, name_filter, num_theorems
)

# Don't do theorems that are already done.
finished_theorem_hashes = set()
if progress_dir is not None:
os.makedirs(progress_dir, exist_ok=True)
for file in os.listdir(progress_dir):
assert file.endswith(".out")
finished_theorem_hashes.add(file[:-4])
unfinished_theorems, unfinished_positions = [], []
for theorem, position in zip(theorems, positions):
if theorem.uhash in finished_theorem_hashes:
continue
unfinished_theorems.append(theorem)
unfinished_positions.append(position)
logger.info(f"{len(unfinished_theorems)} theorems to prove")

# Search for proofs using multiple concurrent provers.
prover = DistributedProver(
ckpt_path,
indexed_corpus_path,
tactic,
module,
num_workers,
num_gpus=num_gpus,
num_gpus,
vllm_args=vllm_args,
timeout=timeout,
num_sampled_tactics=num_sampled_tactics,
debug=verbose,
)
results = prover.search_unordered(repo, theorems, positions)
results = prover.search_unordered(repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir)

# Calculate the result statistics.
num_proved = num_failed = num_discarded = 0
Expand Down Expand Up @@ -205,31 +225,43 @@ def main() -> None:
parser.add_argument(
"--verbose", action="store_true", help="Set the logging level to DEBUG."
)
parser.add_argument(
"--vllm-args-json-path", type=str, help="URL of the VLLM server."
)
parser.add_argument(
"--progress-dir", type=str, help="Progress directory"
)
args = parser.parse_args()

assert args.ckpt_path or args.tactic
assert args.ckpt_path or args.tactic or args.vllm_args_json_path
if args.vllm_args_json_path:
vllm_args = json.load(open(args.vllm_args_json_path))
else:
vllm_args = None
assert args.num_gpus <= args.num_workers

logger.info(f"PID: {os.getpid()}")
logger.info(args)

pass_1 = evaluate(
args.data_path,
args.exp_id,
args.split,
args.file_path,
args.full_name,
args.name_filter,
args.num_theorems,
args.ckpt_path,
args.indexed_corpus_path,
args.tactic,
args.module,
args.num_sampled_tactics,
args.timeout,
args.num_workers,
args.num_gpus,
args.verbose,
data_path = args.data_path,
exp_id = args.exp_id,
split = args.split,
file_path = args.file_path,
full_name = args.full_name,
name_filter = args.name_filter,
num_theorems = args.num_theorems,
ckpt_path = args.ckpt_path,
indexed_corpus_path = args.indexed_corpus_path,
tactic = args.tactic,
module = args.module,
num_sampled_tactics = args.num_sampled_tactics,
vllm_args = vllm_args,
timeout = args.timeout,
num_workers = args.num_workers,
num_gpus = args.num_gpus,
verbose = args.verbose,
progress_dir = args.progress_dir,
)

logger.info(f"Pass@1: {pass_1}")
Expand Down
Loading

0 comments on commit 3236b83

Please sign in to comment.