From eb0781428108abf2e8f90a89dcd4309c943e4510 Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Sat, 6 Jan 2024 18:09:27 +0000 Subject: [PATCH 1/6] vllm --- generator/model.py | 102 +++++++++++++++++++++++++++++++++++++++++ prover/evaluate.py | 50 ++++++++++++-------- prover/proof_search.py | 37 +++++++++++++-- 3 files changed, 165 insertions(+), 24 deletions(-) diff --git a/generator/model.py b/generator/model.py index 9cd542d..8aa9058 100644 --- a/generator/model.py +++ b/generator/model.py @@ -1,5 +1,6 @@ """Lightning module for the tactic generator.""" import torch +import time import openai import pickle from lean_dojo import Pos @@ -356,6 +357,107 @@ def batch_generate( return tactics_with_scores +class VLLMGenerator(TacticGenerator): + def __init__( + self, + server_url: str, + model: str, + max_tokens: int, + temperature: float, + stop: List[str], + prompt_format: str, + num_retries: int = 3, + ): + super().__init__() + if not server_url.startswith("http"): + server_url = f"http://{server_url}" + if not server_url.endswith("/v1"): + server_url = f"{server_url.rstrip('/')}/v1" + logger.info(f"Connecting to VLLM server at {server_url}") + self.server_url = server_url + self.client = openai.OpenAI(base_url=server_url, api_key="NONE") + logger.info(f"Initialized vllm client at {self.server_url}") + self.backoff_time = 3.0 + self.model = model + self.stop = stop + self.max_tokens = max_tokens + self.temperature = temperature + self.prompt_format = prompt_format + assert prompt_format.count("TACTIC_STATE") == 1 + self.num_retries = num_retries + + def trial_completion_with_args(self, completion_args: Dict[str, Any]) -> List[Tuple[str, float]]: + trial = 0 + while trial < self.num_retries: + try: + responses = self.client.completions.create(**completion_args) + texts_and_logprobs: List[Tuple[str, float]] = [] + for choice in responses.choices: + text = choice.text.strip() + logprob = sum(choice.logprobs.token_logprobs) + texts_and_logprobs.append((text, logprob)) + return texts_and_logprobs + except openai.OpenAIError as e: + logger.error(f"OpenAI API returned an error: {e}") + trial += 1 + logger.info(f"Retrying in {self.backoff_time} seconds...") + time.sleep(self.backoff_time) + + def generate( + self, + state: str, + file_path: str, + theorem_full_name: str, + theorem_pos: Pos, + num_samples: int, + ) -> List[Tuple[str, float]]: + # If no stochasticity, sample one tactic only. + assert self.temperature > 0 or len(num_samples) == 1 + prompt = self.prompt_format.replace("TACTIC_STATE", state.strip()) + completion_args = { + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "logprobs": 1, + "top_p": 1.0, + "echo": False, + "stop": self.stop, + "prompt": [prompt]*num_samples, + } + return self.trial_completion_with_args(completion_args) + + + def batch_generate( + self, + state: List[str], + file_path: List[str], + theorem_full_name: List[str], + theorem_pos: List[Pos], + num_samples: int, + ) -> List[List[Tuple[str, float]]]: + # If no stochasticity, sample one tactic only. + assert self.temperature > 0 or len(num_samples) == 1 + all_prompts: List[str] = [] + for s in state: + prompt = self.prompt_format.replace("TACTIC_STATE", s.strip()) + all_prompts.extend([prompt]*num_samples) + completion_args = { + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "logprobs": 1, + "top_p": 1.0, + "echo": False, + "stop": self.stop, + "prompt": all_prompts + } + all_completions = self.trial_completion_with_args(completion_args) + assert len(all_completions) == len(state) * num_samples + all_tactics_with_scores: List[List[Tuple[str, float]]] = [] + for i in range(len(state)): + all_tactics_with_scores.append(all_completions[i*num_samples:(i+1)*num_samples]) + return all_tactics_with_scores + class GPT4TacticGenerator(TacticGenerator): def __init__( self, diff --git a/prover/evaluate.py b/prover/evaluate.py index 21c709a..9101978 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -8,7 +8,7 @@ import argparse from loguru import logger from lean_dojo import Theorem -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any from lean_dojo import LeanGitRepo, Theorem, Pos, is_available_in_cache from common import set_logger @@ -23,6 +23,7 @@ def _get_theorems( name_filter: str, num_theorems: int, ) -> Tuple[LeanGitRepo, List[Theorem], List[Pos]]: + logger.info(f"Loading theorems from {data_path}...") repo, theorems, positions = _get_theorems_from_files( data_path, split, @@ -31,6 +32,7 @@ def _get_theorems( name_filter, num_theorems, ) + logger.info(f"Loaded theorems from {data_path}...") all_repos = {thm.repo for thm in theorems} for r in all_repos: @@ -95,6 +97,7 @@ def evaluate( tactic: Optional[str] = None, module: Optional[str] = None, num_sampled_tactics: int = 64, + vllm_args: Optional[dict[str, Any]] = None, timeout: int = 600, num_cpus: int = 1, with_gpus: bool = False, @@ -113,11 +116,13 @@ def evaluate( tactic, module, num_cpus, + vllm_args=vllm_args, with_gpus=with_gpus, timeout=timeout, num_sampled_tactics=num_sampled_tactics, debug=verbose, ) + import pdb; pdb.set_trace() results = prover.search_unordered(repo, theorems, positions) # Calculate the result statistics. @@ -204,30 +209,37 @@ def main() -> None: parser.add_argument( "--verbose", action="store_true", help="Set the logging level to DEBUG." ) + parser.add_argument( + "--vllm-args-json-path", type=str, help="URL of the VLLM server." + ) args = parser.parse_args() - assert args.ckpt_path or args.tactic - + assert args.ckpt_path or args.tactic or args.vllm_args_json_path + if args.vllm_args_json_path: + vllm_args = json.load(open(args.vllm_args_json_path)) + else: + vllm_args = None logger.info(f"PID: {os.getpid()}") logger.info(args) pass_1 = evaluate( - args.data_path, - args.exp_id, - args.split, - args.file_path, - args.full_name, - args.name_filter, - args.num_theorems, - args.ckpt_path, - args.indexed_corpus_path, - args.tactic, - args.module, - args.num_sampled_tactics, - args.timeout, - args.num_cpus, - args.with_gpus, - args.verbose, + data_path = args.data_path, + exp_id = args.exp_id, + split = args.split, + file_path = args.file_path, + full_name = args.full_name, + name_filter = args.name_filter, + num_theorems = args.num_theorems, + ckpt_path = args.ckpt_path, + indexed_corpus_path = args.indexed_corpus_path, + tactic = args.tactic, + module = args.module, + num_sampled_tactics = args.num_sampled_tactics, + vllm_args = vllm_args, + timeout = args.timeout, + num_cpus = args.num_cpus, + with_gpus = args.with_gpus, + verbose = args.verbose, ) logger.info(f"Pass@1: {pass_1}") diff --git a/prover/proof_search.py b/prover/proof_search.py index 5fe7262..73e52f8 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -22,12 +22,12 @@ ) from loguru import logger from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Any from ray.util.actor_pool import ActorPool from common import zip_strict from prover.search_tree import * -from generator.model import RetrievalAugmentedGenerator, FixedTacticGenerator +from generator.model import RetrievalAugmentedGenerator, FixedTacticGenerator, VLLMGenerator @dataclass(frozen=True) @@ -305,11 +305,26 @@ def __init__( indexed_corpus_path: Optional[str], tactic: Optional[str], module: Optional[str], + vllm_args: Optional[dict[str, Any]], timeout: int, num_sampled_tactics: int, debug: bool, ) -> None: - if ckpt_path is None: + if vllm_args: + assert all( + key in vllm_args + for key in ["server_url", "model", "max_tokens", "temperature", "stop", "prompt_format"] + ), vllm_args + tac_gen = VLLMGenerator( + server_url=vllm_args["server_url"], + model=vllm_args["model"], + max_tokens=vllm_args["max_tokens"], + temperature=vllm_args["temperature"], + stop=vllm_args["stop"], + prompt_format=vllm_args["prompt_format"], + num_retries=vllm_args.get("num_retries", 3), + ) + elif ckpt_path is None: tac_gen = FixedTacticGenerator(tactic, module) else: tac_gen = RetrievalAugmentedGenerator.load( @@ -374,18 +389,29 @@ def __init__( module: Optional[str], num_cpus: int, with_gpus: bool, + vllm_args: Optional[dict[str, Any]], timeout: int, num_sampled_tactics: int, debug: Optional[bool] = False, ) -> None: - if ckpt_path is None: + if ckpt_path is None and vllm_args is None: assert tactic and not indexed_corpus_path else: assert not tactic and not module self.distributed = num_cpus > 1 if not self.distributed: - if ckpt_path is None: + if vllm_args: + tac_gen = VLLMGenerator( + server_url=vllm_args["server_url"], + model=vllm_args["model"], + max_tokens=vllm_args["max_tokens"], + temperature=vllm_args["temperature"], + stop=vllm_args["stop"], + prompt_format=vllm_args["prompt_format"], + num_retries=vllm_args.get("num_retries", 3), + ) + elif ckpt_path is None: tac_gen = FixedTacticGenerator(tactic, module) else: device = torch.device("cuda") if with_gpus else torch.device("cpu") @@ -423,6 +449,7 @@ def __init__( indexed_corpus_path, tactic, module, + vllm_args=vllm_args, timeout=timeout, num_sampled_tactics=num_sampled_tactics, debug=debug, From ad355e346c3bb42f8295bfc5cd8f6690a789f97a Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Sat, 6 Jan 2024 21:32:17 +0000 Subject: [PATCH 2/6] make it work --- generator/model.py | 75 +++++++++++++++++++++------------------------- prover/evaluate.py | 1 - 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/generator/model.py b/generator/model.py index 8aa9058..ee9c65b 100644 --- a/generator/model.py +++ b/generator/model.py @@ -3,6 +3,7 @@ import time import openai import pickle +import multiprocessing.pool as mpp from lean_dojo import Pos from loguru import logger import pytorch_lightning as pl @@ -357,6 +358,24 @@ def batch_generate( return tactics_with_scores +def trial_completion_with_args(args_tuple: Tuple[openai.Client, int, float, Dict[str, Any]]) -> List[Tuple[str, float]]: + client, num_retries, backoff_time, completion_args = args_tuple + trial = 0 + while trial < num_retries: + try: + responses = client.completions.create(**completion_args) + texts_and_logprobs: List[Tuple[str, float]] = [] + for choice in responses.choices: + text = choice.text.strip() + logprob = sum(choice.logprobs.token_logprobs) + texts_and_logprobs.append((text, logprob)) + return texts_and_logprobs + except openai.OpenAIError as e: + logger.error(f"OpenAI API returned an error: {e}") + trial += 1 + logger.info(f"Retrying in {backoff_time} seconds...") + time.sleep(backoff_time) + class VLLMGenerator(TacticGenerator): def __init__( self, @@ -386,22 +405,7 @@ def __init__( assert prompt_format.count("TACTIC_STATE") == 1 self.num_retries = num_retries - def trial_completion_with_args(self, completion_args: Dict[str, Any]) -> List[Tuple[str, float]]: - trial = 0 - while trial < self.num_retries: - try: - responses = self.client.completions.create(**completion_args) - texts_and_logprobs: List[Tuple[str, float]] = [] - for choice in responses.choices: - text = choice.text.strip() - logprob = sum(choice.logprobs.token_logprobs) - texts_and_logprobs.append((text, logprob)) - return texts_and_logprobs - except openai.OpenAIError as e: - logger.error(f"OpenAI API returned an error: {e}") - trial += 1 - logger.info(f"Retrying in {self.backoff_time} seconds...") - time.sleep(self.backoff_time) + def generate( self, @@ -422,11 +426,18 @@ def generate( "top_p": 1.0, "echo": False, "stop": self.stop, - "prompt": [prompt]*num_samples, + "prompt": [prompt], } - return self.trial_completion_with_args(completion_args) - + all_results = [] + with mpp.ThreadPool(64) as p: + for result in p.imap( + trial_completion_with_args, + [(self.client, self.num_retries, self.backoff_time, completion_args) for _ in range(num_samples)], + ): + all_results.extend(result) + return all_results + def batch_generate( self, state: List[str], @@ -435,28 +446,10 @@ def batch_generate( theorem_pos: List[Pos], num_samples: int, ) -> List[List[Tuple[str, float]]]: - # If no stochasticity, sample one tactic only. - assert self.temperature > 0 or len(num_samples) == 1 - all_prompts: List[str] = [] - for s in state: - prompt = self.prompt_format.replace("TACTIC_STATE", s.strip()) - all_prompts.extend([prompt]*num_samples) - completion_args = { - "model": self.model, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - "logprobs": 1, - "top_p": 1.0, - "echo": False, - "stop": self.stop, - "prompt": all_prompts - } - all_completions = self.trial_completion_with_args(completion_args) - assert len(all_completions) == len(state) * num_samples - all_tactics_with_scores: List[List[Tuple[str, float]]] = [] - for i in range(len(state)): - all_tactics_with_scores.append(all_completions[i*num_samples:(i+1)*num_samples]) - return all_tactics_with_scores + return [ + self.generate(s, f, tfn, tp, num_samples) + for s, f, tfn, tp in zip_strict(state, file_path, theorem_full_name, theorem_pos) + ] class GPT4TacticGenerator(TacticGenerator): def __init__( diff --git a/prover/evaluate.py b/prover/evaluate.py index 9101978..07d8731 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -122,7 +122,6 @@ def evaluate( num_sampled_tactics=num_sampled_tactics, debug=verbose, ) - import pdb; pdb.set_trace() results = prover.search_unordered(repo, theorems, positions) # Calculate the result statistics. From b9ca6ed496fac454a96fc4f3d45071d851d764ae Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Sun, 7 Jan 2024 18:32:05 +0000 Subject: [PATCH 3/6] d --- generator/model.py | 4 ++-- prover/evaluate.py | 26 ++++++++++++++++++++++++-- prover/proof_search.py | 37 +++++++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/generator/model.py b/generator/model.py index ee9c65b..2bb973e 100644 --- a/generator/model.py +++ b/generator/model.py @@ -429,14 +429,14 @@ def generate( "prompt": [prompt], } - all_results = [] with mpp.ThreadPool(64) as p: + all_results = [] for result in p.imap( trial_completion_with_args, [(self.client, self.num_retries, self.backoff_time, completion_args) for _ in range(num_samples)], ): all_results.extend(result) - return all_results + return all_results def batch_generate( self, diff --git a/prover/evaluate.py b/prover/evaluate.py index 07d8731..5133af8 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -11,7 +11,7 @@ from typing import List, Tuple, Optional, Any from lean_dojo import LeanGitRepo, Theorem, Pos, is_available_in_cache -from common import set_logger +from common import set_logger, zip_strict from prover.proof_search import Status, DistributedProver @@ -102,6 +102,7 @@ def evaluate( num_cpus: int = 1, with_gpus: bool = False, verbose: bool = False, + progress_dir: Optional[str] = None, ) -> float: set_logger(verbose) @@ -109,6 +110,23 @@ def evaluate( data_path, split, file_path, full_name, name_filter, num_theorems ) + # Don't do theorems that are already done. + finished_theorem_names = set() + if progress_dir is not None: + os.makedirs(progress_dir, exist_ok=True) + for file in os.listdir(progress_dir): + assert file.endswith(".out") + name = file[:-4] + assert name not in finished_theorem_names + finished_theorem_names.add(name) + unfinished_theorems, unfinished_positions = [], [] + for theorem, position in zip_strict(theorems, positions): + if theorem.uid in finished_theorem_names: + continue + unfinished_theorems.append(theorem) + unfinished_positions.append(position) + logger.info(f"{len(unfinished_theorems)} theorems to prove") + # Search for proofs using multiple concurrent provers. prover = DistributedProver( ckpt_path, @@ -122,7 +140,7 @@ def evaluate( num_sampled_tactics=num_sampled_tactics, debug=verbose, ) - results = prover.search_unordered(repo, theorems, positions) + results = prover.search_unordered(repo, unfinished_theorems, unfinished_positions, progress_dir=progress_dir) # Calculate the result statistics. num_proved = num_failed = num_discarded = 0 @@ -211,6 +229,9 @@ def main() -> None: parser.add_argument( "--vllm-args-json-path", type=str, help="URL of the VLLM server." ) + parser.add_argument( + "--progress-dir", type=str, help="Progress directory" + ) args = parser.parse_args() assert args.ckpt_path or args.tactic or args.vllm_args_json_path @@ -239,6 +260,7 @@ def main() -> None: num_cpus = args.num_cpus, with_gpus = args.with_gpus, verbose = args.verbose, + progress_dir = args.progress_dir, ) logger.info(f"Pass@1: {pass_1}") diff --git a/prover/proof_search.py b/prover/proof_search.py index 73e52f8..b5ad443 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -6,6 +6,7 @@ import time import heapq import torch +import json from lean_dojo import ( Pos, Dojo, @@ -45,6 +46,20 @@ class SearchResult: num_total_nodes: int num_searched_nodes: int + def serialize(self) -> str: + result_dict = { + "theorem": self.theorem.uid, + "status": self.status.name, + "proof": self.proof, + "actor_time": self.actor_time, + "environment_time": self.environment_time, + "total_time": self.total_time, + "num_total_nodes": self.num_total_nodes, + "num_searched_nodes": self.num_searched_nodes, + } + return json.dumps(result_dict) + + class BestFirstSearchProver: """A prover that uses best-first search to find proofs using a tactic generator.""" @@ -67,9 +82,21 @@ def __init__( self.total_time = None def search( - self, repo: LeanGitRepo, thm: Theorem, pos: Pos + self, repo: LeanGitRepo, thm: Theorem, pos: Pos, progress_dir: Optional[str] = None ) -> Optional[SearchResult]: logger.info(f"Proving {thm}") + + theorem_uid = thm.uid + if progress_dir is not None: + assert os.path.isdir(progress_dir) + progress_file = os.path.join(progress_dir, theorem_uid + ".out") + assert not os.path.isfile(progress_file) + empty_placeholder_result = { + "theorem": thm.uid, + "status": Status.OPEN.name, + "proof": None, + } + json.dump(empty_placeholder_result, open(progress_file, "w"), ensure_ascii=False, indent=4) self.repo = repo self.theorem = thm @@ -119,6 +146,8 @@ def search( num_searched_nodes=self.num_expansions, ) logger.info(result) + if progress_dir is not None: + json.dump(result, open(progress_file, "w"), ensure_ascii=False, indent=4) return result except DojoInitError as ex: @@ -460,19 +489,19 @@ def __init__( self.prover_pool = ActorPool(provers) def search_unordered( - self, repo: LeanGitRepo, theorems: List[Theorem], positions: List[Pos] + self, repo: LeanGitRepo, theorems: List[Theorem], positions: List[Pos], progress_dir: Optional[str] = None ) -> List[SearchResult]: """Parallel proof search for `theorems`. The order of the results is not guaranteed to match the order of the input.""" if not self.distributed: return [ - self.prover.search(repo, thm, pos) + self.prover.search(repo, thm, pos, progress_dir) for thm, pos in zip_strict(theorems, positions) ] try: results = list( self.prover_pool.map_unordered( - lambda p, x: p.search.remote(repo, x[0], x[1]), + lambda p, x: p.search.remote(repo, x[0], x[1], progress_dir), zip_strict(theorems, positions), ) ) From 99ec0164bd4d250ea6a96238256acd58a1107eb6 Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Sun, 7 Jan 2024 19:55:06 +0000 Subject: [PATCH 4/6] d --- prover/proof_search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prover/proof_search.py b/prover/proof_search.py index b5ad443..c74c064 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -57,7 +57,7 @@ def serialize(self) -> str: "num_total_nodes": self.num_total_nodes, "num_searched_nodes": self.num_searched_nodes, } - return json.dumps(result_dict) + return json.dumps(result_dict, ensure_ascii=False, indent=4) @@ -147,7 +147,8 @@ def search( ) logger.info(result) if progress_dir is not None: - json.dump(result, open(progress_file, "w"), ensure_ascii=False, indent=4) + with open(progress_file, "w") as f: + f.write(result.serialize()) return result except DojoInitError as ex: From b486963e08428d6b0f6889ae61c96a765c3d2af2 Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Mon, 8 Jan 2024 21:06:26 +0000 Subject: [PATCH 5/6] Update prover/proof_search.py --- prover/proof_search.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/prover/proof_search.py b/prover/proof_search.py index c74c064..fea6638 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -89,14 +89,6 @@ def search( theorem_uid = thm.uid if progress_dir is not None: assert os.path.isdir(progress_dir) - progress_file = os.path.join(progress_dir, theorem_uid + ".out") - assert not os.path.isfile(progress_file) - empty_placeholder_result = { - "theorem": thm.uid, - "status": Status.OPEN.name, - "proof": None, - } - json.dump(empty_placeholder_result, open(progress_file, "w"), ensure_ascii=False, indent=4) self.repo = repo self.theorem = thm From a05555133cb3eb59a4141ba1b4674645d83bcc43 Mon Sep 17 00:00:00 2001 From: Albert Jiang Date: Fri, 12 Jan 2024 10:46:17 +0000 Subject: [PATCH 6/6] make it work --- generator/model.py | 39 +++++++++++++++++++++++---------------- prover/evaluate.py | 10 ++++------ prover/proof_search.py | 2 +- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/generator/model.py b/generator/model.py index 2bb973e..54a5aca 100644 --- a/generator/model.py +++ b/generator/model.py @@ -405,7 +405,15 @@ def __init__( assert prompt_format.count("TACTIC_STATE") == 1 self.num_retries = num_retries - + def generate_from_args(self, args: List[Dict[str, Any]]) -> List[Tuple[str, float]]: + with mpp.ThreadPool(64) as p: + all_results = [] + for result in p.imap( + trial_completion_with_args, + [(self.client, self.num_retries, self.backoff_time, arg) for arg in args], + ): + all_results.extend(result) + return all_results def generate( self, @@ -418,7 +426,11 @@ def generate( # If no stochasticity, sample one tactic only. assert self.temperature > 0 or len(num_samples) == 1 prompt = self.prompt_format.replace("TACTIC_STATE", state.strip()) - completion_args = { + completion_args = self.get_completion_args(prompt) + return self.generate_from_args([completion_args] * num_samples) + + def get_completion_args(self, prompt: str) -> dict[str, Any]: + return { "model": self.model, "max_tokens": self.max_tokens, "temperature": self.temperature, @@ -428,16 +440,7 @@ def generate( "stop": self.stop, "prompt": [prompt], } - - with mpp.ThreadPool(64) as p: - all_results = [] - for result in p.imap( - trial_completion_with_args, - [(self.client, self.num_retries, self.backoff_time, completion_args) for _ in range(num_samples)], - ): - all_results.extend(result) - return all_results - + def batch_generate( self, state: List[str], @@ -446,10 +449,14 @@ def batch_generate( theorem_pos: List[Pos], num_samples: int, ) -> List[List[Tuple[str, float]]]: - return [ - self.generate(s, f, tfn, tp, num_samples) - for s, f, tfn, tp in zip_strict(state, file_path, theorem_full_name, theorem_pos) - ] + all_args: List[Dict[str, Any]] = [] + for s in state: + prompt = self.prompt_format.replace("TACTIC_STATE", s.strip()) + completion_args = self.get_completion_args(prompt) + for _ in range(num_samples): + all_args.append(completion_args) + + return self.generate_from_args(all_args) class GPT4TacticGenerator(TacticGenerator): def __init__( diff --git a/prover/evaluate.py b/prover/evaluate.py index 5133af8..13bd55f 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -111,17 +111,15 @@ def evaluate( ) # Don't do theorems that are already done. - finished_theorem_names = set() + finished_theorem_hashes = set() if progress_dir is not None: os.makedirs(progress_dir, exist_ok=True) for file in os.listdir(progress_dir): assert file.endswith(".out") - name = file[:-4] - assert name not in finished_theorem_names - finished_theorem_names.add(name) + finished_theorem_hashes.add(file[:-4]) unfinished_theorems, unfinished_positions = [], [] - for theorem, position in zip_strict(theorems, positions): - if theorem.uid in finished_theorem_names: + for theorem, position in zip(theorems, positions): + if theorem.uhash in finished_theorem_hashes: continue unfinished_theorems.append(theorem) unfinished_positions.append(position) diff --git a/prover/proof_search.py b/prover/proof_search.py index fea6638..9a63092 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -139,7 +139,7 @@ def search( ) logger.info(result) if progress_dir is not None: - with open(progress_file, "w") as f: + with open(os.path.join(progress_dir, f"{thm.uhash}.out"), "w") as f: f.write(result.serialize()) return result