Skip to content

Commit

Permalink
Improve kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jul 2, 2024
1 parent b67f44f commit e1a3082
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 61 deletions.
6 changes: 5 additions & 1 deletion autonmt/modules/datasets/seq2seq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ def collate_fn(self, batch, max_tokens=None, **kwargs):
print(msg.format(drop_ratio, max_tokens))
break

# Get lengths
x_len = torch.tensor([len(x) for x in x_encoded], dtype=torch.long)
y_len = torch.tensor([len(y) for y in y_encoded], dtype=torch.long)

# Pad sequence
x_padded = pad_sequence(x_encoded, batch_first=False, padding_value=self.src_vocab.pad_id).T
y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T

# Check stuff
assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples
assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens
return x_padded, y_padded
return (x_padded, y_padded), (x_len, y_len)

def get_collate_fn(self, max_tokens):
return functools.partial(self.collate_fn, max_tokens=max_tokens)
89 changes: 53 additions & 36 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def __init__(self,
decoder_bidirectional=False,
teacher_force_ratio=0.5,
padding_idx=None,
packed_sequence=True,
architecture="rnn",
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture=architecture, **kwargs)
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=packed_sequence,
architecture=architecture, **kwargs)
self.encoder_embed_dim = encoder_embed_dim
self.decoder_embed_dim = decoder_embed_dim
self.encoder_hidden_dim = encoder_hidden_dim
Expand Down Expand Up @@ -80,19 +82,27 @@ def get_base_rnn(architecture):
return None
# raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'")

def forward_encoder(self, x):
def forward_encoder(self, x, x_len, **kwargs):
# Encode trg: (batch, length) => (batch, length, emb_dim)
x_emb = self.src_embeddings(x)
x_emb = self.enc_dropout(x_emb)

# Pack sequence
if self.packed_sequence:
x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=False)

# input: (length, batch, emb_dim)
# output: (length, batch, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
# cell: (n_layers * n_directions, batch, hidden_dim)
output, states = self.encoder_rnn(x_emb)

# Unpack sequence
if self.packed_sequence:
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
return output, states

def forward_decoder(self, y, states):
def forward_decoder(self, y, y_len, states, **kwargs):
# Fix "y" dimensions
if len(y.shape) == 1: # (batch) => (batch, 1)
y = y.unsqueeze(1)
Expand All @@ -113,17 +123,18 @@ def forward_decoder(self, y, states):
output = self.output_layer(output)
return output, states

def forward_enc_dec(self, x, y):
def forward_enc_dec(self, x, x_len, y, y_len, **kwargs):
# Run encoder
_, states = self.forward_encoder(x)
_, states = self.forward_encoder(x, x_len)

y_pred = y[:, 0] # <sos>
outputs = [] # Doesn't contain <sos> token

# Iterate over trg tokens
x_pad_mask = (x != self.padding_idx) if self.packed_sequence else None # Mask padding
trg_length = y.shape[1]
for t in range(trg_length):
outputs_t, states = self.forward_decoder(y_pred, states) # (B, L, E)
outputs_t, states = self.forward_decoder(y=y_pred, y_len=y_len, states=states, x_pad_mask=x_pad_mask, **kwargs) # (B, L, E)
outputs.append(outputs_t) # (B, L, V)

# Next input?
Expand All @@ -141,19 +152,19 @@ def __init__(self, *args, architecture="gru", **kwargs):
super().__init__(*args, architecture=architecture, **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.encoder_bidirectional, batch_first=True)
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=self.encoder_bidirectional, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.decoder_bidirectional, batch_first=True)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim*2, self.trg_vocab_size)
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=self.decoder_bidirectional, batch_first=True)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim * 2, self.trg_vocab_size)

def forward_encoder(self, x):
output, states = super().forward_encoder(x)
def forward_encoder(self, x, x_len, **kwargs):
output, states = super().forward_encoder(x, x_len)

# Clone states
if isinstance(states, tuple): # Trick to save the context (last hidden state of the encoder)
Expand All @@ -163,7 +174,7 @@ def forward_encoder(self, x):

return output, (states, context) # (states, context)

def forward_decoder(self, y, states):
def forward_decoder(self, y, y_len, states, **kwargs):
states, context = states

# Fix "y" dimensions
Expand All @@ -188,7 +199,7 @@ def forward_decoder(self, y, states):

# Add context
tmp_hidden = states[0] if isinstance(states, tuple) else states # Get hidden state
tmp_hidden = tmp_hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
tmp_hidden = tmp_hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer
output = torch.cat((y_emb, tmp_hidden, tmp_context), dim=2)

# Get output: (batch, length, hidden_dim * n_directions) => (batch, length, trg_vocab_size)
Expand All @@ -201,28 +212,30 @@ def __init__(self, *args, architecture="gru", **kwargs):
super().__init__(*args, architecture=architecture, **kwargs)
base_rnn = self.get_base_rnn(self.architecture)
self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim,
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=True, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim*2,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=False, batch_first=True)
hidden_size=self.encoder_hidden_dim,
num_layers=self.encoder_n_layers,
dropout=self.encoder_dropout,
bidirectional=True, batch_first=True)
self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim * 2,
hidden_size=self.decoder_hidden_dim,
num_layers=self.decoder_n_layers,
dropout=self.decoder_dropout,
bidirectional=False, batch_first=True)

# Attention
self.attn = nn.Linear((self.encoder_hidden_dim * 2) + self.decoder_hidden_dim, self.decoder_hidden_dim)
self.v = nn.Linear(self.decoder_hidden_dim, 1, bias=False)

self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, self.decoder_hidden_dim * self.decoder_n_layers)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim*2, self.trg_vocab_size)
self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2,
self.decoder_hidden_dim * self.decoder_n_layers)
self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim * 2,
self.trg_vocab_size)

def forward_encoder(self, x):
def forward_encoder(self, x, x_len, **kwargs):
# input: (B, L) =>
# output: (B, L, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
output, states = super().forward_encoder(x)
output, states = super().forward_encoder(x, x_len)

# Reshape hidden to (batch, n_layers * n_directions * hidden_dim)
states = states if isinstance(states, tuple) else (states,) # Get hidden state
Expand All @@ -240,8 +253,7 @@ def forward_encoder(self, x):
states = tuple(states) if len(states) > 1 else states[0]
return output, (states, output)


def forward_decoder(self, y, states):
def forward_decoder(self, y, y_len, states, x_pad_mask=None, **kwargs):
states, enc_outputs = states

# Fix "y" dimensions
Expand All @@ -255,7 +267,7 @@ def forward_decoder(self, y, states):
y_emb = self.dec_dropout(y_emb)

# Attention (using only the top layer of hidden state)
attn = self.attention(states, enc_outputs)
attn = self.attention(states, enc_outputs, x_pad_mask)
attn = attn.unsqueeze(1) # (B, L) => (B, 1, L)
weighted = torch.bmm(attn, enc_outputs) # (B, 1, L) x (B, L, H) => (B, 1, H)

Expand All @@ -271,7 +283,7 @@ def forward_decoder(self, y, states):

return output, (states, enc_outputs) # pass enc_outputs (trick)

def attention(self, states, encoder_outputs):
def attention(self, states, encoder_outputs, x_pad_mask):
hidden = states[0][-1] if isinstance(states, tuple) else states[-1] # Get hidden state
src_len = encoder_outputs.shape[1]

Expand All @@ -285,4 +297,9 @@ def attention(self, states, encoder_outputs):

# Compute attention
attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits"

# Mask attention
if x_pad_mask is not None:
attention = attention.masked_fill(x_pad_mask == 0, -1e10)

return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention)
14 changes: 7 additions & 7 deletions autonmt/modules/models/transfomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
assert encoder_attention_heads == decoder_attention_heads
assert encoder_ffn_embed_dim == decoder_ffn_embed_dim

def forward_encoder(self, x):
def forward_encoder(self, x, x_len, **kwargs):
assert x.shape[1] <= self.max_src_positions

# Encode src
Expand All @@ -57,7 +57,7 @@ def forward_encoder(self, x):
state = self.transformer.encoder(src=x_emb, mask=None, src_key_padding_mask=None)
return None, state

def forward_decoder(self, y, state):
def forward_decoder(self, y, y_len, states, **kwargs):
assert y.shape[1] <= self.max_trg_positions

# Encode trg
Expand All @@ -68,15 +68,15 @@ def forward_decoder(self, y, state):
# Make trg mask
tgt_mask = self.transformer.generate_square_subsequent_mask(y_emb.shape[0]).to(y_emb.device)

output = self.transformer.decoder(tgt=y_emb, memory=state, tgt_mask=tgt_mask, memory_mask=None,
output = self.transformer.decoder(tgt=y_emb, memory=states, tgt_mask=tgt_mask, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None)

# Get output
output = output.transpose(0, 1)
output = self.output_layer(output)
return output, state # Return state for compatibility
return output, states # Return state for compatibility

def forward_enc_dec(self, x, y):
_, states = self.forward_encoder(x)
output, _ = self.forward_decoder(y, states)
def forward_enc_dec(self, x, x_len, y, y_len, **kwargs):
_, states = self.forward_encoder(x, x_len, **kwargs)
output, _ = self.forward_decoder(y, y_len, states, **kwargs)
return output
21 changes: 10 additions & 11 deletions autonmt/modules/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

class LitSeq2Seq(pl.LightningModule):

def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, architecture="base", **kwargs):
def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=False, architecture="base", **kwargs):
super().__init__()
self.src_vocab_size = src_vocab_size
self.trg_vocab_size = trg_vocab_size
self.padding_idx = padding_idx
self.packed_sequence = packed_sequence
self.architecture = architecture

# Hyperparams (PyTorch Lightning stuff)
Expand All @@ -33,11 +34,15 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, architecture="ba
self.validation_step_outputs = defaultdict(list)

@abstractmethod
def forward_encoder(self, *args, **kwargs):
def forward_encoder(self, x, x_len, **kwargs):
pass

@abstractmethod
def forward_decoder(self, *args, **kwargs):
def forward_decoder(self, y, y_len, states, **kwargs):
pass

@abstractmethod
def forward_enc_dec(self, x, x_len, y, y_len, **kwargs):
pass

def configure_optimizers(self):
Expand Down Expand Up @@ -117,19 +122,13 @@ def on_validation_epoch_end(self):
# Free memory
self.validation_step_outputs.clear()

def forward_enc_dec(self, x, y):
values = self.forward_encoder(x)
values = self.forward_decoder(y, **values) # (B, L, E)
output = values["output"]
return output

def _step(self, batch, batch_idx, log_prefix):
x, y = batch
(x, y), (x_len, y_len) = batch

# Forward => (Batch, Length) => (Batch, Length, Vocab)
# The input of the decoder needs the <sos>, but its output is shifted as it starts with the first word, not
# with the <sos>. Therefore, we need to remove the last token from 'y'
output = self.forward_enc_dec(x, y[:, :-1])
output = self.forward_enc_dec(x=x, x_len=x_len, y=y[:, :-1], y_len=y_len)

# Remove the <sos> token from the target
y = y[:, 1:]
Expand Down
4 changes: 2 additions & 2 deletions autonmt/search/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def beam_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_
probabilities = []
vocab_size = len(dataset.trg_vocab)
with torch.no_grad():
for x, _ in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)):
for (x, _), (x_len, _) in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)):
# Move to device
x = x.to(device)

Expand All @@ -30,7 +30,7 @@ def beam_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_
# dec_probs = torch.zeros(x.shape[0]).to(device) # Sentence probability

# Run encoder
memory = model.forward_encoder(x)
memory = model.forward_encoder(x, x_len)

# Get top k word predictions
next_probabilities = model.forward_decoder(dec_idxs, memory)[:, -1, :]
Expand Down
7 changes: 4 additions & 3 deletions autonmt/search/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@ def greedy_search(model, dataset, sos_id, eos_id, pad_id, batch_size, max_tokens

with torch.no_grad():
outputs = []
for x, _ in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)):
for (x, _), (x_len, _) in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)):
max_gen_length = int(max_len_a*x.shape[1] + max_len_b)

# Run encoder
_, states = model.forward_encoder(x.to(device))
_, states = model.forward_encoder(x=x.to(device), x_len=x_len.to(device))

# Set start token <sos> and initial probabilities
y_pred = torch.full((x.shape[0], max_gen_length), pad_id, dtype=torch.long).to(device) # (B, L)
y_pred[:, 0] = sos_id

# Iterate over trg tokens
x_pad_mask = (x != pad_id) if model.packed_sequence else None # Mask padding
eos_mask = torch.zeros(x.shape[0], dtype=torch.bool).to(device)
max_iter = 0
for i in range(1, max_gen_length):
max_iter = i
outputs_t, states = model.forward_decoder(y_pred[:, :i], states)
outputs_t, states = model.forward_decoder(y=y_pred[:, :i], state=states, x_pad_mask=x_pad_mask)
top1 = outputs_t[:, -1, :].argmax(1) # Get most probable next-word (logits)

# Update y_pred for next iteration
Expand Down
2 changes: 1 addition & 1 deletion examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = AttentionRNN(architecture="gru", src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
Expand Down

0 comments on commit e1a3082

Please sign in to comment.