Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Apr 27, 2023
1 parent 0a81619 commit a8caec7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
8 changes: 6 additions & 2 deletions prover/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def __init__(
) -> None:
super().__init__()
max_len = max(max_input_len, max_output_len)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_len)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=max_len
)
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.is_train = is_train
Expand Down Expand Up @@ -213,7 +215,9 @@ def __init__(
) -> None:
super().__init__()
max_len = max(max_input_len, max_output_len)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_len)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=max_len
)
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.sample_goal = sample_goal
Expand Down
6 changes: 4 additions & 2 deletions prover/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __init__(
EntailmentClassifier.load_from_checkpoint(verifier_ckpt)
] # Avoid making the verifier a submodule.

self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=max_input_len
)
if (
model_name.startswith("t5-")
or model_name.startswith("google/t5-v1_1-")
Expand Down Expand Up @@ -680,7 +682,7 @@ def configure_optimizers(self) -> Dict[str, Any]:
assert self.trainer is not None
if self.trainer.max_steps != -1:
max_steps = self.trainer.max_steps
else:
else:
max_steps = (
self.trainer.max_epochs
* len(self.trainer.datamodule.train_dataloader()) # type: ignore
Expand Down
4 changes: 3 additions & 1 deletion verifier/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __init__(
irrelevant_distractors_only: bool,
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=max_input_len
)
assert split in ("train", "val")
self.split = split
self.max_num_premises = max_num_premises # The maximum number of premises used in data augmentation.
Expand Down
15 changes: 12 additions & 3 deletions verifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import AutoTokenizer, AutoModel
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryF1Score, BinarySpecificity, BinaryRecall, BinaryPrecision
from torchmetrics.classification import (
BinaryAccuracy,
BinaryAveragePrecision,
BinaryF1Score,
BinarySpecificity,
BinaryRecall,
BinaryPrecision,
)


class EntailmentClassifier(pl.LightningModule):
Expand All @@ -24,7 +31,9 @@ def __init__(
self.warmup_steps = warmup_steps
self.pos_weight = pos_weight
self.max_input_len = max_input_len
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=max_input_len)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=max_input_len
)
self.encoder = AutoModel.from_pretrained(model_name)
self.fc = nn.Linear(self.encoder.config.hidden_size, 1)
self.metrics = {
Expand Down Expand Up @@ -87,7 +96,7 @@ def configure_optimizers(self) -> Dict[str, Any]:
assert self.trainer is not None
if self.trainer.max_steps != -1:
max_steps = self.trainer.max_steps
else:
else:
max_steps = (
self.trainer.max_epochs
* len(self.trainer.datamodule.train_dataloader()) # type: ignore
Expand Down

0 comments on commit a8caec7

Please sign in to comment.