diff --git a/autonmt/modules/datasets/seq2seq_dataset.py b/autonmt/modules/datasets/seq2seq_dataset.py index 162cc13..b870502 100644 --- a/autonmt/modules/datasets/seq2seq_dataset.py +++ b/autonmt/modules/datasets/seq2seq_dataset.py @@ -34,7 +34,7 @@ def __getitem__(self, idx): src_line, trg_line = self.src_lines[idx], self.trg_lines[idx] return src_line, trg_line - def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs): + def collate_fn(self, batch, max_tokens=None, **kwargs): x_encoded, y_encoded = [], [] x_max_len = y_max_len = 0 @@ -67,17 +67,17 @@ def collate_fn(self, batch, max_tokens=None, sort_within_batch=False, **kwargs): y_padded = pad_sequence(y_encoded, batch_first=False, padding_value=self.trg_vocab.pad_id).T # Sort these tensors pairs by the length of x_len - if sort_within_batch: - x_len, x_idx = x_len.sort(0, descending=True) - y_len = y_len[x_idx] - x_padded = x_padded[x_idx] - y_padded = y_padded[x_idx] + # if sort_within_batch: + # x_len, x_idx = x_len.sort(0, descending=True) + # y_len = y_len[x_idx] + # x_padded = x_padded[x_idx] + # y_padded = y_padded[x_idx] # Check stuff assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens return (x_padded, y_padded), (x_len, y_len) - def get_collate_fn(self, max_tokens, sort_within_batch): - return functools.partial(self.collate_fn, max_tokens=max_tokens, sort_within_batch=sort_within_batch) + def get_collate_fn(self, max_tokens): + return functools.partial(self.collate_fn, max_tokens=max_tokens) diff --git a/autonmt/modules/models/rnn.py b/autonmt/modules/models/rnn.py index 705736d..c5330b7 100644 --- a/autonmt/modules/models/rnn.py +++ b/autonmt/modules/models/rnn.py @@ -24,7 +24,7 @@ def __init__(self, decoder_bidirectional=False, teacher_force_ratio=0.5, padding_idx=None, - packed_sequence=True, + packed_sequence=False, architecture="rnn", **kwargs): super().__init__(src_vocab_size, trg_vocab_size, padding_idx, packed_sequence=packed_sequence, @@ -90,7 +90,7 @@ def forward_encoder(self, x, x_len, **kwargs): x_emb = self.enc_dropout(x_emb) # Pack sequence - if self.packed_sequence: + if self.packed_sequence: # Requires bucketing x_emb = nn.utils.rnn.pack_padded_sequence(x_emb, x_len.to('cpu'), batch_first=True, enforce_sorted=True) # input: (length, batch, emb_dim) diff --git a/autonmt/modules/samplers/bucket.py b/autonmt/modules/samplers/bucket.py index 1e6cc4d..9a4547d 100644 --- a/autonmt/modules/samplers/bucket.py +++ b/autonmt/modules/samplers/bucket.py @@ -5,10 +5,11 @@ class BucketIterator(Sampler): def __init__(self, data_source, batch_size, sort_key, shuffle=True, sort_within_batch=False): - super().__init__(data_source) + super().__init__() self.data_source = data_source self.batch_size = batch_size self.shuffle = shuffle + self.sort_key = sort_key self.sort_within_batch = sort_within_batch # Sort indices by the specified key (e.g., sequence length) @@ -27,9 +28,9 @@ def __iter__(self): shuffled_indices = torch.randperm(len(self.buckets), generator=g).tolist() self.buckets = [self.buckets[i] for i in shuffled_indices] - # Sort within each bucket if required - # if self.sort_within_batch: - # self.buckets = [sorted(bucket, key=lambda idx: len(self.data_source[idx][0].split(' ')), reverse=True) for bucket in self.buckets] + # Sort within each bucket if required + if self.sort_within_batch: + self.buckets = [sorted(bucket, key=lambda idx: self.sort_key(*self.data_source[idx]), reverse=True) for bucket in self.buckets] # Flatten the list of buckets into a list of indices indices = [idx for bucket in self.buckets for idx in bucket] diff --git a/autonmt/modules/samplers/random.py b/autonmt/modules/samplers/random.py index 4e72caa..2bee5b1 100644 --- a/autonmt/modules/samplers/random.py +++ b/autonmt/modules/samplers/random.py @@ -4,7 +4,7 @@ class RandomIterator(Sampler): def __init__(self, data_source): - super().__init__(data_source) + super().__init__() self.data_source = data_source def __iter__(self): diff --git a/autonmt/modules/samplers/sequential.py b/autonmt/modules/samplers/sequential.py index 9ef4f3d..0840d83 100644 --- a/autonmt/modules/samplers/sequential.py +++ b/autonmt/modules/samplers/sequential.py @@ -4,7 +4,7 @@ class SequentialIterator(Sampler): def __init__(self, data_source): - super().__init__(data_source) + super().__init__() self.data_source = data_source def __iter__(self): diff --git a/autonmt/toolkits/autonmt.py b/autonmt/toolkits/autonmt.py index 10b2423..0bd79e8 100644 --- a/autonmt/toolkits/autonmt.py +++ b/autonmt/toolkits/autonmt.py @@ -109,6 +109,7 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs comet_params = kwargs.get("comet_params") print_samples = kwargs.get("print_samples") skip_val_metrics = kwargs.get("skip_val_metrics") + use_bucketing = kwargs.get("use_bucketing") mode_str = "min" if "loss" in monitor.lower() else "max" ckpt_filename = "{epoch:03d}-{" + monitor.replace('/', '-') + ":.3f}" pin_memory = False if kwargs.get('devices') == "cpu" else True @@ -129,24 +130,37 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs self.model._print_samples = print_samples self.model._skip_val_metrics = skip_val_metrics + # Check padding + if not use_bucketing and self.model.packed_sequence: + raise ValueError("Packed sequence is only compatible with bucketing") + # Dataloader: Training - sampler = BucketIterator(self.train_tds, batch_size=batch_size, - sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), - sort_within_batch=self.model.packed_sequence, shuffle=True) + print(f"\t- [INFO]: Preparing training dataloader... (1/1)") + sampler, shuffle = None, True + if use_bucketing: + print(f"\t\t- Preparing bucketing iterator...") + shuffle = False # 'sampler' option is mutually exclusive with shuffle (we shuffle in bucket) + sampler = BucketIterator(self.train_tds, batch_size=batch_size, + sort_key=lambda x, y: len(self.model._src_vocab.encode(x)), + sort_within_batch=self.model.packed_sequence, shuffle=True) train_loader = DataLoader(self.train_tds, - collate_fn=self.train_tds.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler, + collate_fn=self.train_tds.get_collate_fn(max_tokens), sampler=sampler, num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory, - batch_size=batch_size, shuffle=False, - ) # 'sampler' option is mutually exclusive with shuffle + batch_size=batch_size, shuffle=shuffle, + ) # Dataloader: Validation val_loaders = [] - for val_tds_i in self.val_tds: - sampler_i = BucketIterator(val_tds_i, batch_size=batch_size, - sort_key=lambda x, y: len(x.split(' ')) + len(y.split(' ')), - sort_within_batch=self.model.packed_sequence, shuffle=True) + for i, val_tds_i in enumerate(self.val_tds): + print(f"\t- [INFO]: Preparing validation dataloader... ({i+1}/{len(self.val_tds)})") + sampler_i = None + if use_bucketing: + print(f"\t\t- Preparing bucketing iterator...") + sampler_i = BucketIterator(val_tds_i, batch_size=batch_size, + sort_key=lambda x, y: len(self.model._src_vocab.encode(x)), + sort_within_batch=self.model.packed_sequence, shuffle=True) val_loaders.append(DataLoader(val_tds_i, - collate_fn=val_tds_i.get_collate_fn(max_tokens, sort_within_batch=self.model.packed_sequence), sampler=sampler_i, + collate_fn=val_tds_i.get_collate_fn(max_tokens), sampler=sampler_i, num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory, batch_size=batch_size, shuffle=False)) diff --git a/autonmt/toolkits/base.py b/autonmt/toolkits/base.py index e5f9617..e4497e4 100644 --- a/autonmt/toolkits/base.py +++ b/autonmt/toolkits/base.py @@ -130,12 +130,11 @@ def _save_config(self, fname="config.json", force_overwrite=False): make_dir(logs_path) save_json(self.config, savepath=os.path.join(logs_path, fname), allow_overwrite=force_overwrite) - def fit(self, train_ds, max_tokens=None, batch_size=128, max_epochs=1, patience=None, optimizer="adam", learning_rate=0.001, weight_decay=0, gradient_clip_val=0.0, accumulate_grad_batches=1, criterion="cross_entropy", monitor="val_loss", devices="auto", accelerator="auto", num_workers=0, - seed=None, force_overwrite=False, **kwargs): + seed=None, force_overwrite=False, use_bucketing=False, **kwargs): print("=> [Fit]: Started.") # Save training config @@ -150,7 +149,7 @@ def fit(self, train_ds, max_tokens=None, batch_size=128, max_epochs=1, patience= optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, gradient_clip_val=gradient_clip_val, accumulate_grad_batches=accumulate_grad_batches, criterion=criterion, monitor=monitor, devices=devices, accelerator=accelerator, num_workers=num_workers, - seed=seed, force_overwrite=force_overwrite, **kwargs) + seed=seed, force_overwrite=force_overwrite, use_bucketing=use_bucketing, **kwargs) def predict(self, eval_datasets, metrics=None, beams=None, max_len_a=1.2, max_len_b=50, max_tokens=None, batch_size=64, diff --git a/examples/dev/0_test_custom_model.py b/examples/dev/0_test_custom_model.py index 738adba..1c297ca 100644 --- a/examples/dev/0_test_custom_model.py +++ b/examples/dev/0_test_custom_model.py @@ -94,7 +94,7 @@ def main(): wandb_params = None #dict(project="architecture", entity="salvacarrion", reinit=True) trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=1234, patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1, - wandb_params=wandb_params) + wandb_params=wandb_params, use_bucketing=False) # Test model m_scores = trainer.predict(ts_datasets, metrics={"bleu"}, beams=[1], load_checkpoint="best",