Skip to content

Commit

Permalink
RNNs improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 27, 2024
1 parent 4aed30f commit ec217e3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 29 deletions.
2 changes: 1 addition & 1 deletion autonmt/modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autonmt.modules.models.transfomer import Transformer
from autonmt.modules.models.rnn import BaseRNN, LSTM, GRU
from autonmt.modules.models.rnn import GenericRNN, GRU

97 changes: 73 additions & 24 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def forward_decoder(self, y, states):
y_emb = self.trg_embeddings(y)
y_emb = self.dec_dropout(y_emb)

# (1-length, batch, emb_dim) =>
# intput: (batch, 1-length, emb_dim), (n_layers * n_directions, batch, hidden_dim) =>
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
# cell: (n_layers * n_directions, batch, hidden_dim)
# cell*: (n_layers * n_directions, batch, hidden_dim)
output, states = self.decoder_rnn(y_emb, states)

# Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size)
Expand Down Expand Up @@ -104,31 +104,80 @@ def forward_enc_dec(self, x, y):
return outputs


class LSTM(BaseRNN):
class GenericRNN(BaseRNN):

def __init__(self, architecture, **kwargs):
super().__init__(architecture="lstm", **kwargs)

# Choose architecture
architecture = architecture.lower().strip()
if architecture == "rnn":
base_rnn = nn.RNN
elif architecture == "lstm":
base_rnn = nn.LSTM
elif architecture == "gru":
base_rnn = nn.GRU
else:
raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")

self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)

def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="lstm", **kwargs)
self.encoder_rnn = nn.LSTM(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = nn.LSTM(input_size=self.decoder_embed_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)

class GRU(BaseRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
self.encoder_rnn = nn.GRU(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim + self.encoder_hidden_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim*2, self.trg_vocab_size)

def forward_encoder(self, x):
output, states = super().forward_encoder(x)
context = states.clone()
return output, (states, context) # (hidden, context)

def forward_decoder(self, y, states):
hidden, context = states

# Fix "y" dimensions
if len(y.shape) == 1: # (batch) => (batch, 1)
y = y.unsqueeze(1)
if len(y.shape) == 2 and y.shape[1] > 1:
y = y[:, -1].unsqueeze(1) # Get last value

# Decode trg: (batch, 1-length) => (batch, length, emb_dim)
y_emb = self.trg_embeddings(y)
y_emb = self.dec_dropout(y_emb)

# Add context
tmp_context = context.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
y_context = torch.cat((y_emb, tmp_context), dim=2)

# intput: (batch, 1-length, emb_dim), (n_layers * n_directions, batch, hidden_dim)
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
output, hidden = self.decoder_rnn(y_context, hidden)

# Add context
tmp_hidden = hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
output = torch.cat((y_emb, tmp_hidden, tmp_context), dim=2)

# Get output: (batch, length, hidden_dim * n_directions) => (batch, length, trg_vocab_size)
output = self.output_layer(output)
return output, (hidden, context)
14 changes: 13 additions & 1 deletion docs/references/benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,24 @@ vocab__subword_model vocab__size train__lang_pair test__lang_pair
word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 32.626757
# Simple LSTM (2 layers, 512 hidden units => 7.4M params)
# Simple LSTM (2 layers, 512 hidden units => ?-7.4M params)
vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score
bytes 260 de-en de-en multi30k_de-en_original_bytes_1000 multi30k 4.688523
char 101/101 de-en de-en multi30k_de-en_original_char_1000 multi30k 4.404792
bpe 4000 de-en de-en multi30k_de-en_original_bpe_4000 multi30k 9.013423
word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 10.356350
# z+GRU (1 layer, 512 hidden units => 3-10.3 M params)
vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score
bytes 260 de-en de-en multi30k_de-en_original_bytes_1000 multi30k 4.221123
char 101/101 de-en de-en multi30k_de-en_original_char_1000 multi30k 4.441055
bpe 4000 de-en de-en multi30k_de-en_original_bpe_4000 multi30k 13.061608
word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 13.239970
# z+GRU (2 layers, 512 hidden units => 3-10.3 M params)
vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score
word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 13.032078
```
------

Expand Down
6 changes: 3 additions & 3 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
torch.set_float32_matmul_precision("high")

from autonmt.modules.models import Transformer, LSTM, GRU
from autonmt.modules.models import Transformer, GenericRNN, GRU
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import AutonmtTranslator
from autonmt.vocabularies import Vocabulary
Expand Down Expand Up @@ -70,15 +70,15 @@ def main():
else:
raise ValueError(f"Unknown subword model: {train_ds.subword_model}")

for iters in [5]:
for iters in [10]:
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = GRU(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
run_prefix = f"{model.architecture}-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_prefix = f"{model.architecture}-2L-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_name = train_ds.get_run_name(run_prefix=run_prefix) #+ f"__{int(time.time())}"
trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab,
runs_dir=runs_dir, run_name=run_name)
Expand Down

0 comments on commit ec217e3

Please sign in to comment.