Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 12, 2024
1 parent 40e07a6 commit 04c80fb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
4 changes: 3 additions & 1 deletion common.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,13 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:


def format_augmented_state(
s: str, premises: List[Premise], max_len: int, p_drop: float
s: str, premises: List[Premise], max_len: Optional[int] = None, p_drop: float = 0.0
) -> str:
"""Format a state with retrieved premises and drop some of them with probability ``p_drop``."""
aug_s = ""
length = 0
if max_len is None:
max_len = 9999999999999999999999
max_premises_len = max_len - len(bytes(s.encode("utf-8")))

for p in premises:
Expand Down
13 changes: 11 additions & 2 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,21 @@ def __init__(
ray.get(vllm_actor.initialize.remote())
tac_gen = VllmGenerator(vllm_actor)
elif indexed_corpus_path is not None:
device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu")
tac_gen = RetrievalAugmentedGenerator(
gen_ckpt_path, ret_ckpt_path, indexed_corpus_path, device, max_num_retrieved=100
gen_ckpt_path,
ret_ckpt_path,
indexed_corpus_path,
device,
max_oup_seq_len,
length_penalty,
max_num_retrieved=100,
)
else:
device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu")
tac_gen = HuggingFaceGenerator(gen_ckpt_path, device, max_oup_seq_len, length_penalty)
tac_gen = HuggingFaceGenerator(
gen_ckpt_path, device, max_oup_seq_len, length_penalty
)

self.distributed = num_workers > 1
if not self.distributed:
Expand Down
10 changes: 3 additions & 7 deletions retrieval/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,11 @@ def load(cls, ckpt_path: str, device, freeze: bool) -> "PremiseRetriever":

@classmethod
def load_hf(
cls, ckpt_path: str, device: int, dtype = None, max_seq_len: Optional[int] = None
cls, ckpt_path: str, device: int, dtype=None, max_seq_len: Optional[int] = None
) -> "PremiseRetriever":
if max_seq_len is None:
max_seq_len = 999999999999
model = (
PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100)
.to(device)
.eval()
)
model = PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100).to(device).eval()
if dtype is not None:
return model.to(dtype)
elif (
Expand Down Expand Up @@ -373,7 +369,7 @@ def retrieve(

retrieved_premises, scores = self.corpus.get_nearest_premises(
self.corpus_embeddings,
ctx,
[ctx],
context_emb,
k,
)
Expand Down

0 comments on commit 04c80fb

Please sign in to comment.