From b860aa57338e8124be7cc0c6cb6292db0a7f634d Mon Sep 17 00:00:00 2001 From: salvacarrion Date: Mon, 24 Jun 2024 19:03:40 +0200 Subject: [PATCH] General fixes + LSTM (dev) --- autonmt/bundle/plots.py | 9 ++-- autonmt/bundle/utils.py | 2 +- autonmt/modules/models/__init__.py | 2 + autonmt/modules/models/lstm.py | 64 +++++++++++++++++++++++++++++ autonmt/preprocessing/builder.py | 27 ++++++++---- autonmt/preprocessing/tokenizers.py | 2 +- autonmt/toolkits/fairseq.py | 35 +++++++++++++--- 7 files changed, 122 insertions(+), 19 deletions(-) create mode 100644 autonmt/modules/models/lstm.py diff --git a/autonmt/bundle/plots.py b/autonmt/bundle/plots.py index 6e46413..c41d72b 100644 --- a/autonmt/bundle/plots.py +++ b/autonmt/bundle/plots.py @@ -73,17 +73,18 @@ def barplot(data, x, y, output_dir, fname, title="", xlabel="x", ylabel="y", asp return False # Create subplot - fig = plt.figure(figsize=(aspect_ratio[0] * size, aspect_ratio[1] * size)) + fig = plt.figure(figsize=(aspect_ratio[0], aspect_ratio[1])) sns.set(font_scale=size) # Plot barplot - g = sns.barplot(data=data, x=x, y=y) + g = sns.barplot(data=data, x=x, y=y, edgecolor="none") # Tweaks g.set(xlabel=xlabel, ylabel=ylabel) g.set_xticklabels(g.get_xticklabels(), rotation=90) - g.tick_params(axis='x', which='major', labelsize=8) # *size => because of the vocabulary distribution - g.tick_params(axis='y', which='major', labelsize=8) # *size => because of the vocabulary distribution + plt.xticks(ticks=[]) # Completely remove x-ticks + g.tick_params(axis='x', which='major', labelsize=12) # *size => because of the vocabulary distribution + g.tick_params(axis='y', which='major', labelsize=12) # *size => because of the vocabulary distribution g.yaxis.set_major_formatter(utils.human_format_int) # properties diff --git a/autonmt/bundle/utils.py b/autonmt/bundle/utils.py index 6f7e1e7..3e9bccb 100644 --- a/autonmt/bundle/utils.py +++ b/autonmt/bundle/utils.py @@ -140,7 +140,7 @@ def human_format(num, decimals=2): def human_format_int(x, *args, **kwargs): - return human_format(int(x), decimals=0) + return human_format(int(x), decimals=1) def load_json(filename): diff --git a/autonmt/modules/models/__init__.py b/autonmt/modules/models/__init__.py index 895af99..655937d 100644 --- a/autonmt/modules/models/__init__.py +++ b/autonmt/modules/models/__init__.py @@ -1 +1,3 @@ from autonmt.modules.models.transfomer import Transformer +from autonmt.modules.models.lstm import LSTM + diff --git a/autonmt/modules/models/lstm.py b/autonmt/modules/models/lstm.py new file mode 100644 index 0000000..abc84da --- /dev/null +++ b/autonmt/modules/models/lstm.py @@ -0,0 +1,64 @@ +import torch.nn as nn + +from autonmt.modules.layers import PositionalEmbedding +from autonmt.modules.seq2seq import LitSeq2Seq + + +class LSTM(LitSeq2Seq): + def __init__(self, + src_vocab_size, trg_vocab_size, + encoder_embed_dim=256, + decoder_embed_dim=256, + encoder_hidden_dim=512, + decoder_hidden_dim=512, + encoder_n_layers=2, + decoder_n_layers=2, + encoder_dropout=0.5, + decoder_dropout=0.5, + padding_idx=None, + **kwargs): + super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs) + + # Model + self.src_embeddings = nn.Embedding(src_vocab_size, encoder_embed_dim) + self.trg_embeddings = nn.Embedding(trg_vocab_size, decoder_embed_dim) + + self.encoder_dropout = nn.Dropout(encoder_dropout) + self.decoder_dropout = nn.Dropout(decoder_dropout) + + self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout) + self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout) + + self.output_layer = nn.Linear(encoder_embed_dim, trg_vocab_size) + + # Checks + assert encoder_embed_dim == decoder_embed_dim + assert encoder_hidden_dim == decoder_hidden_dim + assert encoder_n_layers == decoder_n_layers + + def forward_encoder(self, x): + # Encode src: (length, batch) => (length, batch, emb_dim) + x_emb = self.src_embeddings(x) + x_emb = self.encoder_dropout(x_emb) + + # input: (length, batch, emb_dim) + # output: (length, batch, hidden_dim * n_directions) + # hidden: (n_layers * n_directions, batch, hidden_dim) + # cell: (n_layers * n_directions, batch, hidden_dim] + outputs, (hidden, cell) = self.encoder_rnn(x_emb) + return hidden, cell + + def forward_decoder(self, y, hidden, cell): + # Encode trg: (1-length, batch) => (length, batch, emb_dim) + y_emb = self.trg_embeddings(y) + y_emb = self.decoder_dropout(y_emb) + + # (1-length, batch, emb_dim) => + # output: (length, batch, hidden_dim * n_directions) + # hidden: (n_layers * n_directions, batch, hidden_dim] + # cell: (n_layers * n_directions, batch, hidden_dim] + output, (hidden, cell) = self.decoder_rnn(y_emb, (hidden, cell)) + + # Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size) + output = self.output_layer(output) + return output diff --git a/autonmt/preprocessing/builder.py b/autonmt/preprocessing/builder.py index d3e9d52..090d7e9 100644 --- a/autonmt/preprocessing/builder.py +++ b/autonmt/preprocessing/builder.py @@ -633,7 +633,7 @@ def _plot_datasets(self, force_overwrite, save_figures=True, show_figures=False, # Set default vars if vocab_top_k is None: - vocab_top_k = [50] + vocab_top_k = [256] # Set backend if save_figures: @@ -733,13 +733,26 @@ def _plot_datasets(self, force_overwrite, save_figures=True, show_figures=False, df = df.sort_values(by='frequency', ascending=False, na_position='last') for top_k in vocab_top_k: - title = f"Vocabulary distribution (top {str(top_k)} {ds.subword_model.title()}; {lang_file})" - title = title if not add_dataset_title else f"{ds_title}:\n{title}" - p_fname = f"vocab_distr_{lang_file}_top{str(top_k)}__{suffix_fname}".lower() - plots.barplot(data=df.head(top_k), x="token", y="frequency", + # Sample a subset of words for visualization + len_vocab = len(df) + df2 = df.sample(n=top_k, random_state=1).sort_values(by='frequency', ascending=False) + + # Sampled + d = {"word": "Words", "bpe": "BPE", "char": "Chars", "bytes": "Bytes"} + title = f"Vocabulary distribution ({d[ds.subword_model]} - {len(df):,})" + p_fname = f"vocab_distr_{lang_file}_sampled{str(top_k)}__{suffix_fname}".lower() + + # Top + # title = f"Vocabulary distribution ({d[ds.subword_model]} - Top {top_k})" + # p_fname = f"vocab_distr_{lang_file}_top{str(top_k)}__{suffix_fname}".lower() + + # title = f"Vocabulary distribution ({d[ds.subword_model]} - {len_vocab:,})" + # title = f"Vocabulary distribution (top {str(top_k)} {ds.subword_model.title()}; {lang_file})" + # title = title if not add_dataset_title else f"{ds_title}:\n{title}" + plots.barplot(data=df2, x="token", y="frequency", output_dir=plots_encoded_path, fname=p_fname, - title=title, xlabel="Token frequency", ylabel="Frequency", - aspect_ratio=(6, 4), size=1.25, save_fig=save_figures, show_fig=show_figures, + title=title, xlabel="Tokens", ylabel="Frequency", + aspect_ratio=(12, 8), size=2.5, save_fig=save_figures, show_fig=show_figures, overwrite=force_overwrite) def merge_datasets(self, name="europarl", language_pair="xx-yy", dataset_size_name="original", diff --git a/autonmt/preprocessing/tokenizers.py b/autonmt/preprocessing/tokenizers.py index eea13bd..b297007 100644 --- a/autonmt/preprocessing/tokenizers.py +++ b/autonmt/preprocessing/tokenizers.py @@ -48,7 +48,7 @@ def spm_train_file(input_file, model_prefix, subword_model, vocab_size, input_se model_type=subword_model, vocab_size=vocab_size, input_sentence_size=input_sentence_size, byte_fallback=byte_fallback, character_coverage=character_coverage, split_digits=split_digits, - pad_id=3) + pad_id=3) # max_sentencepiece_length=2, def spm_encode_file(spm_model_path, input_file, output_file): diff --git a/autonmt/toolkits/fairseq.py b/autonmt/toolkits/fairseq.py index a20526d..5ad05bc 100644 --- a/autonmt/toolkits/fairseq.py +++ b/autonmt/toolkits/fairseq.py @@ -114,8 +114,6 @@ def __init__(self, wandb_params=None, **kwargs): # Vars self.wandb_params = wandb_params - if self.wandb_params: - raise ValueError("WandB monitoring is disabled for FairSeq due to a bug related to parallelization.") # Custom self.data_bin_name = "data-bin" @@ -176,12 +174,18 @@ def _preprocess(self, ds, output_path, src_lang, trg_lang, train_path, val_path, # Parse args and execute command # From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/preprocess.py input_args = sum([str(c).split(' ', 1) for c in input_args], []) # Split key/val (str) and flat list + + # Command + print("COMMAND:") + print("fairseq-preprocess " + ' '.join(input_args)) + + # Run command parser = options.get_preprocessing_parser(default_task="translation") args = parser.parse_args(args=input_args) preprocess.main(args) - def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, run_name, - resume_training, force_overwrite, **kwargs): + def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, force_overwrite, **kwargs): + wandb_params = kwargs.get("wandb_params") # Get data-bin path data_bin_path = train_ds.get_bin_data(self.engine, self.data_bin_name) @@ -196,13 +200,20 @@ def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, r print("\t- [Train]: Skipped. The checkpoint directory is not empty") return - if self.wandb_params: - print("\t\t- [WARNING]: 'wandb_params' will be ignored when using Fairseq due to some known bugs") + # if self.wandb_params: + # print("\t\t- [WARNING]: 'wandb_params' will be ignored when using Fairseq due to some known bugs") # Write command input_args = [data_bin_path] input_args += ["--save-dir", checkpoints_dir] if checkpoints_dir else [] input_args += ["--tensorboard-logdir", logs_path] if logs_path else [] + if wandb_params: + # raise ValueError("WandB monitoring is disabled for FairSeq due to a bug related to parallelization.") + print("\t\t- [WARNING]: 'wandb_params' will produce to some known bugs") + + # Set vars + input_args += ["--wandb-project", wandb_params["project"]] + os.environ["WANDB_NAME"] = self.run_name # Parse fairseq args input_args += _parse_args(max_tokens=max_tokens, batch_size=batch_size, **kwargs) @@ -214,6 +225,11 @@ def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, r if num_gpus and isinstance(num_gpus, int): os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in range(num_gpus)]) + # Command + print("COMMAND:") + print("fairseq-train " + ' '.join(input_args)) + + # Run command # From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/train.py parser = options.get_training_parser() args = options.parse_args_and_arch(parser, input_args=input_args) @@ -252,6 +268,13 @@ def _translate(self, model_ds, data_path, output_path, src_lang, trg_lang, beam_ # Parse args and execute command # From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/generate.py input_args = sum([str(c).split(' ', 1) for c in input_args], []) # Split key/val (str) and flat list + + # Command + print("COMMAND:") + print("fairseq-generate " + ' '.join(input_args)) + + # Run command + # From: https://github.com/facebookresearch/fairseq/blob/main/fairseq_cli/generate.py parser = options.get_generation_parser(default_task="translation") args = options.parse_args_and_arch(parser, input_args=input_args) generate.main(args)