Skip to content

Commit

Permalink
Add sort within batch
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jul 2, 2024
1 parent db7c245 commit c8b6f1c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 10 deletions.
13 changes: 10 additions & 3 deletions autonmt/modules/datasets/seq2seq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __getitem__(self, idx):
src_line, trg_line = self.src_lines[idx], self.trg_lines[idx]
return src_line, trg_line

def collate_fn(self, batch, max_tokens=None, **kwargs):
def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs):
x_encoded, y_encoded = [], []
x_max_len = y_max_len = 0

Expand Down Expand Up @@ -66,11 +66,18 @@ def collate_fn(self, batch, max_tokens=None, **kwargs):
x_padded = pad_sequence(x_encoded, batch_first=False, padding_value=self.src_vocab.pad_id).T
y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T

# Sort these tensors pairs by the length of x_len
if sort_within_batch:
x_len, x_idx = x_len.sort(0, descending=True)
y_len = y_len[x_idx]
x_padded = x_padded[x_idx]
y_padded = y_padded[x_idx]

# Check stuff
assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples
assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens
return (x_padded, y_padded), (x_len, y_len)

def get_collate_fn(self, max_tokens):
return functools.partial(self.collate_fn, max_tokens=max_tokens)
def get_collate_fn(self, max_tokens, sort_within_batch):
return functools.partial(self.collate_fn, max_tokens=max_tokens, sort_within_batch=sort_within_batch)

4 changes: 3 additions & 1 deletion autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -89,7 +91,7 @@ def forward_encoder(self, x, x_len, **kwargs):

# Pack sequence
if self.packed_sequence:
x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=False)
x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=True)

# input: (length, batch, emb_dim)
# output: (length, batch, hidden_dim * n_directions)
Expand Down
7 changes: 6 additions & 1 deletion autonmt/modules/samplers/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@


class BucketIterator(Sampler):
def __init__(self, data_source, batch_size, sort_key, shuffle=True):
def __init__(self, data_source, batch_size, sort_key, shuffle=True, sort_within_batch=False):
super().__init__(data_source)
self.data_source = data_source
self.batch_size = batch_size
self.shuffle = shuffle
self.sort_within_batch = sort_within_batch

# Sort indices by the specified key (e.g., sequence length)
self.sorted_indices = np.argsort([sort_key(x, y) for x, y in self.data_source])
Expand All @@ -26,6 +27,10 @@ def __iter__(self):
shuffled_indices = torch.randperm(len(self.buckets), generator=g).tolist()
self.buckets = [self.buckets[i] for i in shuffled_indices]

# Sort within each bucket if required
# if self.sort_within_batch:
# self.buckets = [sorted(bucket, key=lambda idx: len(self.data_source[idx][0].split(' ')), reverse=True) for bucket in self.buckets]

# Flatten the list of buckets into a list of indices
indices = [idx for bucket in self.buckets for idx in bucket]
return iter(indices)
Expand Down
2 changes: 1 addition & 1 deletion autonmt/modules/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=
self.src_vocab_size = src_vocab_size
self.trg_vocab_size = trg_vocab_size
self.padding_idx = padding_idx
self.packed_sequence = packed_sequence
self.packed_sequence = packed_sequence # Use for RNNs and to "sort within batches"
self.architecture = architecture

# Hyperparams (PyTorch Lightning stuff)
Expand Down
11 changes: 8 additions & 3 deletions autonmt/toolkits/autonmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,23 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs
self.model._skip_val_metrics = skip_val_metrics

# Dataloader: Training
sampler = BucketIterator(self.train_tds, batch_size=batch_size, sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), shuffle=True)
sampler = BucketIterator(self.train_tds, batch_size=batch_size,
sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')),
sort_within_batch=self.model.packed_sequence, shuffle=True)
train_loader = DataLoader(self.train_tds,
collate_fn=self.train_tds.get_collate_fn(max_tokens), sampler=sampler,
collate_fn=self.train_tds.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler,
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=False,
) # 'sampler' option is mutually exclusive with shuffle

# Dataloader: Validation
val_loaders = []
for val_tds_i in self.val_tds:
sampler_i = BucketIterator(val_tds_i, batch_size=batch_size,
sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')),
sort_within_batch=self.model.packed_sequence, shuffle=True)
val_loaders.append(DataLoader(val_tds_i,
collate_fn=val_tds_i.get_collate_fn(max_tokens), sampler=None,
collate_fn=val_tds_i.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler_i,
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=False))

Expand Down
2 changes: 1 addition & 1 deletion examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)
model = AttentionRNN(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
Expand Down

0 comments on commit c8b6f1c

Please sign in to comment.