Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 26, 2024
1 parent 2d785c3 commit a3ee23b
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 19 deletions.
4 changes: 4 additions & 0 deletions autonmt/modules/datasets/seq2seq_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import functools
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
Expand Down Expand Up @@ -65,3 +66,6 @@ def collate_fn(self, batch, max_tokens=None, **kwargs):
assert x_padded.shape[0] == y_padded.shape[0] == len(x_encoded) # Control samples
assert max_tokens is None or (x_padded.numel() + y_padded.numel()) <= max_tokens # Control max tokens
return x_padded, y_padded

def get_collate_fn(self, max_tokens):
return functools.partial(self.collate_fn, max_tokens=max_tokens)
7 changes: 0 additions & 7 deletions autonmt/toolkits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
from autonmt.toolkits.autonmt import AutonmtTranslator

try:
from autonmt.toolkits.fairseq import FairseqTranslator
except ImportError as e:
print("WARNING: Fairseq toolkit could not be loaded. FairseqTranslator will not be available.")
except Exception as e:
print("WARNING: Fairseq toolkit could not be loaded. FairseqTranslator will not be available.")
print(e)
12 changes: 7 additions & 5 deletions autonmt/toolkits/autonmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# from torchnlp.samplers import BucketBatchSampler



class AutonmtTranslator(BaseTranslator): # AutoNMT Translator

def __init__(self, model, **kwargs):
Expand Down Expand Up @@ -103,17 +104,18 @@ def _train(self, train_ds, checkpoints_dir, logs_path, force_overwrite, **kwargs

# Dataloader: Training
train_loader = DataLoader(self.train_tds,
collate_fn=lambda x: self.train_tds.collate_fn(x, max_tokens=max_tokens),
num_workers=num_workers, pin_memory=pin_memory,
collate_fn=self.train_tds.get_collate_fn(max_tokens),
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=True,
)

# Dataloader: Validation
val_loaders = []
for val_tds_i in self.val_tds:
val_loaders.append(DataLoader(val_tds_i, shuffle=False,
collate_fn=lambda x: val_tds_i.collate_fn(x, max_tokens=max_tokens),
batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory))
val_loaders.append(DataLoader(val_tds_i,
collate_fn=val_tds_i.get_collate_fn(max_tokens),
num_workers=num_workers, persistent_workers=bool(num_workers), pin_memory=pin_memory,
batch_size=batch_size, shuffle=False))

# Callbacks: Checkpoint
ckpt_p = {}
Expand Down
18 changes: 12 additions & 6 deletions autonmt/toolkits/fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from autonmt.bundle import utils
from autonmt.toolkits.base import BaseTranslator

import torch
import fairseq_cli
from fairseq import options
from fairseq_cli import preprocess, train, generate
from fairseq.distributed import utils as distributed_utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
try:
import fairseq_cli
from fairseq import options
from fairseq_cli import preprocess, train, generate
from fairseq.distributed import utils as distributed_utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
except ImportError as e:
print("WARNING: Fairseq toolkit could not be loaded. FairseqTranslator will not be available.")
except Exception as e:
print("WARNING: Fairseq toolkit could not be loaded. FairseqTranslator will not be available.")
print(e)



def _parse_args(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion examples/2_fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from autonmt.modules.models import Transformer
from autonmt.preprocessing import DatasetBuilder
from autonmt.toolkits import FairseqTranslator
from autonmt.toolkits.fairseq import FairseqTranslator
from autonmt.vocabularies import Vocabulary

from autonmt.bundle.report import generate_report
Expand Down

0 comments on commit a3ee23b

Please sign in to comment.