Skip to content

Commit

Permalink
Add LSTMs
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 26, 2024
1 parent b860aa5 commit 2d785c3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 29 deletions.
53 changes: 42 additions & 11 deletions autonmt/modules/models/lstm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random
import torch
import torch.nn as nn

from autonmt.modules.layers import PositionalEmbedding
Expand All @@ -15,9 +17,12 @@ def __init__(self,
decoder_n_layers=2,
encoder_dropout=0.5,
decoder_dropout=0.5,
bidirectional=False,
teacher_force_ratio=0.5,
padding_idx=None,
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs)
self.teacher_forcing_ratio = teacher_force_ratio

# Model
self.src_embeddings = nn.Embedding(src_vocab_size, encoder_embed_dim)
Expand All @@ -26,39 +31,65 @@ def __init__(self,
self.encoder_dropout = nn.Dropout(encoder_dropout)
self.decoder_dropout = nn.Dropout(decoder_dropout)

self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout)
self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout)
self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout, bidirectional=bidirectional, batch_first=True)
self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout, bidirectional=bidirectional, batch_first=True)

self.output_layer = nn.Linear(encoder_embed_dim, trg_vocab_size)
self.output_layer = nn.Linear(decoder_hidden_dim, trg_vocab_size)

# Checks
assert encoder_embed_dim == decoder_embed_dim
assert encoder_hidden_dim == decoder_hidden_dim
assert encoder_n_layers == decoder_n_layers

def forward_encoder(self, x):
# Encode src: (length, batch) => (length, batch, emb_dim)
def forward_encoder(self, x, **kwargs):
# Encode trg: (batch, length) => (batch, length, emb_dim)
x_emb = self.src_embeddings(x)
x_emb = self.encoder_dropout(x_emb)

# input: (length, batch, emb_dim)
# output: (length, batch, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
# cell: (n_layers * n_directions, batch, hidden_dim]
outputs, (hidden, cell) = self.encoder_rnn(x_emb)
return hidden, cell
output, (hidden, cell) = self.encoder_rnn(x_emb)
return output, (hidden, cell)

def forward_decoder(self, y, hidden, cell):
# Encode trg: (1-length, batch) => (length, batch, emb_dim)
def forward_decoder(self, y, hidden, cell, **kwargs):
# Fix y dimensions
if len(y.shape) == 1:
y = y.unsqueeze(1)

# Decode trg: (batch, 1-length) => (batch, length, emb_dim)
y_emb = self.trg_embeddings(y)
y_emb = self.decoder_dropout(y_emb)

# (1-length, batch, emb_dim) =>
# output: (length, batch, hidden_dim * n_directions)
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim]
# cell: (n_layers * n_directions, batch, hidden_dim]
output, (hidden, cell) = self.decoder_rnn(y_emb, (hidden, cell))

# Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size)
output = self.output_layer(output)
return output
return output, (hidden, cell)

def forward_enc_dec(self, x, y):
# Run encoder
_, states = self.forward_encoder(x)

y_pred = y[:, 0] # <sos>
outputs = [] # Doesn't contain <sos> token

# Iterate over trg tokens
trg_length = y.shape[1]
for t in range(trg_length):
outputs_t, states = self.forward_decoder(y_pred, *states) # (B, L, E)
outputs.append(outputs_t) # (B, L, V)

# Next input?
teacher_force = random.random() < self.teacher_forcing_ratio
top1 = outputs_t.argmax(2) # Get most probable next-word (logits)
y_pred = y[:, t] if teacher_force else top1 # Use ground-truth or predicted word

# Concatenate outputs (B, 1, V) => (B, L, V)
outputs = torch.concat(outputs, 1)
return outputs
5 changes: 5 additions & 0 deletions autonmt/modules/models/transfomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ def forward_decoder(self, y, memory):
output = output.transpose(0, 1)
output = self.output_layer(output)
return output

def forward_enc_dec(self, x, y):
memory = self.forward_encoder(x)
output = self.forward_decoder(y, memory)
return output
30 changes: 24 additions & 6 deletions autonmt/modules/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
import math
from collections import defaultdict

Expand Down Expand Up @@ -30,6 +31,14 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, **kwargs):
self.best_scores = defaultdict(float)
self.validation_step_outputs = defaultdict(list)

@abstractmethod
def forward_encoder(self, *args, **kwargs):
pass

@abstractmethod
def forward_decoder(self, *args, **kwargs):
pass

def configure_optimizers(self):
optim_fn = {
"adadelta": torch.optim.Adadelta,
Expand Down Expand Up @@ -107,17 +116,26 @@ def on_validation_epoch_end(self):
# Free memory
self.validation_step_outputs.clear()

def forward_enc_dec(self, x, y):
values = self.forward_encoder(x)
values = self.forward_decoder(y, **values) # (B, L, E)
output = values["output"]
return output

def _step(self, batch, batch_idx, log_prefix):
x, y = batch

# Forward
output = self.forward_encoder(x)
output = self.forward_decoder(y, output) # (B, L, E)
# Forward => (Batch, Length) => (Batch, Length, Vocab)
# The input of the decoder needs the <sos>, but its output is shifted as it starts with the first word, not
# with the <sos>. Therefore, we need to remove the last token from 'y'
output = self.forward_enc_dec(x, y[:, :-1])

# Remove the <sos> token from the target
y = y[:, 1:]

# Compute loss
output = output.transpose(1, 2)[:, :, :-1] # Remove last index to match shape with 'y[1:]'
y = y[:, 1:] # Remove <sos>
loss = self.criterion_fn(output, y)
output = output.transpose(1, 2) # (B, L, V) => (B, V, L)
loss = self.criterion_fn(output, y) # (B, V, L) vs (B, L)

# Apply regularization
if self.regularization_fn:
Expand Down
19 changes: 7 additions & 12 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
torch.set_float32_matmul_precision("high")

from autonmt.modules.models import Transformer
from autonmt.modules.models import Transformer, LSTM
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import AutonmtTranslator
from autonmt.vocabularies import Vocabulary
Expand All @@ -31,21 +31,16 @@ def main():

# Set of datasets, languages, training sizes to try
datasets=[
# {"name": "multi30k", "languages": ["en-es"], "sizes": [("original", None)]},
{"name": "europarl", "languages": ["en-es"], "sizes": [("50k", 50000), ("100k", 100000), ("original", None)]},
{"name": "multi30k", "languages": ["en-de"], "sizes": [("original", None)]},
# {"name": "europarl", "languages": ["en-es"], "sizes": [("50k", 50000), ("100k", 100000), ("original", None)]},
],

# Set of subword models and vocab sizes to try
encoding=[
# {"subword_models": ["bytes"], "vocab_sizes": [1000]},
# {"subword_models": ["char"], "vocab_sizes": [1000]},
# {"subword_models": ["bpe"], "vocab_sizes": [8000]},
# {"subword_models": ["word"], "vocab_sizes": [8000]},

{"subword_models": ["bytes"], "vocab_sizes": [1000]},
{"subword_models": ["char"], "vocab_sizes": [1000]},
{"subword_models": ["bpe"], "vocab_sizes": [16000]},
{"subword_models": ["word"], "vocab_sizes": [32000]},
{"subword_models": ["word"], "vocab_sizes": [8000]},
],

# Preprocessing functions
Expand Down Expand Up @@ -75,11 +70,11 @@ def main():
else:
raise ValueError(f"Unknown subword model: {train_ds.subword_model}")

for iters in [100]:
for iters in [10]:
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = LSTM(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
Expand All @@ -96,7 +91,7 @@ def main():

# Train model
wandb_params = None #dict(project="vocab-comparison", entity="salvacarrion", reinit=True)
trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=256, seed=None,
trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=None,
patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
wandb_params=wandb_params)

Expand Down

0 comments on commit 2d785c3

Please sign in to comment.