diff --git a/autonmt/modules/models/__init__.py b/autonmt/modules/models/__init__.py index 2357af1..f00d43e 100644 --- a/autonmt/modules/models/__init__.py +++ b/autonmt/modules/models/__init__.py @@ -1,3 +1,3 @@ from autonmt.modules.models.transfomer import Transformer -from autonmt.modules.models.rnn import BaseRNN, LSTM, GRU +from autonmt.modules.models.rnn import GenericRNN, GRU diff --git a/autonmt/modules/models/rnn.py b/autonmt/modules/models/rnn.py index d44121d..4cc6515 100644 --- a/autonmt/modules/models/rnn.py +++ b/autonmt/modules/models/rnn.py @@ -71,10 +71,10 @@ def forward_decoder(self, y, states): y_emb = self.trg_embeddings(y) y_emb = self.dec_dropout(y_emb) - # (1-length, batch, emb_dim) => + # intput: (batch, 1-length, emb_dim), (n_layers * n_directions, batch, hidden_dim) => # output: (batch, length, hidden_dim * n_directions) # hidden: (n_layers * n_directions, batch, hidden_dim) - # cell: (n_layers * n_directions, batch, hidden_dim) + # cell*: (n_layers * n_directions, batch, hidden_dim) output, states = self.decoder_rnn(y_emb, states) # Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size) @@ -104,31 +104,80 @@ def forward_enc_dec(self, x, y): return outputs -class LSTM(BaseRNN): +class GenericRNN(BaseRNN): + + def __init__(self, architecture, **kwargs): + super().__init__(architecture="lstm", **kwargs) + + # Choose architecture + architecture = architecture.lower().strip() + if architecture == "rnn": + base_rnn = nn.RNN + elif architecture == "lstm": + base_rnn = nn.LSTM + elif architecture == "gru": + base_rnn = nn.GRU + else: + raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'") + + self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim, + hidden_size=self.encoder_hidden_dim, + num_layers=self.encoder_n_layers, + dropout=self.encoder_dropout, + bidirectional=self.bidirectional, batch_first=True) + self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim, + hidden_size=self.decoder_hidden_dim, + num_layers=self.decoder_n_layers, + dropout=self.decoder_dropout, + bidirectional=self.bidirectional, batch_first=True) - def __init__(self, *args, **kwargs): - super().__init__(*args, architecture="lstm", **kwargs) - self.encoder_rnn = nn.LSTM(input_size=self.encoder_embed_dim, - hidden_size=self.encoder_hidden_dim, - num_layers=self.encoder_n_layers, - dropout=self.encoder_dropout, - bidirectional=self.bidirectional, batch_first=True) - self.decoder_rnn = nn.LSTM(input_size=self.decoder_embed_dim, - hidden_size=self.decoder_hidden_dim, - num_layers=self.decoder_n_layers, - dropout=self.decoder_dropout, - bidirectional=self.bidirectional, batch_first=True) class GRU(BaseRNN): def __init__(self, *args, **kwargs): super().__init__(*args, architecture="gru", **kwargs) self.encoder_rnn = nn.GRU(input_size=self.encoder_embed_dim, - hidden_size=self.encoder_hidden_dim, - num_layers=self.encoder_n_layers, - dropout=self.encoder_dropout, - bidirectional=self.bidirectional, batch_first=True) - self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim, - hidden_size=self.decoder_hidden_dim, - num_layers=self.decoder_n_layers, - dropout=self.decoder_dropout, - bidirectional=self.bidirectional, batch_first=True) \ No newline at end of file + hidden_size=self.encoder_hidden_dim, + num_layers=self.encoder_n_layers, + dropout=self.encoder_dropout, + bidirectional=self.bidirectional, batch_first=True) + self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim + self.encoder_hidden_dim, + hidden_size=self.decoder_hidden_dim, + num_layers=self.decoder_n_layers, + dropout=self.decoder_dropout, + bidirectional=self.bidirectional, batch_first=True) + self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim*2, self.trg_vocab_size) + + def forward_encoder(self, x): + output, states = super().forward_encoder(x) + context = states.clone() + return output, (states, context) # (hidden, context) + + def forward_decoder(self, y, states): + hidden, context = states + + # Fix "y" dimensions + if len(y.shape) == 1: # (batch) => (batch, 1) + y = y.unsqueeze(1) + if len(y.shape) == 2 and y.shape[1] > 1: + y = y[:, -1].unsqueeze(1) # Get last value + + # Decode trg: (batch, 1-length) => (batch, length, emb_dim) + y_emb = self.trg_embeddings(y) + y_emb = self.dec_dropout(y_emb) + + # Add context + tmp_context = context.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer + y_context = torch.cat((y_emb, tmp_context), dim=2) + + # intput: (batch, 1-length, emb_dim), (n_layers * n_directions, batch, hidden_dim) + # output: (batch, length, hidden_dim * n_directions) + # hidden: (n_layers * n_directions, batch, hidden_dim) + output, hidden = self.decoder_rnn(y_context, hidden) + + # Add context + tmp_hidden = hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer + output = torch.cat((y_emb, tmp_hidden, tmp_context), dim=2) + + # Get output: (batch, length, hidden_dim * n_directions) => (batch, length, trg_vocab_size) + output = self.output_layer(output) + return output, (hidden, context) diff --git a/docs/references/benchmarks.md b/docs/references/benchmarks.md index 59a3b69..e855517 100644 --- a/docs/references/benchmarks.md +++ b/docs/references/benchmarks.md @@ -146,12 +146,24 @@ vocab__subword_model vocab__size train__lang_pair test__lang_pair word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 32.626757 -# Simple LSTM (2 layers, 512 hidden units => 7.4M params) +# Simple LSTM (2 layers, 512 hidden units => ?-7.4M params) vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score bytes 260 de-en de-en multi30k_de-en_original_bytes_1000 multi30k 4.688523 char 101/101 de-en de-en multi30k_de-en_original_char_1000 multi30k 4.404792 bpe 4000 de-en de-en multi30k_de-en_original_bpe_4000 multi30k 9.013423 word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 10.356350 + + +# z+GRU (1 layer, 512 hidden units => 3-10.3 M params) +vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score + bytes 260 de-en de-en multi30k_de-en_original_bytes_1000 multi30k 4.221123 + char 101/101 de-en de-en multi30k_de-en_original_char_1000 multi30k 4.441055 + bpe 4000 de-en de-en multi30k_de-en_original_bpe_4000 multi30k 13.061608 + word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 13.239970 + +# z+GRU (2 layers, 512 hidden units => 3-10.3 M params) +vocab__subword_model vocab__size train__lang_pair test__lang_pair train_dataset test_dataset translations.beam1.sacrebleu_bleu_score + word 4000 de-en de-en multi30k_de-en_original_word_4000 multi30k 13.032078 ``` ------ diff --git a/examples/dev/0_test_custom_model.py b/examples/dev/0_test_custom_model.py index 25de617..d244d3a 100644 --- a/examples/dev/0_test_custom_model.py +++ b/examples/dev/0_test_custom_model.py @@ -4,7 +4,7 @@ import torch torch.set_float32_matmul_precision("high") -from autonmt.modules.models import Transformer, LSTM, GRU +from autonmt.modules.models import Transformer, GenericRNN, GRU from autonmt.preprocessing import DatasetBuilder from autonmt.toolkits import AutonmtTranslator from autonmt.vocabularies import Vocabulary @@ -70,7 +70,7 @@ def main(): else: raise ValueError(f"Unknown subword model: {train_ds.subword_model}") - for iters in [5]: + for iters in [10]: # Instantiate vocabs and model src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang) trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang) @@ -78,7 +78,7 @@ def main(): # Define trainer runs_dir = train_ds.get_runs_path(toolkit="autonmt") - run_prefix = f"{model.architecture}-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-') + run_prefix = f"{model.architecture}-2L-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-') run_name = train_ds.get_run_name(run_prefix=run_prefix) #+ f"__{int(time.time())}" trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab, runs_dir=runs_dir, run_name=run_name)