diff --git a/prover/datamodule.py b/prover/datamodule.py index 63fec20..2893c20 100644 --- a/prover/datamodule.py +++ b/prover/datamodule.py @@ -138,7 +138,9 @@ def __init__( ) -> None: super().__init__() max_len = max(max_input_len, max_output_len) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_len) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, model_max_length=max_len + ) self.max_input_len = max_input_len self.max_output_len = max_output_len self.is_train = is_train @@ -213,7 +215,9 @@ def __init__( ) -> None: super().__init__() max_len = max(max_input_len, max_output_len) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_len) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, model_max_length=max_len + ) self.max_input_len = max_input_len self.max_output_len = max_output_len self.sample_goal = sample_goal diff --git a/prover/model.py b/prover/model.py index 44b014a..ca99861 100644 --- a/prover/model.py +++ b/prover/model.py @@ -136,7 +136,9 @@ def __init__( EntailmentClassifier.load_from_checkpoint(verifier_ckpt) ] # Avoid making the verifier a submodule. - self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, model_max_length=max_input_len + ) if ( model_name.startswith("t5-") or model_name.startswith("google/t5-v1_1-") @@ -680,7 +682,7 @@ def configure_optimizers(self) -> Dict[str, Any]: assert self.trainer is not None if self.trainer.max_steps != -1: max_steps = self.trainer.max_steps - else: + else: max_steps = ( self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) # type: ignore diff --git a/verifier/datamodule.py b/verifier/datamodule.py index a17f8c4..f9debba 100644 --- a/verifier/datamodule.py +++ b/verifier/datamodule.py @@ -64,7 +64,9 @@ def __init__( irrelevant_distractors_only: bool, ) -> None: super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, model_max_length=max_input_len + ) assert split in ("train", "val") self.split = split self.max_num_premises = max_num_premises # The maximum number of premises used in data augmentation. diff --git a/verifier/model.py b/verifier/model.py index 8b4ea2a..2486b53 100644 --- a/verifier/model.py +++ b/verifier/model.py @@ -6,7 +6,14 @@ import torch.nn.functional as F import pytorch_lightning as pl from transformers import AutoTokenizer, AutoModel -from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryF1Score, BinarySpecificity, BinaryRecall, BinaryPrecision +from torchmetrics.classification import ( + BinaryAccuracy, + BinaryAveragePrecision, + BinaryF1Score, + BinarySpecificity, + BinaryRecall, + BinaryPrecision, +) class EntailmentClassifier(pl.LightningModule): @@ -24,7 +31,9 @@ def __init__( self.warmup_steps = warmup_steps self.pos_weight = pos_weight self.max_input_len = max_input_len - self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, model_max_length=max_input_len + ) self.encoder = AutoModel.from_pretrained(model_name) self.fc = nn.Linear(self.encoder.config.hidden_size, 1) self.metrics = { @@ -87,7 +96,7 @@ def configure_optimizers(self) -> Dict[str, Any]: assert self.trainer is not None if self.trainer.max_steps != -1: max_steps = self.trainer.max_steps - else: + else: max_steps = ( self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) # type: ignore