diff --git a/autonmt/modules/datasets/seq2seq_dataset.py b/autonmt/modules/datasets/seq2seq_dataset.py index 08a2093..fbfb04b 100644 --- a/autonmt/modules/datasets/seq2seq_dataset.py +++ b/autonmt/modules/datasets/seq2seq_dataset.py @@ -58,6 +58,10 @@ def collate_fn(self, batch, max_tokens=None, **kwargs): print(msg.format(drop_ratio, max_tokens)) break + # Get lengths + x_len = torch.tensor([len(x) for x in x_encoded], dtype=torch.long) + y_len = torch.tensor([len(y) for y in y_encoded], dtype=torch.long) + # Pad sequence x_padded = pad_sequence(x_encoded, batch_first=False, padding_value=self.src_vocab.pad_id).T y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T @@ -65,7 +69,7 @@ def collate_fn(self, batch, max_tokens=None, **kwargs): # Check stuff assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens - return x_padded, y_padded + return (x_padded, y_padded), (x_len, y_len) def get_collate_fn(self, max_tokens): return functools.partial(self.collate_fn, max_tokens=max_tokens) diff --git a/autonmt/modules/models/rnn.py b/autonmt/modules/models/rnn.py index b1e6e6b..d0c0ef3 100644 --- a/autonmt/modules/models/rnn.py +++ b/autonmt/modules/models/rnn.py @@ -22,9 +22,11 @@ def __init__(self, decoder_bidirectional=False, teacher_force_ratio=0.5, padding_idx=None, + packed_sequence=True, architecture="rnn", **kwargs): - super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture=architecture, **kwargs) + super().__init__(src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=packed_sequence, + architecture=architecture, **kwargs) self.encoder_embed_dim = encoder_embed_dim self.decoder_embed_dim = decoder_embed_dim self.encoder_hidden_dim = encoder_hidden_dim @@ -80,19 +82,27 @@ def get_base_rnn(architecture): return None # raise ValueError(f"Invalid architecture: {architecture}. Choose: 'rnn', 'lstm' or 'gru'") - def forward_encoder(self, x): + def forward_encoder(self, x, x_len, **kwargs): # Encode trg: (batch, length) => (batch, length, emb_dim) x_emb = self.src_embeddings(x) x_emb = self.enc_dropout(x_emb) + # Pack sequence + if self.packed_sequence: + x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=False) + # input: (length, batch, emb_dim) # output: (length, batch, hidden_dim * n_directions) # hidden: (n_layers * n_directions, batch, hidden_dim) # cell: (n_layers * n_directions, batch, hidden_dim) output, states = self.encoder_rnn(x_emb) + + # Unpack sequence + if self.packed_sequence: + output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) return output, states - def forward_decoder(self, y, states): + def forward_decoder(self, y, y_len, states, **kwargs): # Fix "y" dimensions if len(y.shape) == 1: # (batch) => (batch, 1) y = y.unsqueeze(1) @@ -113,17 +123,18 @@ def forward_decoder(self, y, states): output = self.output_layer(output) return output, states - def forward_enc_dec(self, x, y): + def forward_enc_dec(self, x, x_len, y, y_len, **kwargs): # Run encoder - _, states = self.forward_encoder(x) + _, states = self.forward_encoder(x, x_len) y_pred = y[:, 0] # outputs = [] # Doesn't contain token # Iterate over trg tokens + x_pad_mask = (x != self.padding_idx) if self.packed_sequence else None # Mask padding trg_length = y.shape[1] for t in range(trg_length): - outputs_t, states = self.forward_decoder(y_pred, states) # (B, L, E) + outputs_t, states = self.forward_decoder(y=y_pred, y_len=y_len, states=states, x_pad_mask=x_pad_mask, **kwargs) # (B, L, E) outputs.append(outputs_t) # (B, L, V) # Next input? @@ -141,19 +152,19 @@ def __init__(self, *args, architecture="gru", **kwargs): super().__init__(*args, architecture=architecture, **kwargs) base_rnn = self.get_base_rnn(self.architecture) self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim, - hidden_size=self.encoder_hidden_dim, - num_layers=self.encoder_n_layers, - dropout=self.encoder_dropout, - bidirectional=self.encoder_bidirectional, batch_first=True) + hidden_size=self.encoder_hidden_dim, + num_layers=self.encoder_n_layers, + dropout=self.encoder_dropout, + bidirectional=self.encoder_bidirectional, batch_first=True) self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim, - hidden_size=self.decoder_hidden_dim, - num_layers=self.decoder_n_layers, - dropout=self.decoder_dropout, - bidirectional=self.decoder_bidirectional, batch_first=True) - self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim*2, self.trg_vocab_size) + hidden_size=self.decoder_hidden_dim, + num_layers=self.decoder_n_layers, + dropout=self.decoder_dropout, + bidirectional=self.decoder_bidirectional, batch_first=True) + self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim * 2, self.trg_vocab_size) - def forward_encoder(self, x): - output, states = super().forward_encoder(x) + def forward_encoder(self, x, x_len, **kwargs): + output, states = super().forward_encoder(x, x_len) # Clone states if isinstance(states, tuple): # Trick to save the context (last hidden state of the encoder) @@ -163,7 +174,7 @@ def forward_encoder(self, x): return output, (states, context) # (states, context) - def forward_decoder(self, y, states): + def forward_decoder(self, y, y_len, states, **kwargs): states, context = states # Fix "y" dimensions @@ -188,7 +199,7 @@ def forward_decoder(self, y, states): # Add context tmp_hidden = states[0] if isinstance(states, tuple) else states # Get hidden state - tmp_hidden = tmp_hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer + tmp_hidden = tmp_hidden.transpose(1, 0).sum(axis=1, keepdims=True) # The paper has just 1 layer output = torch.cat((y_emb, tmp_hidden, tmp_context), dim=2) # Get output: (batch, length, hidden_dim * n_directions) => (batch, length, trg_vocab_size) @@ -201,28 +212,30 @@ def __init__(self, *args, architecture="gru", **kwargs): super().__init__(*args, architecture=architecture, **kwargs) base_rnn = self.get_base_rnn(self.architecture) self.encoder_rnn = base_rnn(input_size=self.encoder_embed_dim, - hidden_size=self.encoder_hidden_dim, - num_layers=self.encoder_n_layers, - dropout=self.encoder_dropout, - bidirectional=True, batch_first=True) - self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim*2, - hidden_size=self.decoder_hidden_dim, - num_layers=self.decoder_n_layers, - dropout=self.decoder_dropout, - bidirectional=False, batch_first=True) + hidden_size=self.encoder_hidden_dim, + num_layers=self.encoder_n_layers, + dropout=self.encoder_dropout, + bidirectional=True, batch_first=True) + self.decoder_rnn = base_rnn(input_size=self.decoder_embed_dim + self.encoder_hidden_dim * 2, + hidden_size=self.decoder_hidden_dim, + num_layers=self.decoder_n_layers, + dropout=self.decoder_dropout, + bidirectional=False, batch_first=True) # Attention self.attn = nn.Linear((self.encoder_hidden_dim * 2) + self.decoder_hidden_dim, self.decoder_hidden_dim) self.v = nn.Linear(self.decoder_hidden_dim, 1, bias=False) - self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, self.decoder_hidden_dim * self.decoder_n_layers) - self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim*2, self.trg_vocab_size) + self.enc_ffn = nn.Linear(self.encoder_hidden_dim * self.encoder_n_layers * 2, + self.decoder_hidden_dim * self.decoder_n_layers) + self.output_layer = nn.Linear(self.decoder_embed_dim + self.decoder_hidden_dim + self.encoder_hidden_dim * 2, + self.trg_vocab_size) - def forward_encoder(self, x): + def forward_encoder(self, x, x_len, **kwargs): # input: (B, L) => # output: (B, L, hidden_dim * n_directions) # hidden: (n_layers * n_directions, batch, hidden_dim) - output, states = super().forward_encoder(x) + output, states = super().forward_encoder(x, x_len) # Reshape hidden to (batch, n_layers * n_directions * hidden_dim) states = states if isinstance(states, tuple) else (states,) # Get hidden state @@ -240,8 +253,7 @@ def forward_encoder(self, x): states = tuple(states) if len(states) > 1 else states[0] return output, (states, output) - - def forward_decoder(self, y, states): + def forward_decoder(self, y, y_len, states, x_pad_mask=None, **kwargs): states, enc_outputs = states # Fix "y" dimensions @@ -255,7 +267,7 @@ def forward_decoder(self, y, states): y_emb = self.dec_dropout(y_emb) # Attention (using only the top layer of hidden state) - attn = self.attention(states, enc_outputs) + attn = self.attention(states, enc_outputs, x_pad_mask) attn = attn.unsqueeze(1) # (B, L) => (B, 1, L) weighted = torch.bmm(attn, enc_outputs) # (B, 1, L) x (B, L, H) => (B, 1, H) @@ -271,7 +283,7 @@ def forward_decoder(self, y, states): return output, (states, enc_outputs) # pass enc_outputs (trick) - def attention(self, states, encoder_outputs): + def attention(self, states, encoder_outputs, x_pad_mask): hidden = states[0][-1] if isinstance(states, tuple) else states[-1] # Get hidden state src_len = encoder_outputs.shape[1] @@ -285,4 +297,9 @@ def attention(self, states, encoder_outputs): # Compute attention attention = self.v(energy).squeeze(2) # (B, L, H) => (B, L) # "weight logits" + + # Mask attention + if x_pad_mask is not None: + attention = attention.masked_fill(x_pad_mask == 0, -1e10) + return F.softmax(attention, dim=1) # (B, L): normalized between 0..1 (attention) diff --git a/autonmt/modules/models/transfomer.py b/autonmt/modules/models/transfomer.py index 2271480..ec7ba6f 100644 --- a/autonmt/modules/models/transfomer.py +++ b/autonmt/modules/models/transfomer.py @@ -46,7 +46,7 @@ def __init__(self, assert encoder_attention_heads == decoder_attention_heads assert encoder_ffn_embed_dim == decoder_ffn_embed_dim - def forward_encoder(self, x): + def forward_encoder(self, x, x_len, **kwargs): assert x.shape[1] <= self.max_src_positions # Encode src @@ -57,7 +57,7 @@ def forward_encoder(self, x): state = self.transformer.encoder(src=x_emb, mask=None, src_key_padding_mask=None) return None, state - def forward_decoder(self, y, state): + def forward_decoder(self, y, y_len, states, **kwargs): assert y.shape[1] <= self.max_trg_positions # Encode trg @@ -68,15 +68,15 @@ def forward_decoder(self, y, state): # Make trg mask tgt_mask = self.transformer.generate_square_subsequent_mask(y_emb.shape[0]).to(y_emb.device) - output = self.transformer.decoder(tgt=y_emb, memory=state, tgt_mask=tgt_mask, memory_mask=None, + output = self.transformer.decoder(tgt=y_emb, memory=states, tgt_mask=tgt_mask, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None) # Get output output = output.transpose(0, 1) output = self.output_layer(output) - return output, state # Return state for compatibility + return output, states # Return state for compatibility - def forward_enc_dec(self, x, y): - _, states = self.forward_encoder(x) - output, _ = self.forward_decoder(y, states) + def forward_enc_dec(self, x, x_len, y, y_len, **kwargs): + _, states = self.forward_encoder(x, x_len, **kwargs) + output, _ = self.forward_decoder(y, y_len, states, **kwargs) return output diff --git a/autonmt/modules/seq2seq.py b/autonmt/modules/seq2seq.py index f34723f..c516274 100644 --- a/autonmt/modules/seq2seq.py +++ b/autonmt/modules/seq2seq.py @@ -12,11 +12,12 @@ class LitSeq2Seq(pl.LightningModule): - def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, architecture="base", **kwargs): + def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=False, architecture="base", **kwargs): super().__init__() self.src_vocab_size = src_vocab_size self.trg_vocab_size = trg_vocab_size self.padding_idx = padding_idx + self.packed_sequence = packed_sequence self.architecture = architecture # Hyperparams (PyTorch Lightning stuff) @@ -33,11 +34,15 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, architecture="ba self.validation_step_outputs = defaultdict(list) @abstractmethod - def forward_encoder(self, *args, **kwargs): + def forward_encoder(self, x, x_len, **kwargs): pass @abstractmethod - def forward_decoder(self, *args, **kwargs): + def forward_decoder(self, y, y_len, states, **kwargs): + pass + + @abstractmethod + def forward_enc_dec(self, x, x_len, y, y_len, **kwargs): pass def configure_optimizers(self): @@ -117,19 +122,13 @@ def on_validation_epoch_end(self): # Free memory self.validation_step_outputs.clear() - def forward_enc_dec(self, x, y): - values = self.forward_encoder(x) - values = self.forward_decoder(y, **values) # (B, L, E) - output = values["output"] - return output - def _step(self, batch, batch_idx, log_prefix): - x, y = batch + (x, y), (x_len, y_len) = batch # Forward => (Batch, Length) => (Batch, Length, Vocab) # The input of the decoder needs the , but its output is shifted as it starts with the first word, not # with the . Therefore, we need to remove the last token from 'y' - output = self.forward_enc_dec(x, y[:, :-1]) + output = self.forward_enc_dec(x=x, x_len=x_len, y=y[:, :-1], y_len=y_len) # Remove the token from the target y = y[:, 1:] diff --git a/autonmt/search/beam_search.py b/autonmt/search/beam_search.py index 44f2db8..18c60b2 100644 --- a/autonmt/search/beam_search.py +++ b/autonmt/search/beam_search.py @@ -20,7 +20,7 @@ def beam_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_ probabilities = [] vocab_size = len(dataset.trg_vocab) with torch.no_grad(): - for x, _ in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)): + for (x, _), (x_len, _) in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)): # Move to device x = x.to(device) @@ -30,7 +30,7 @@ def beam_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_ # dec_probs = torch.zeros(x.shape[0]).to(device) # Sentence probability # Run encoder - memory = model.forward_encoder(x) + memory = model.forward_encoder(x, x_len) # Get top k word predictions next_probabilities = model.forward_decoder(dec_idxs, memory)[:, -1, :] diff --git a/autonmt/search/greedy_search.py b/autonmt/search/greedy_search.py index cc8103e..71c7a12 100644 --- a/autonmt/search/greedy_search.py +++ b/autonmt/search/greedy_search.py @@ -17,22 +17,23 @@ def greedy_search(model, dataset, sos_id, eos_id, pad_id, batch_size, max_tokens with torch.no_grad(): outputs = [] - for x, _ in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)): + for (x, _), (x_len, _) in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)): max_gen_length = int(max_len_a*x.shape[1] + max_len_b) # Run encoder - _, states = model.forward_encoder(x.to(device)) + _, states = model.forward_encoder(x=x.to(device), x_len=x_len.to(device)) # Set start token and initial probabilities y_pred = torch.full((x.shape[0], max_gen_length), pad_id, dtype=torch.long).to(device) # (B, L) y_pred[:, 0] = sos_id # Iterate over trg tokens + x_pad_mask = (x != pad_id) if model.packed_sequence else None # Mask padding eos_mask = torch.zeros(x.shape[0], dtype=torch.bool).to(device) max_iter = 0 for i in range(1, max_gen_length): max_iter = i - outputs_t, states = model.forward_decoder(y_pred[:, :i], states) + outputs_t, states = model.forward_decoder(y=y_pred[:, :i], state=states, x_pad_mask=x_pad_mask) top1 = outputs_t[:, -1, :].argmax(1) # Get most probable next-word (logits) # Update y_pred for next iteration diff --git a/examples/dev/0_test_custom_model.py b/examples/dev/0_test_custom_model.py index 80edf85..1774376 100644 --- a/examples/dev/0_test_custom_model.py +++ b/examples/dev/0_test_custom_model.py @@ -75,7 +75,7 @@ def main(): # Instantiate vocabs and model src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang) trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang) - model = AttentionRNN(architecture="gru", src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) + model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) # Define trainer runs_dir = train_ds.get_runs_path(toolkit="autonmt")