Skip to content

Commit

Permalink
fix some warnings and minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Nov 2, 2022
1 parent 88b07f3 commit 80290e5
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions prover/cli_ruletaker_single_shot_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
4 changes: 2 additions & 2 deletions prover/cli_ruletaker_stepwise_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
4 changes: 2 additions & 2 deletions prover/cli_task1_single_shot_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
4 changes: 2 additions & 2 deletions prover/cli_task1_stepwise_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
4 changes: 2 additions & 2 deletions prover/cli_task2_single_shot_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
4 changes: 2 additions & 2 deletions prover/cli_task2_stepwise_t5-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
model_name: t5-large
num_beams: 10
topk: 10
verifier_ckpt: ""
verifier_weight: 0
verifier_ckpt: null
verifier_weight: 0.0
proof_search: false
oracle_prover: false
oracle_verifier: false
Expand Down
10 changes: 7 additions & 3 deletions prover/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def read_ruletaker_proofs(path: str, is_train: bool) -> List[Example]:
"all_proofs": ex["proofs"],
}
)
if ex["answer"] == "Unknown":
ans = "Unknown"
else:
ans = not ex["answer"]
data.append(
{
"answer": not ex["answer"] if ex["answer"] != "Unknown" else "Unknown",
"answer": ans,
"depth": ex["depth"],
"proof": Proof(
context,
Expand Down Expand Up @@ -133,7 +137,7 @@ def __init__(
is_train: bool,
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.is_train = is_train
Expand Down Expand Up @@ -207,7 +211,7 @@ def __init__(
is_train: bool,
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.sample_goal = sample_goal
Expand Down
13 changes: 7 additions & 6 deletions prover/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def __init__(
warmup_steps: int,
num_beams: int,
topk: int,
verifier_ckpt: str,
verifier_weight: float,
proof_search: bool,
oracle_prover: bool,
oracle_verifier: bool,
verifier_weight: float,
verifier_ckpt: Optional[str] = None,
oracle_prover: Optional[bool] = False,
oracle_verifier: Optional[bool] = False,
) -> None:
super().__init__()
self.save_hyperparameters()
Expand All @@ -129,12 +129,13 @@ def __init__(
self.oracle_prover = oracle_prover
self.oracle_verifier = oracle_verifier
if stepwise and verifier_weight > 0:
assert verifier_weight <= 1
assert verifier_weight <= 1.0
assert verifier_ckpt is not None
self.verifiers = [
EntailmentClassifier.load_from_checkpoint(verifier_ckpt)
] # Avoid making the verifier a submodule.

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
if (
model_name.startswith("t5-")
or model_name.startswith("google/t5-v1_1-")
Expand Down
2 changes: 1 addition & 1 deletion verifier/cli_entailmentbank_task1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ model:
lr: 1e-5
warmup_steps: 2000
model_name: roberta-large
pos_weight: 128
pos_weight: 128.0
data:
dataset: entailmentbank
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion verifier/cli_entailmentbank_task2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ model:
lr: 1e-5
warmup_steps: 2500
model_name: roberta-large
pos_weight: 128
pos_weight: 128.0
data:
dataset: entailmentbank
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion verifier/cli_ruletaker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ model:
lr: 5e-6
warmup_steps: 2500
model_name: roberta-large
pos_weight: 4
pos_weight: 4.0
data:
dataset: ruletaker
batch_size: 128
Expand Down

0 comments on commit 80290e5

Please sign in to comment.