Skip to content

Commit

Permalink
Generalize RNNs
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jul 1, 2024
1 parent 636adb1 commit b67f44f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
71 changes: 40 additions & 31 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self,
decoder_bidirectional=False,
teacher_force_ratio=0.5,
padding_idx=None,
architecture="gru",
architecture="rnn",
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture=architecture, **kwargs)
self.encoder_embed_dim = encoder_embed_dim
Expand Down Expand Up @@ -137,8 +137,8 @@ def forward_enc_dec(self, x, y):


class ContextRNN(SimpleRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
def __init__(self, *args, architecture="gru", **kwargs):
super().__init__(*args, architecture=architecture, **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
Expand Down Expand Up @@ -195,9 +195,10 @@ def forward_decoder(self, y, states):
output = self.output_layer(output)
return output, (states, context)


class AttentionRNN(SimpleRNN):
def __init__(self, *args, **kwargs):
super().__init__(*args, architecture="gru", **kwargs)
def __init__(self, *args, architecture="gru", **kwargs):
super().__init__(*args, architecture=architecture, **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
Expand All @@ -217,40 +218,31 @@ def __init__(self, *args, **kwargs):
self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, self.decoder_hidden_dim * self.decoder_n_layers)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim*2, self.trg_vocab_size)

def attention(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]

# Repeat decoder hidden state "src_len" times: (B, emb) => (B, src_len, hid_dim)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

# Compute energy
energy = torch.cat((hidden, encoder_outputs), dim=2) # => (B, L, hid_dim+hid_dim)
energy = self.attn(energy) # => (B, L, hid_dim)
energy = torch.tanh(energy)

# Compute attention
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"
return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)

def forward_encoder(self, x):
# input: (B, L) =>
# output: (B, L, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
output, hidden = super().forward_encoder(x)
output, states = super().forward_encoder(x)

# Reshape hidden to (batch, n_layers * n_directions * hidden_dim)
hidden = hidden.transpose(0, 1).contiguous().view(hidden.size(1), -1)
states = states if isinstance(states, tuple) else (states,) # Get hidden state
states = list(states)
for i in range(len(states)):
states[i] = states[i].transpose(0, 1).contiguous().view(states[i].size(1), -1)

# Apply the linear transformation
hidden = torch.tanh(self.enc_ffn(hidden))
# Apply the linear transformation
states[i] = torch.tanh(self.enc_ffn(states[i]))

# Reshape back to (n_layers, batch, decoder_hidden_dim)
hidden = hidden.view(self.encoder_n_layers, -1, self.decoder_hidden_dim)
# Reshape back to (n_layers, batch, decoder_hidden_dim)
states[i] = states[i].view(self.encoder_n_layers, -1, self.decoder_hidden_dim)

# Fix states shape
states = tuple(states) if len(states) > 1 else states[0]
return output, (states, output)

return output, (hidden, output)

def forward_decoder(self, y, states):
hidden, enc_outputs = states
states, enc_outputs = states

# Fix "y" dimensions
if len(y.shape) == 1: # (batch) => (batch, 1)
Expand All @@ -263,17 +255,34 @@ def forward_decoder(self, y, states):
y_emb = self.dec_dropout(y_emb)

# Attention (using only the top layer of hidden state)
attn = self.attention(hidden[-1], enc_outputs)
attn = self.attention(states, enc_outputs)
attn = attn.unsqueeze(1) # (B, L) => (B, 1, L)
weighted = torch.bmm(attn, enc_outputs) # (B, 1, L) x (B, L, H) => (B, 1, H)

# intput: (batch, 1-length, emb_dim+w_emb_dim), (1, batch, hidden_dim)
# output: (batch, length, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
rnn_input = torch.cat((y_emb, weighted), dim=2)
output, hidden = self.decoder_rnn(rnn_input, hidden)
output, states = self.decoder_rnn(rnn_input, states)

# Get output: => (B, 1-length, H+H+H)
output = torch.cat((output, weighted, y_emb), dim=2)
output = self.output_layer(output) # (B, 1, H+H+H) => (B, 1, V)
return output, (hidden, enc_outputs) # pass enc_outputs (trick)

return output, (states, enc_outputs) # pass enc_outputs (trick)

def attention(self, states, encoder_outputs):
hidden = states[0][-1] if isinstance(states, tuple) else states[-1] # Get hidden state
src_len = encoder_outputs.shape[1]

# Repeat decoder hidden state "src_len" times: (B, emb) => (B, src_len, hid_dim)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

# Compute energy
energy = torch.cat((hidden, encoder_outputs), dim=2) # => (B, L, hid_dim+hid_dim)
energy = self.attn(energy) # => (B, L, hid_dim)
energy = torch.tanh(energy)

# Compute attention
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"
return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)
2 changes: 1 addition & 1 deletion examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = AttentionRNN(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = AttentionRNN(architecture="gru", src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
Expand Down

0 comments on commit b67f44f

Please sign in to comment.