Skip to content

Commit

Permalink
Generalize RNNs
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jul 1, 2024
1 parent 42ff301 commit 636adb1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 65 deletions.
2 changes: 1 addition & 1 deletion autonmt/modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autonmt.modules.models.transfomer import Transformer
from autonmt.modules.models.rnn import GenericRNN, GRU, GRUAttention
from autonmt.modules.models.rnn import *

129 changes: 68 additions & 61 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autonmt.modules.seq2seq import LitSeq2Seq


class BaseRNN(LitSeq2Seq):
class SimpleRNN(LitSeq2Seq):
def __init__(self,
src_vocab_size, trg_vocab_size,
encoder_embed_dim=256,
Expand All @@ -18,10 +18,11 @@ def __init__(self,
decoder_n_layers=2,
encoder_dropout=0.5,
decoder_dropout=0.5,
bidirectional=False,
encoder_bidirectional=False,
decoder_bidirectional=False,
teacher_force_ratio=0.5,
padding_idx=None,
architecture="base_rnn",
architecture="gru",
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture=architecture, **kwargs)
self.encoder_embed_dim = encoder_embed_dim
Expand All @@ -32,18 +33,34 @@ def __init__(self,
self.decoder_n_layers = decoder_n_layers
self.encoder_dropout = encoder_dropout
self.decoder_dropout = decoder_dropout
self.bidirectional = bidirectional
self.encoder_bidirectional = encoder_bidirectional
self.decoder_bidirectional = decoder_bidirectional
self.teacher_forcing_ratio = teacher_force_ratio

# Model
self.src_embeddings = nn.Embedding(src_vocab_size, encoder_embed_dim)
self.trg_embeddings = nn.Embedding(trg_vocab_size, decoder_embed_dim)
self.enc_dropout = nn.Dropout(encoder_dropout)
self.dec_dropout = nn.Dropout(decoder_dropout)
self.encoder_rnn = None
self.decoder_rnn = None
self.output_layer = nn.Linear(decoder_hidden_dim, trg_vocab_size)

# RNN
base_rnn = self.get_base_rnn(self.architecture)
if base_rnn is None:
self.encoder_rnn = None
self.decoder_rnn = None
else:
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.encoder_bidirectional, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.decoder_bidirectional, batch_first=True)

# Checks
assert encoder_embed_dim == decoder_embed_dim
assert encoder_hidden_dim == decoder_hidden_dim
Expand All @@ -60,7 +77,8 @@ def get_base_rnn(architecture):
elif architecture == "gru":
return nn.GRU
else:
raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")
return None
# raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")

def forward_encoder(self, x):
# Encode trg: (batch, length) => (batch, length, emb_dim)
Expand Down Expand Up @@ -118,45 +136,35 @@ def forward_enc_dec(self, x, y):
return outputs


class GenericRNN(BaseRNN):

def __init__(self, architecture="lstm", **kwargs):
super().__init__(architecture=architecture, **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)


class GRU(BaseRNN):
class ContextRNN(SimpleRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
self.encoder_rnn = nn.GRU(input_size=self.encoder_embed_dim,
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = nn.GRU(input_size=self.decoder_embed_dim + self.encoder_hidden_dim,
bidirectional=self.encoder_bidirectional, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.bidirectional, batch_first=True)
bidirectional=self.decoder_bidirectional, batch_first=True)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim*2, self.trg_vocab_size)

def forward_encoder(self, x):
output, states = super().forward_encoder(x)
context = states.clone()
return output, (states, context) # (hidden, context)

# Clone states
if isinstance(states, tuple): # Trick to save the context (last hidden state of the encoder)
context = tuple([s.clone() for s in states])
else:
context = states.clone()

return output, (states, context) # (states, context)

def forward_decoder(self, y, states):
hidden, context = states
states, context = states

# Fix "y" dimensions
if len(y.shape) == 1: # (batch) => (batch, 1)
Expand All @@ -168,46 +176,26 @@ def forward_decoder(self, y, states):
y_emb = self.trg_embeddings(y)
y_emb = self.dec_dropout(y_emb)

# Add context
tmp_context = context.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
# Add context (reduce to 1 layer)
tmp_context = context[0] if isinstance(context, tuple) else context # Get hidden state
tmp_context = tmp_context.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
y_context = torch.cat((y_emb, tmp_context), dim=2)

# intput: (batch, 1-length, emb_dim), (n_layers * n_directions, batch, hidden_dim)
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
output, hidden = self.decoder_rnn(y_context, hidden)
output, states = self.decoder_rnn(y_context, states)

# Add context
tmp_hidden = hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
tmp_hidden = states[0] if isinstance(states, tuple) else states # Get hidden state
tmp_hidden = tmp_hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
output = torch.cat((y_emb, tmp_hidden, tmp_context), dim=2)

# Get output: (batch, length, hidden_dim * n_directions) => (batch, length, trg_vocab_size)
output = self.output_layer(output)
return output, (hidden, context)

class Attention(nn.Module):

def __init__(self, encoder_hidden_dim, decoder_hidden_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn = nn.Linear((encoder_hidden_dim * 2) + decoder_hidden_dim, decoder_hidden_dim)
self.v = nn.Linear(decoder_hidden_dim, 1, bias=False)

def forward(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]

# Repeat decoder hidden state "src_len" times: (B, emb) => (B, src_len, hid_dim)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

# Compute energy
energy = torch.cat((hidden, encoder_outputs), dim=2) # => (B, L, hid_dim+hid_dim)
energy = self.attn(energy) # => (B, L, hid_dim)
energy = torch.tanh(energy)

# Compute attention
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"
return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)
return output, (states, context)

class GRUAttention(BaseRNN):
class AttentionRNN(SimpleRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
Expand All @@ -221,10 +209,29 @@ def __init__(self, *args, **kwargs):
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=False, batch_first=True)
self.attention = Attention(self.encoder_hidden_dim, self.decoder_hidden_dim)

# Attention
self.attn = nn.Linear((self.encoder_hidden_dim * 2) + self.decoder_hidden_dim, self.decoder_hidden_dim)
self.v = nn.Linear(self.decoder_hidden_dim, 1, bias=False)

self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, self.decoder_hidden_dim * self.decoder_n_layers)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim*2, self.trg_vocab_size)

def attention(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]

# Repeat decoder hidden state "src_len" times: (B, emb) => (B, src_len, hid_dim)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

# Compute energy
energy = torch.cat((hidden, encoder_outputs), dim=2) # => (B, L, hid_dim+hid_dim)
energy = self.attn(energy) # => (B, L, hid_dim)
energy = torch.tanh(energy)

# Compute attention
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"
return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)

def forward_encoder(self, x):
# input: (B, L) =>
# output: (B, L, hidden_dim * n_directions)
Expand Down
7 changes: 4 additions & 3 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
torch.set_float32_matmul_precision("high")

from autonmt.modules.models import Transformer, GenericRNN, GRU, GRUAttention
from autonmt.modules.models import *
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import AutonmtTranslator
from autonmt.vocabularies import Vocabulary
Expand All @@ -22,6 +22,7 @@
preprocess_predict_fn = lambda data, ds: preprocess_lines(data["lines"], normalize_fn=normalize_fn)

BASE_PATH = "/home/scarrion/datasets/translate" # Remote
BASE_PATH = "/Users/salvacarrion/Documents/Programming/datasets/translate" # Remote

def main():
# Create preprocessing for training
Expand Down Expand Up @@ -74,11 +75,11 @@ def main():
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = GRUAttention(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = AttentionRNN(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
run_prefix = f"{model.architecture}-2L-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_prefix = f"{model.architecture}-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_name = train_ds.get_run_name(run_prefix=run_prefix) #+ f"__{int(time.time())}"
trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab,
runs_dir=runs_dir, run_name=run_name)
Expand Down

0 comments on commit 636adb1

Please sign in to comment.