From 2d785c3526df36e92519bb82e7ec148510daa7b1 Mon Sep 17 00:00:00 2001 From: salvacarrion Date: Wed, 26 Jun 2024 21:09:33 +0200 Subject: [PATCH] Add LSTMs --- autonmt/modules/models/lstm.py | 53 ++++++++++++++++++++++------ autonmt/modules/models/transfomer.py | 5 +++ autonmt/modules/seq2seq.py | 30 ++++++++++++---- examples/dev/0_test_custom_model.py | 19 ++++------ 4 files changed, 78 insertions(+), 29 deletions(-) diff --git a/autonmt/modules/models/lstm.py b/autonmt/modules/models/lstm.py index abc84da..f72948c 100644 --- a/autonmt/modules/models/lstm.py +++ b/autonmt/modules/models/lstm.py @@ -1,3 +1,5 @@ +import random +import torch import torch.nn as nn from autonmt.modules.layers import PositionalEmbedding @@ -15,9 +17,12 @@ def __init__(self, decoder_n_layers=2, encoder_dropout=0.5, decoder_dropout=0.5, + bidirectional=False, + teacher_force_ratio=0.5, padding_idx=None, **kwargs): super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs) + self.teacher_forcing_ratio = teacher_force_ratio # Model self.src_embeddings = nn.Embedding(src_vocab_size, encoder_embed_dim) @@ -26,18 +31,18 @@ def __init__(self, self.encoder_dropout = nn.Dropout(encoder_dropout) self.decoder_dropout = nn.Dropout(decoder_dropout) - self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout) - self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout) + self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout, bidirectional=bidirectional, batch_first=True) + self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout, bidirectional=bidirectional, batch_first=True) - self.output_layer = nn.Linear(encoder_embed_dim, trg_vocab_size) + self.output_layer = nn.Linear(decoder_hidden_dim, trg_vocab_size) # Checks assert encoder_embed_dim == decoder_embed_dim assert encoder_hidden_dim == decoder_hidden_dim assert encoder_n_layers == decoder_n_layers - def forward_encoder(self, x): - # Encode src: (length, batch) => (length, batch, emb_dim) + def forward_encoder(self, x, **kwargs): + # Encode trg: (batch, length) => (batch, length, emb_dim) x_emb = self.src_embeddings(x) x_emb = self.encoder_dropout(x_emb) @@ -45,20 +50,46 @@ def forward_encoder(self, x): # output: (length, batch, hidden_dim * n_directions) # hidden: (n_layers * n_directions, batch, hidden_dim) # cell: (n_layers * n_directions, batch, hidden_dim] - outputs, (hidden, cell) = self.encoder_rnn(x_emb) - return hidden, cell + output, (hidden, cell) = self.encoder_rnn(x_emb) + return output, (hidden, cell) - def forward_decoder(self, y, hidden, cell): - # Encode trg: (1-length, batch) => (length, batch, emb_dim) + def forward_decoder(self, y, hidden, cell, **kwargs): + # Fix y dimensions + if len(y.shape) == 1: + y = y.unsqueeze(1) + + # Decode trg: (batch, 1-length) => (batch, length, emb_dim) y_emb = self.trg_embeddings(y) y_emb = self.decoder_dropout(y_emb) # (1-length, batch, emb_dim) => - # output: (length, batch, hidden_dim * n_directions) + # output: (batch, length, hidden_dim * n_directions) # hidden: (n_layers * n_directions, batch, hidden_dim] # cell: (n_layers * n_directions, batch, hidden_dim] output, (hidden, cell) = self.decoder_rnn(y_emb, (hidden, cell)) # Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size) output = self.output_layer(output) - return output + return output, (hidden, cell) + + def forward_enc_dec(self, x, y): + # Run encoder + _, states = self.forward_encoder(x) + + y_pred = y[:, 0] # + outputs = [] # Doesn't contain token + + # Iterate over trg tokens + trg_length = y.shape[1] + for t in range(trg_length): + outputs_t, states = self.forward_decoder(y_pred, *states) # (B, L, E) + outputs.append(outputs_t) # (B, L, V) + + # Next input? + teacher_force = random.random() < self.teacher_forcing_ratio + top1 = outputs_t.argmax(2) # Get most probable next-word (logits) + y_pred = y[:, t] if teacher_force else top1 # Use ground-truth or predicted word + + # Concatenate outputs (B, 1, V) => (B, L, V) + outputs = torch.concat(outputs, 1) + return outputs diff --git a/autonmt/modules/models/transfomer.py b/autonmt/modules/models/transfomer.py index 45c8904..e97a704 100644 --- a/autonmt/modules/models/transfomer.py +++ b/autonmt/modules/models/transfomer.py @@ -75,3 +75,8 @@ def forward_decoder(self, y, memory): output = output.transpose(0, 1) output = self.output_layer(output) return output + + def forward_enc_dec(self, x, y): + memory = self.forward_encoder(x) + output = self.forward_decoder(y, memory) + return output diff --git a/autonmt/modules/seq2seq.py b/autonmt/modules/seq2seq.py index 0cc5fde..5c48427 100644 --- a/autonmt/modules/seq2seq.py +++ b/autonmt/modules/seq2seq.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import math from collections import defaultdict @@ -30,6 +31,14 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, **kwargs): self.best_scores = defaultdict(float) self.validation_step_outputs = defaultdict(list) + @abstractmethod + def forward_encoder(self, *args, **kwargs): + pass + + @abstractmethod + def forward_decoder(self, *args, **kwargs): + pass + def configure_optimizers(self): optim_fn = { "adadelta": torch.optim.Adadelta, @@ -107,17 +116,26 @@ def on_validation_epoch_end(self): # Free memory self.validation_step_outputs.clear() + def forward_enc_dec(self, x, y): + values = self.forward_encoder(x) + values = self.forward_decoder(y, **values) # (B, L, E) + output = values["output"] + return output + def _step(self, batch, batch_idx, log_prefix): x, y = batch - # Forward - output = self.forward_encoder(x) - output = self.forward_decoder(y, output) # (B, L, E) + # Forward => (Batch, Length) => (Batch, Length, Vocab) + # The input of the decoder needs the , but its output is shifted as it starts with the first word, not + # with the . Therefore, we need to remove the last token from 'y' + output = self.forward_enc_dec(x, y[:, :-1]) + + # Remove the token from the target + y = y[:, 1:] # Compute loss - output = output.transpose(1, 2)[:, :, :-1] # Remove last index to match shape with 'y[1:]' - y = y[:, 1:] # Remove - loss = self.criterion_fn(output, y) + output = output.transpose(1, 2) # (B, L, V) => (B, V, L) + loss = self.criterion_fn(output, y) # (B, V, L) vs (B, L) # Apply regularization if self.regularization_fn: diff --git a/examples/dev/0_test_custom_model.py b/examples/dev/0_test_custom_model.py index 01e6c20..e8cf9ae 100644 --- a/examples/dev/0_test_custom_model.py +++ b/examples/dev/0_test_custom_model.py @@ -4,7 +4,7 @@ import torch torch.set_float32_matmul_precision("high") -from autonmt.modules.models import Transformer +from autonmt.modules.models import Transformer, LSTM from autonmt.preprocessing import DatasetBuilder from autonmt.toolkits import AutonmtTranslator from autonmt.vocabularies import Vocabulary @@ -31,8 +31,8 @@ def main(): # Set of datasets, languages, training sizes to try datasets=[ - # {"name": "multi30k", "languages": ["en-es"], "sizes": [("original", None)]}, - {"name": "europarl", "languages": ["en-es"], "sizes": [("50k", 50000), ("100k", 100000), ("original", None)]}, + {"name": "multi30k", "languages": ["en-de"], "sizes": [("original", None)]}, + # {"name": "europarl", "languages": ["en-es"], "sizes": [("50k", 50000), ("100k", 100000), ("original", None)]}, ], # Set of subword models and vocab sizes to try @@ -40,12 +40,7 @@ def main(): # {"subword_models": ["bytes"], "vocab_sizes": [1000]}, # {"subword_models": ["char"], "vocab_sizes": [1000]}, # {"subword_models": ["bpe"], "vocab_sizes": [8000]}, - # {"subword_models": ["word"], "vocab_sizes": [8000]}, - - {"subword_models": ["bytes"], "vocab_sizes": [1000]}, - {"subword_models": ["char"], "vocab_sizes": [1000]}, - {"subword_models": ["bpe"], "vocab_sizes": [16000]}, - {"subword_models": ["word"], "vocab_sizes": [32000]}, + {"subword_models": ["word"], "vocab_sizes": [8000]}, ], # Preprocessing functions @@ -75,11 +70,11 @@ def main(): else: raise ValueError(f"Unknown subword model: {train_ds.subword_model}") - for iters in [100]: + for iters in [10]: # Instantiate vocabs and model src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang) trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang) - model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) + model = LSTM(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) # Define trainer runs_dir = train_ds.get_runs_path(toolkit="autonmt") @@ -96,7 +91,7 @@ def main(): # Train model wandb_params = None #dict(project="vocab-comparison", entity="salvacarrion", reinit=True) - trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=256, seed=None, + trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=None, patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1, wandb_params=wandb_params)