From c8b6f1c4f2cc148428336e3e416821ce2ed6e12a Mon Sep 17 00:00:00 2001 From: salvacarrion Date: Tue, 2 Jul 2024 04:30:21 +0200 Subject: [PATCH] Add sort within batch --- autonmt/modules/datasets/seq2seq_dataset.py | 13 ++++++++++--- autonmt/modules/models/rnn.py | 4 +++- autonmt/modules/samplers/bucket.py | 7 ++++++- autonmt/modules/seq2seq.py | 2 +- autonmt/toolkits/autonmt.py | 11 ++++++++--- examples/dev/0_test_custom_model.py | 2 +- 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/autonmt/modules/datasets/seq2seq_dataset.py b/autonmt/modules/datasets/seq2seq_dataset.py index 5bd9fe9..162cc13 100644 --- a/autonmt/modules/datasets/seq2seq_dataset.py +++ b/autonmt/modules/datasets/seq2seq_dataset.py @@ -34,7 +34,7 @@ def __getitem__(self, idx): src_line, trg_line = self.src_lines[idx], self.trg_lines[idx] return src_line, trg_line - def collate_fn(self, batch, max_tokens=None, **kwargs): + def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs): x_encoded, y_encoded = [], [] x_max_len = y_max_len = 0 @@ -66,11 +66,18 @@ def collate_fn(self, batch, max_tokens=None, **kwargs): x_padded = pad_sequence(x_encoded, batch_first=False, padding_value=self.src_vocab.pad_id).T y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T + # Sort these tensors pairs by the length of x_len + if sort_within_batch: + x_len, x_idx = x_len.sort(0, descending=True) + y_len = y_len[x_idx] + x_padded = x_padded[x_idx] + y_padded = y_padded[x_idx] + # Check stuff assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens return (x_padded, y_padded), (x_len, y_len) - def get_collate_fn(self, max_tokens): - return functools.partial(self.collate_fn, max_tokens=max_tokens) + def get_collate_fn(self, max_tokens, sort_within_batch): + return functools.partial(self.collate_fn, max_tokens=max_tokens, sort_within_batch=sort_within_batch) diff --git a/autonmt/modules/models/rnn.py b/autonmt/modules/models/rnn.py index d0c0ef3..705736d 100644 --- a/autonmt/modules/models/rnn.py +++ b/autonmt/modules/models/rnn.py @@ -1,4 +1,6 @@ import random + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -89,7 +91,7 @@ def forward_encoder(self, x, x_len, **kwargs): # Pack sequence if self.packed_sequence: - x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=False) + x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=True) # input: (length, batch, emb_dim) # output: (length, batch, hidden_dim * n_directions) diff --git a/autonmt/modules/samplers/bucket.py b/autonmt/modules/samplers/bucket.py index 3648115..1e6cc4d 100644 --- a/autonmt/modules/samplers/bucket.py +++ b/autonmt/modules/samplers/bucket.py @@ -4,11 +4,12 @@ class BucketIterator(Sampler): - def __init__(self, data_source, batch_size, sort_key, shuffle=True): + def __init__(self, data_source, batch_size, sort_key, shuffle=True, sort_within_batch=False): super().__init__(data_source) self.data_source = data_source self.batch_size = batch_size self.shuffle = shuffle + self.sort_within_batch = sort_within_batch # Sort indices by the specified key (e.g., sequence length) self.sorted_indices = np.argsort([sort_key(x, y) for x, y in self.data_source]) @@ -26,6 +27,10 @@ def __iter__(self): shuffled_indices = torch.randperm(len(self.buckets), generator=g).tolist() self.buckets = [self.buckets[i] for i in shuffled_indices] + # Sort within each bucket if required + # if self.sort_within_batch: + # self.buckets = [sorted(bucket, key=lambda idx: len(self.data_source[idx][0].split(' ')), reverse=True) for bucket in self.buckets] + # Flatten the list of buckets into a list of indices indices = [idx for bucket in self.buckets for idx in bucket] return iter(indices) diff --git a/autonmt/modules/seq2seq.py b/autonmt/modules/seq2seq.py index c516274..b24b64f 100644 --- a/autonmt/modules/seq2seq.py +++ b/autonmt/modules/seq2seq.py @@ -17,7 +17,7 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, packed_sequence= self.src_vocab_size = src_vocab_size self.trg_vocab_size = trg_vocab_size self.padding_idx = padding_idx - self.packed_sequence = packed_sequence + self.packed_sequence = packed_sequence # Use for RNNs and to "sort within batches" self.architecture = architecture # Hyperparams (PyTorch Lightning stuff) diff --git a/autonmt/toolkits/autonmt.py b/autonmt/toolkits/autonmt.py index a4f549a..10b2423 100644 --- a/autonmt/toolkits/autonmt.py +++ b/autonmt/toolkits/autonmt.py @@ -130,9 +130,11 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs self.model._skip_val_metrics = skip_val_metrics # Dataloader: Training - sampler = BucketIterator(self.train_tds, batch_size=batch_size, sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), shuffle=True) + sampler = BucketIterator(self.train_tds, batch_size=batch_size, + sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), + sort_within_batch=self.model.packed_sequence, shuffle=True) train_loader = DataLoader(self.train_tds, - collate_fn=self.train_tds.get_collate_fn(max_tokens), sampler=sampler, + collate_fn=self.train_tds.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler, num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory, batch_size=batch_size, shuffle=False, ) # 'sampler' option is mutually exclusive with shuffle @@ -140,8 +142,11 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs # Dataloader: Validation val_loaders = [] for val_tds_i in self.val_tds: + sampler_i = BucketIterator(val_tds_i, batch_size=batch_size, + sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), + sort_within_batch=self.model.packed_sequence, shuffle=True) val_loaders.append(DataLoader(val_tds_i, - collate_fn=val_tds_i.get_collate_fn(max_tokens), sampler=None, + collate_fn=val_tds_i.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler_i, num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory, batch_size=batch_size, shuffle=False)) diff --git a/examples/dev/0_test_custom_model.py b/examples/dev/0_test_custom_model.py index 1774376..738adba 100644 --- a/examples/dev/0_test_custom_model.py +++ b/examples/dev/0_test_custom_model.py @@ -75,7 +75,7 @@ def main(): # Instantiate vocabs and model src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang) trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang) - model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) + model = AttentionRNN(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id) # Define trainer runs_dir = train_ds.get_runs_path(toolkit="autonmt")