Skip to content

Commit

Permalink
Add gru attention + improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 27, 2024
1 parent 8c65571 commit 42ff301
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion autonmt/modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autonmt.modules.models.transfomer import Transformer
from autonmt.modules.models.rnn import GenericRNN, GRU
from autonmt.modules.models.rnn import GenericRNN, GRU, GRUAttention

60 changes: 32 additions & 28 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ def __init__(self,
assert encoder_hidden_dim == decoder_hidden_dim
assert encoder_n_layers == decoder_n_layers

@staticmethod
def get_base_rnn(architecture):
# Choose architecture
architecture = architecture.lower().strip()
if architecture == "rnn":
return nn.RNN
elif architecture == "lstm":
return nn.LSTM
elif architecture == "gru":
return nn.GRU
else:
raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")

def forward_encoder(self, x):
# Encode trg: (batch, length) => (batch, length, emb_dim)
x_emb = self.src_embeddings(x)
Expand Down Expand Up @@ -109,18 +122,7 @@ class GenericRNN(BaseRNN):

def __init__(self, architecture="lstm", **kwargs):
super().__init__(architecture=architecture, **kwargs)

# Choose architecture
architecture = architecture.lower().strip()
if architecture == "rnn":
base_rnn = nn.RNN
elif architecture == "lstm":
base_rnn = nn.LSTM
elif architecture == "gru":
base_rnn = nn.GRU
else:
raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")

base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
Expand Down Expand Up @@ -205,21 +207,22 @@ def forward(self, hidden, encoder_outputs):
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"
return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)

class GenericRNNAttention(BaseRNN):
class GRUAttention(BaseRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
self.encoder_rnn = nn.GRU(input_size=self.encoder_embed_dim,
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=1,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=True, batch_first=True)
self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim + self.encoder_hidden_dim*2,
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim*2,
hidden_size=self.decoder_hidden_dim,
num_layers=1,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=False, batch_first=True)
self.attention = Attention(self.encoder_hidden_dim, self.decoder_hidden_dim)
self.enc_ffn = nn.Linear(self.encoder_hidden_dim * 2, self.decoder_hidden_dim)
self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, self.decoder_hidden_dim * self.decoder_n_layers)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim*2, self.trg_vocab_size)

def forward_encoder(self, x):
Expand All @@ -228,14 +231,15 @@ def forward_encoder(self, x):
# hidden: (n_layers * n_directions, batch, hidden_dim)
output, hidden = super().forward_encoder(x)

# bidirectional hidden is stacked [forward_1, backward_1, forward_2, backward_2,...]
# hidden [-2, :, : ] is the last of the forwards RNN
# hidden [-1, :, : ] is the last of the backwards RNN
# Concat hidden layers (back and forward) from the last layer: (B, emb) => (B, emb*2)
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
# Reshape hidden to (batch, n_layers * n_directions * hidden_dim)
hidden = hidden.transpose(0, 1).contiguous().view(hidden.size(1), -1)

# Apply the linear transformation to each layer to match the decoder hidden size
# Apply the linear transformation
hidden = torch.tanh(self.enc_ffn(hidden))

# Reshape back to (n_layers, batch, decoder_hidden_dim)
hidden = hidden.view(self.encoder_n_layers, -1, self.decoder_hidden_dim)

return output, (hidden, output)

def forward_decoder(self, y, states):
Expand All @@ -251,18 +255,18 @@ def forward_decoder(self, y, states):
y_emb = self.trg_embeddings(y)
y_emb = self.dec_dropout(y_emb)

# Attention
attn = self.attention(hidden, enc_outputs)
# Attention (using only the top layer of hidden state)
attn = self.attention(hidden[-1], enc_outputs)
attn = attn.unsqueeze(1) # (B, L) => (B, 1, L)
weighted = torch.bmm(attn, enc_outputs) # (B, 1, L) x (B, L, H) => (B, 1, H)

# intput: (batch, 1-length, emb_dim+w_emb_dim), (1, batch, hidden_dim)
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
rnn_input = torch.cat((y_emb, weighted), dim=2)
output, hidden = self.decoder_rnn(rnn_input, hidden.unsqueeze(0))
output, hidden = self.decoder_rnn(rnn_input, hidden)

# Get output: => (B, 1-length, H+H+H)
output = torch.cat((output, weighted, y_emb), dim=2)
output = self.output_layer(output) # (B, 1, H+H+H) => (B, 1, V)
return output, (hidden.squeeze(0), enc_outputs) # pass enc_outputs (trick)
return output, (hidden, enc_outputs) # pass enc_outputs (trick)
4 changes: 2 additions & 2 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
torch.set_float32_matmul_precision("high")

from autonmt.modules.models import Transformer, GenericRNN, GRU
from autonmt.modules.models import Transformer, GenericRNN, GRU, GRUAttention
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import AutonmtTranslator
from autonmt.vocabularies import Vocabulary
Expand Down Expand Up @@ -74,7 +74,7 @@ def main():
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = GRU(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = GRUAttention(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
Expand Down

0 comments on commit 42ff301

Please sign in to comment.