-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrnn.py
129 lines (105 loc) · 3.78 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2021-05-22 14:50:34
# @Author : Chenghao Mou ([email protected])
import torch
import numpy as np
import torch.nn as nn
import pytorch_lightning as pl
from datasets import load_dataset
from torch.utils.data import DataLoader
from typing import Optional
from text_embeddings.visual import VTRTokenizer
from einops import rearrange
class Model(pl.LightningModule):
def __init__(
self, hidden: int = 128, learning_rate: float = 1e-3, num_labels: int = 20
):
super().__init__()
self.model = nn.GRU(
hidden, hidden, num_layers=2, bidirectional=True, batch_first=True
)
self.nonlinear = nn.ReLU()
self.fc = nn.Linear(hidden * 2, num_labels)
self.loss = nn.CrossEntropyLoss(ignore_index=0)
self.lr = learning_rate
def forward(self, batch):
embeddings = batch["input_ids"].float()
logits, _ = self.model(rearrange(embeddings, "b s h w -> b s (h w)"))
logits = torch.cat(
[
logits[:, :, : logits.shape[-1] // 2],
logits[:, :, logits.shape[-1] // 2 :],
],
dim=-1,
)
logits = torch.mean(logits, dim=1)
logits = self.nonlinear(logits)
logits = self.fc(logits)
return logits
def training_step(self, batch, batch_idx):
inputs, labels = batch
logits = self.forward(inputs)
return {"loss": self.loss(logits, labels)}
def validation_step(self, batch, batch_idx):
inputs, labels = batch
logits = self.forward(inputs)
# logger.debug(f"{labels.shape, logits.shape}")
loss = self.loss(logits, labels)
self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return {"val_loss": loss}
def configure_optimizers(self):
return {"optimizer": torch.optim.Adam(self.parameters(), lr=self.lr)}
class DataModule(pl.LightningDataModule):
def __init__(
self,
dataset_name: str,
font_path="/home/chenghaomou/embeddings/Noto_Sans/NotoSans-Regular.ttf",
font_size: int = 16,
window_size: int = 8,
stride: int = 5,
batch_size: int = 8,
subtask: Optional[str] = None,
):
super().__init__()
self.dataset = (
load_dataset(dataset_name, subtask)
if subtask
else load_dataset(dataset_name)
)
self.tokenizer = VTRTokenizer(
font=font_path, window_size=window_size, font_size=font_size, stride=stride
)
self.batch_size = batch_size
def setup(self, stage=None):
self.train = self.dataset["train"]
self.val = self.dataset["test"]
def train_dataloader(self):
return DataLoader(
[{"text": x["text"], "label": x["label"]} for x in self.train],
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=4,
)
def val_dataloader(self):
return DataLoader(
[{"text": x["text"], "label": x["label"]} for x in self.val],
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=4,
)
def collate_fn(self, examples):
text = [e["text"] for e in examples]
labels = [e["label"] for e in examples]
results = self.tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation="longest_first",
return_attention_mask=True,
return_token_type_ids=False,
)
return results, torch.from_numpy(np.asarray(labels)).long()
if __name__ == "__main__":
from pytorch_lightning.utilities.cli import LightningCLI
cli = LightningCLI(Model, DataModule)