Skip to content

Commit

Permalink
Improve bucketing experience
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jul 2, 2024
1 parent c8b6f1c commit 93c07ed
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 31 deletions.
16 changes: 8 additions & 8 deletions autonmt/modules/datasets/seq2seq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __getitem__(self, idx):
src_line, trg_line = self.src_lines[idx], self.trg_lines[idx]
return src_line, trg_line

def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs):
def collate_fn(self, batch, max_tokens=None, **kwargs):
x_encoded, y_encoded = [], []
x_max_len = y_max_len = 0

Expand Down Expand Up @@ -67,17 +67,17 @@ def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs):
y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T

# Sort these tensors pairs by the length of x_len
if sort_within_batch:
x_len, x_idx = x_len.sort(0, descending=True)
y_len = y_len[x_idx]
x_padded = x_padded[x_idx]
y_padded = y_padded[x_idx]
# if sort_within_batch:
# x_len, x_idx = x_len.sort(0, descending=True)
# y_len = y_len[x_idx]
# x_padded = x_padded[x_idx]
# y_padded = y_padded[x_idx]

# Check stuff
assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples
assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens
return (x_padded, y_padded), (x_len, y_len)

def get_collate_fn(self, max_tokens, sort_within_batch):
return functools.partial(self.collate_fn, max_tokens=max_tokens, sort_within_batch=sort_within_batch)
def get_collate_fn(self, max_tokens):
return functools.partial(self.collate_fn, max_tokens=max_tokens)

4 changes: 2 additions & 2 deletions autonmt/modules/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self,
decoder_bidirectional=False,
teacher_force_ratio=0.5,
padding_idx=None,
packed_sequence=True,
packed_sequence=False,
architecture="rnn",
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=packed_sequence,
Expand Down Expand Up @@ -90,7 +90,7 @@ def forward_encoder(self, x, x_len, **kwargs):
x_emb = self.enc_dropout(x_emb)

# Pack sequence
if self.packed_sequence:
if self.packed_sequence: # Requires bucketing
x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=True)

# input: (length, batch, emb_dim)
Expand Down
9 changes: 5 additions & 4 deletions autonmt/modules/samplers/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

class BucketIterator(Sampler):
def __init__(self, data_source, batch_size, sort_key, shuffle=True, sort_within_batch=False):
super().__init__(data_source)
super().__init__()
self.data_source = data_source
self.batch_size = batch_size
self.shuffle = shuffle
self.sort_key = sort_key
self.sort_within_batch = sort_within_batch

# Sort indices by the specified key (e.g., sequence length)
Expand All @@ -27,9 +28,9 @@ def __iter__(self):
shuffled_indices = torch.randperm(len(self.buckets), generator=g).tolist()
self.buckets = [self.buckets[i] for i in shuffled_indices]

# Sort within each bucket if required
# if self.sort_within_batch:
# self.buckets = [sorted(bucket, key=lambda idx: len(self.data_source[idx][0].split(' ')), reverse=True) for bucket in self.buckets]
# Sort within each bucket if required
if self.sort_within_batch:
self.buckets = [sorted(bucket, key=lambda idx: self.sort_key(*self.data_source[idx]), reverse=True) for bucket in self.buckets]

# Flatten the list of buckets into a list of indices
indices = [idx for bucket in self.buckets for idx in bucket]
Expand Down
2 changes: 1 addition & 1 deletion autonmt/modules/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class RandomIterator(Sampler):
def __init__(self, data_source):
super().__init__(data_source)
super().__init__()
self.data_source = data_source

def __iter__(self):
Expand Down
2 changes: 1 addition & 1 deletion autonmt/modules/samplers/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class SequentialIterator(Sampler):
def __init__(self, data_source):
super().__init__(data_source)
super().__init__()
self.data_source = data_source

def __iter__(self):
Expand Down
36 changes: 25 additions & 11 deletions autonmt/toolkits/autonmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs
comet_params = kwargs.get("comet_params")
print_samples = kwargs.get("print_samples")
skip_val_metrics = kwargs.get("skip_val_metrics")
use_bucketing = kwargs.get("use_bucketing")
mode_str = "min" if "loss" in monitor.lower() else "max"
ckpt_filename = "{epoch:03d}-{" + monitor.replace('/', '-') + ":.3f}"
pin_memory = False if kwargs.get('devices') == "cpu" else True
Expand All @@ -129,24 +130,37 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs
self.model._print_samples = print_samples
self.model._skip_val_metrics = skip_val_metrics

# Check padding
if not use_bucketing and self.model.packed_sequence:
raise ValueError("Packed sequence is only compatible with bucketing")

# Dataloader: Training
sampler = BucketIterator(self.train_tds, batch_size=batch_size,
sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')),
sort_within_batch=self.model.packed_sequence, shuffle=True)
print(f"\t- [INFO]: Preparing training dataloader... (1/1)")
sampler, shuffle = None, True
if use_bucketing:
print(f"\t\t- Preparing bucketing iterator...")
shuffle = False # 'sampler' option is mutually exclusive with shuffle (we shuffle in bucket)
sampler = BucketIterator(self.train_tds, batch_size=batch_size,
sort_key=lambda x, y: len(self.model._src_vocab.encode(x)),
sort_within_batch=self.model.packed_sequence, shuffle=True)
train_loader = DataLoader(self.train_tds,
collate_fn=self.train_tds.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler,
collate_fn=self.train_tds.get_collate_fn(max_tokens), sampler=sampler,
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=False,
) # 'sampler' option is mutually exclusive with shuffle
batch_size=batch_size, shuffle=shuffle,
)

# Dataloader: Validation
val_loaders = []
for val_tds_i in self.val_tds:
sampler_i = BucketIterator(val_tds_i, batch_size=batch_size,
sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')),
sort_within_batch=self.model.packed_sequence, shuffle=True)
for i, val_tds_i in enumerate(self.val_tds):
print(f"\t- [INFO]: Preparing validation dataloader... ({i+1}/{len(self.val_tds)})")
sampler_i = None
if use_bucketing:
print(f"\t\t- Preparing bucketing iterator...")
sampler_i = BucketIterator(val_tds_i, batch_size=batch_size,
sort_key=lambda x, y: len(self.model._src_vocab.encode(x)),
sort_within_batch=self.model.packed_sequence, shuffle=True)
val_loaders.append(DataLoader(val_tds_i,
collate_fn=val_tds_i.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler_i,
collate_fn=val_tds_i.get_collate_fn(max_tokens), sampler=sampler_i,
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=False))

Expand Down
5 changes: 2 additions & 3 deletions autonmt/toolkits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,11 @@ def _save_config(self, fname="config.json", force_overwrite=False):
make_dir(logs_path)
save_json(self.config, savepath=os.path.join(logs_path, fname), allow_overwrite=force_overwrite)


def fit(self, train_ds, max_tokens=None, batch_size=128, max_epochs=1, patience=None,
optimizer="adam", learning_rate=0.001, weight_decay=0, gradient_clip_val=0.0, accumulate_grad_batches=1,
criterion="cross_entropy", monitor="val_loss",
devices="auto", accelerator="auto", num_workers=0,
seed=None, force_overwrite=False, **kwargs):
seed=None, force_overwrite=False, use_bucketing=False, **kwargs):
print("=> [Fit]: Started.")

# Save training config
Expand All @@ -150,7 +149,7 @@ def fit(self, train_ds, max_tokens=None, batch_size=128, max_epochs=1, patience=
optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, gradient_clip_val=gradient_clip_val, accumulate_grad_batches=accumulate_grad_batches,
criterion=criterion, monitor=monitor,
devices=devices, accelerator=accelerator, num_workers=num_workers,
seed=seed, force_overwrite=force_overwrite, **kwargs)
seed=seed, force_overwrite=force_overwrite, use_bucketing=use_bucketing, **kwargs)

def predict(self, eval_datasets, metrics=None, beams=None, max_len_a=1.2, max_len_b=50,
max_tokens=None, batch_size=64,
Expand Down
2 changes: 1 addition & 1 deletion examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def main():
wandb_params = None #dict(project="architecture", entity="salvacarrion", reinit=True)
trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=1234,
patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
wandb_params=wandb_params)
wandb_params=wandb_params, use_bucketing=False)

# Test model
m_scores = trainer.predict(ts_datasets, metrics={"bleu"}, beams=[1], load_checkpoint="best",
Expand Down

0 comments on commit 93c07ed

Please sign in to comment.