Skip to content

Commit

Permalink
General fixes + LSTM (dev)
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 24, 2024
1 parent 76a20c7 commit b860aa5
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 19 deletions.
9 changes: 5 additions & 4 deletions autonmt/bundle/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,18 @@ def barplot(data, x, y, output_dir, fname, title="", xlabel="x", ylabel="y", asp
return False

# Create subplot
fig = plt.figure(figsize=(aspect_ratio[0] * size, aspect_ratio[1] * size))
fig = plt.figure(figsize=(aspect_ratio[0], aspect_ratio[1]))
sns.set(font_scale=size)

# Plot barplot
g = sns.barplot(data=data, x=x, y=y)
g = sns.barplot(data=data, x=x, y=y, edgecolor="none")

# Tweaks
g.set(xlabel=xlabel, ylabel=ylabel)
g.set_xticklabels(g.get_xticklabels(), rotation=90)
g.tick_params(axis='x', which='major', labelsize=8) # *size => because of the vocabulary distribution
g.tick_params(axis='y', which='major', labelsize=8) # *size => because of the vocabulary distribution
plt.xticks(ticks=[]) # Completely remove x-ticks
g.tick_params(axis='x', which='major', labelsize=12) # *size => because of the vocabulary distribution
g.tick_params(axis='y', which='major', labelsize=12) # *size => because of the vocabulary distribution
g.yaxis.set_major_formatter(utils.human_format_int)

# properties
Expand Down
2 changes: 1 addition & 1 deletion autonmt/bundle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def human_format(num, decimals=2):


def human_format_int(x, *args, **kwargs):
return human_format(int(x), decimals=0)
return human_format(int(x), decimals=1)


def load_json(filename):
Expand Down
2 changes: 2 additions & 0 deletions autonmt/modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from autonmt.modules.models.transfomer import Transformer
from autonmt.modules.models.lstm import LSTM

64 changes: 64 additions & 0 deletions autonmt/modules/models/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch.nn as nn

from autonmt.modules.layers import PositionalEmbedding
from autonmt.modules.seq2seq import LitSeq2Seq


class LSTM(LitSeq2Seq):
def __init__(self,
src_vocab_size, trg_vocab_size,
encoder_embed_dim=256,
decoder_embed_dim=256,
encoder_hidden_dim=512,
decoder_hidden_dim=512,
encoder_n_layers=2,
decoder_n_layers=2,
encoder_dropout=0.5,
decoder_dropout=0.5,
padding_idx=None,
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs)

# Model
self.src_embeddings = nn.Embedding(src_vocab_size, encoder_embed_dim)
self.trg_embeddings = nn.Embedding(trg_vocab_size, decoder_embed_dim)

self.encoder_dropout = nn.Dropout(encoder_dropout)
self.decoder_dropout = nn.Dropout(decoder_dropout)

self.encoder_rnn = nn.LSTM(encoder_embed_dim, encoder_hidden_dim, encoder_n_layers, dropout=encoder_dropout)
self.decoder_rnn = nn.LSTM(decoder_embed_dim, decoder_hidden_dim, decoder_n_layers, dropout=decoder_dropout)

self.output_layer = nn.Linear(encoder_embed_dim, trg_vocab_size)

# Checks
assert encoder_embed_dim == decoder_embed_dim
assert encoder_hidden_dim == decoder_hidden_dim
assert encoder_n_layers == decoder_n_layers

def forward_encoder(self, x):
# Encode src: (length, batch) => (length, batch, emb_dim)
x_emb = self.src_embeddings(x)
x_emb = self.encoder_dropout(x_emb)

# input: (length, batch, emb_dim)
# output: (length, batch, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim)
# cell: (n_layers * n_directions, batch, hidden_dim]
outputs, (hidden, cell) = self.encoder_rnn(x_emb)
return hidden, cell

def forward_decoder(self, y, hidden, cell):
# Encode trg: (1-length, batch) => (length, batch, emb_dim)
y_emb = self.trg_embeddings(y)
y_emb = self.decoder_dropout(y_emb)

# (1-length, batch, emb_dim) =>
# output: (length, batch, hidden_dim * n_directions)
# hidden: (n_layers * n_directions, batch, hidden_dim]
# cell: (n_layers * n_directions, batch, hidden_dim]
output, (hidden, cell) = self.decoder_rnn(y_emb, (hidden, cell))

# Get output: (length, batch, hidden_dim * n_directions) => (length, batch, trg_vocab_size)
output = self.output_layer(output)
return output
27 changes: 20 additions & 7 deletions autonmt/preprocessing/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def _plot_datasets(self, force_overwrite, save_figures=True, show_figures=False,

# Set default vars
if vocab_top_k is None:
vocab_top_k = [50]
vocab_top_k = [256]

# Set backend
if save_figures:
Expand Down Expand Up @@ -733,13 +733,26 @@ def _plot_datasets(self, force_overwrite, save_figures=True, show_figures=False,
df = df.sort_values(by='frequency', ascending=False, na_position='last')

for top_k in vocab_top_k:
title = f"Vocabulary distribution (top {str(top_k)} {ds.subword_model.title()}; {lang_file})"
title = title if not add_dataset_title else f"{ds_title}:\n{title}"
p_fname = f"vocab_distr_{lang_file}_top{str(top_k)}__{suffix_fname}".lower()
plots.barplot(data=df.head(top_k), x="token", y="frequency",
# Sample a subset of words for visualization
len_vocab = len(df)
df2 = df.sample(n=top_k, random_state=1).sort_values(by='frequency', ascending=False)

# Sampled
d = {"word": "Words", "bpe": "BPE", "char": "Chars", "bytes": "Bytes"}
title = f"Vocabulary distribution ({d[ds.subword_model]} - {len(df):,})"
p_fname = f"vocab_distr_{lang_file}_sampled{str(top_k)}__{suffix_fname}".lower()

# Top
# title = f"Vocabulary distribution ({d[ds.subword_model]} - Top {top_k})"
# p_fname = f"vocab_distr_{lang_file}_top{str(top_k)}__{suffix_fname}".lower()

# title = f"Vocabulary distribution ({d[ds.subword_model]} - {len_vocab:,})"
# title = f"Vocabulary distribution (top {str(top_k)} {ds.subword_model.title()}; {lang_file})"
# title = title if not add_dataset_title else f"{ds_title}:\n{title}"
plots.barplot(data=df2, x="token", y="frequency",
output_dir=plots_encoded_path, fname=p_fname,
title=title, xlabel="Token frequency", ylabel="Frequency",
aspect_ratio=(6, 4), size=1.25, save_fig=save_figures, show_fig=show_figures,
title=title, xlabel="Tokens", ylabel="Frequency",
aspect_ratio=(12, 8), size=2.5, save_fig=save_figures, show_fig=show_figures,
overwrite=force_overwrite)

def merge_datasets(self, name="europarl", language_pair="xx-yy", dataset_size_name="original",
Expand Down
2 changes: 1 addition & 1 deletion autonmt/preprocessing/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def spm_train_file(input_file, model_prefix, subword_model, vocab_size, input_se
model_type=subword_model, vocab_size=vocab_size,
input_sentence_size=input_sentence_size, byte_fallback=byte_fallback,
character_coverage=character_coverage, split_digits=split_digits,
pad_id=3)
pad_id=3) # max_sentencepiece_length=2,


def spm_encode_file(spm_model_path, input_file, output_file):
Expand Down
35 changes: 29 additions & 6 deletions autonmt/toolkits/fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def __init__(self, wandb_params=None, **kwargs):

# Vars
self.wandb_params = wandb_params
if self.wandb_params:
raise ValueError("WandB monitoring is disabled for FairSeq due to a bug related to parallelization.")

# Custom
self.data_bin_name = "data-bin"
Expand Down Expand Up @@ -176,12 +174,18 @@ def _preprocess(self, ds, output_path, src_lang, trg_lang, train_path, val_path,
# Parse args and execute command
# From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/preprocess.py
input_args = sum([str(c).split(' ', 1) for c in input_args], []) # Split key/val (str) and flat list

# Command
print("COMMAND:")
print("fairseq-preprocess " + ' '.join(input_args))

# Run command
parser = options.get_preprocessing_parser(default_task="translation")
args = parser.parse_args(args=input_args)
preprocess.main(args)

def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, run_name,
resume_training, force_overwrite, **kwargs):
def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, force_overwrite, **kwargs):
wandb_params = kwargs.get("wandb_params")

# Get data-bin path
data_bin_path = train_ds.get_bin_data(self.engine, self.data_bin_name)
Expand All @@ -196,13 +200,20 @@ def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, r
print("\t- [Train]: Skipped. The checkpoint directory is not empty")
return

if self.wandb_params:
print("\t\t- [WARNING]: 'wandb_params' will be ignored when using Fairseq due to some known bugs")
# if self.wandb_params:
# print("\t\t- [WARNING]: 'wandb_params' will be ignored when using Fairseq due to some known bugs")

# Write command
input_args = [data_bin_path]
input_args += ["--save-dir", checkpoints_dir] if checkpoints_dir else []
input_args += ["--tensorboard-logdir", logs_path] if logs_path else []
if wandb_params:
# raise ValueError("WandB monitoring is disabled for FairSeq due to a bug related to parallelization.")
print("\t\t- [WARNING]: 'wandb_params' will produce to some known bugs")

# Set vars
input_args += ["--wandb-project", wandb_params["project"]]
os.environ["WANDB_NAME"] = self.run_name

# Parse fairseq args
input_args += _parse_args(max_tokens=max_tokens, batch_size=batch_size, **kwargs)
Expand All @@ -214,6 +225,11 @@ def _train(self, train_ds, checkpoints_dir, logs_path, max_tokens, batch_size, r
if num_gpus and isinstance(num_gpus, int):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in range(num_gpus)])

# Command
print("COMMAND:")
print("fairseq-train " + ' '.join(input_args))

# Run command
# From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/train.py
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, input_args=input_args)
Expand Down Expand Up @@ -252,6 +268,13 @@ def _translate(self, model_ds, data_path, output_path, src_lang, trg_lang, beam_
# Parse args and execute command
# From: https://github.com/pytorch/fairseq/blob/main/fairseq_cli/generate.py
input_args = sum([str(c).split(' ', 1) for c in input_args], []) # Split key/val (str) and flat list

# Command
print("COMMAND:")
print("fairseq-generate " + ' '.join(input_args))

# Run command
# From: https://github.com/facebookresearch/fairseq/blob/main/fairseq_cli/generate.py
parser = options.get_generation_parser(default_task="translation")
args = options.parse_args_and_arch(parser, input_args=input_args)
generate.main(args)
Expand Down

0 comments on commit b860aa5

Please sign in to comment.