From 80290e58bc64df832648da757d7d5f4e60536904 Mon Sep 17 00:00:00 2001 From: Kaiyu Yang Date: Tue, 1 Nov 2022 17:21:10 -0700 Subject: [PATCH] fix some warnings and minor bugs --- prover/cli_ruletaker_single_shot_t5-large.yaml | 4 ++-- prover/cli_ruletaker_stepwise_t5-large.yaml | 4 ++-- prover/cli_task1_single_shot_t5-large.yaml | 4 ++-- prover/cli_task1_stepwise_t5-large.yaml | 4 ++-- prover/cli_task2_single_shot_t5-large.yaml | 4 ++-- prover/cli_task2_stepwise_t5-large.yaml | 4 ++-- prover/datamodule.py | 10 +++++++--- prover/model.py | 13 +++++++------ verifier/cli_entailmentbank_task1.yaml | 2 +- verifier/cli_entailmentbank_task2.yaml | 2 +- verifier/cli_ruletaker.yaml | 2 +- 11 files changed, 29 insertions(+), 24 deletions(-) diff --git a/prover/cli_ruletaker_single_shot_t5-large.yaml b/prover/cli_ruletaker_single_shot_t5-large.yaml index 1e6cce5..06f60ef 100644 --- a/prover/cli_ruletaker_single_shot_t5-large.yaml +++ b/prover/cli_ruletaker_single_shot_t5-large.yaml @@ -19,8 +19,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/cli_ruletaker_stepwise_t5-large.yaml b/prover/cli_ruletaker_stepwise_t5-large.yaml index e12a931..c8d9633 100644 --- a/prover/cli_ruletaker_stepwise_t5-large.yaml +++ b/prover/cli_ruletaker_stepwise_t5-large.yaml @@ -19,8 +19,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/cli_task1_single_shot_t5-large.yaml b/prover/cli_task1_single_shot_t5-large.yaml index 0eaefb1..e89a08d 100644 --- a/prover/cli_task1_single_shot_t5-large.yaml +++ b/prover/cli_task1_single_shot_t5-large.yaml @@ -18,8 +18,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/cli_task1_stepwise_t5-large.yaml b/prover/cli_task1_stepwise_t5-large.yaml index 60d2a95..1a06c6c 100644 --- a/prover/cli_task1_stepwise_t5-large.yaml +++ b/prover/cli_task1_stepwise_t5-large.yaml @@ -18,8 +18,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/cli_task2_single_shot_t5-large.yaml b/prover/cli_task2_single_shot_t5-large.yaml index 62dd6f2..493ea07 100644 --- a/prover/cli_task2_single_shot_t5-large.yaml +++ b/prover/cli_task2_single_shot_t5-large.yaml @@ -18,8 +18,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/cli_task2_stepwise_t5-large.yaml b/prover/cli_task2_stepwise_t5-large.yaml index f9271fe..ac6fcc3 100644 --- a/prover/cli_task2_stepwise_t5-large.yaml +++ b/prover/cli_task2_stepwise_t5-large.yaml @@ -18,8 +18,8 @@ model: model_name: t5-large num_beams: 10 topk: 10 - verifier_ckpt: "" - verifier_weight: 0 + verifier_ckpt: null + verifier_weight: 0.0 proof_search: false oracle_prover: false oracle_verifier: false diff --git a/prover/datamodule.py b/prover/datamodule.py index 635572c..9127934 100644 --- a/prover/datamodule.py +++ b/prover/datamodule.py @@ -92,9 +92,13 @@ def read_ruletaker_proofs(path: str, is_train: bool) -> List[Example]: "all_proofs": ex["proofs"], } ) + if ex["answer"] == "Unknown": + ans = "Unknown" + else: + ans = not ex["answer"] data.append( { - "answer": not ex["answer"] if ex["answer"] != "Unknown" else "Unknown", + "answer": ans, "depth": ex["depth"], "proof": Proof( context, @@ -133,7 +137,7 @@ def __init__( is_train: bool, ) -> None: super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512) self.max_input_len = max_input_len self.max_output_len = max_output_len self.is_train = is_train @@ -207,7 +211,7 @@ def __init__( is_train: bool, ) -> None: super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512) self.max_input_len = max_input_len self.max_output_len = max_output_len self.sample_goal = sample_goal diff --git a/prover/model.py b/prover/model.py index 6613f10..4ae90ac 100644 --- a/prover/model.py +++ b/prover/model.py @@ -109,11 +109,11 @@ def __init__( warmup_steps: int, num_beams: int, topk: int, - verifier_ckpt: str, - verifier_weight: float, proof_search: bool, - oracle_prover: bool, - oracle_verifier: bool, + verifier_weight: float, + verifier_ckpt: Optional[str] = None, + oracle_prover: Optional[bool] = False, + oracle_verifier: Optional[bool] = False, ) -> None: super().__init__() self.save_hyperparameters() @@ -129,12 +129,13 @@ def __init__( self.oracle_prover = oracle_prover self.oracle_verifier = oracle_verifier if stepwise and verifier_weight > 0: - assert verifier_weight <= 1 + assert verifier_weight <= 1.0 + assert verifier_ckpt is not None self.verifiers = [ EntailmentClassifier.load_from_checkpoint(verifier_ckpt) ] # Avoid making the verifier a submodule. - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512) if ( model_name.startswith("t5-") or model_name.startswith("google/t5-v1_1-") diff --git a/verifier/cli_entailmentbank_task1.yaml b/verifier/cli_entailmentbank_task1.yaml index 2dfcce2..50f0aaf 100644 --- a/verifier/cli_entailmentbank_task1.yaml +++ b/verifier/cli_entailmentbank_task1.yaml @@ -10,7 +10,7 @@ model: lr: 1e-5 warmup_steps: 2000 model_name: roberta-large - pos_weight: 128 + pos_weight: 128.0 data: dataset: entailmentbank batch_size: 128 diff --git a/verifier/cli_entailmentbank_task2.yaml b/verifier/cli_entailmentbank_task2.yaml index 0383033..352925a 100644 --- a/verifier/cli_entailmentbank_task2.yaml +++ b/verifier/cli_entailmentbank_task2.yaml @@ -10,7 +10,7 @@ model: lr: 1e-5 warmup_steps: 2500 model_name: roberta-large - pos_weight: 128 + pos_weight: 128.0 data: dataset: entailmentbank batch_size: 128 diff --git a/verifier/cli_ruletaker.yaml b/verifier/cli_ruletaker.yaml index b012956..8ac6566 100644 --- a/verifier/cli_ruletaker.yaml +++ b/verifier/cli_ruletaker.yaml @@ -10,7 +10,7 @@ model: lr: 5e-6 warmup_steps: 2500 model_name: roberta-large - pos_weight: 4 + pos_weight: 4.0 data: dataset: ruletaker batch_size: 128