Skip to content

Commit

Permalink
Improve search
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Jun 26, 2024
1 parent a3ee23b commit 7a057c3
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 61 deletions.
8 changes: 5 additions & 3 deletions autonmt/modules/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self,
teacher_force_ratio=0.5,
padding_idx=None,
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs)
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture="lstm", **kwargs)
self.teacher_forcing_ratio = teacher_force_ratio

# Model
Expand Down Expand Up @@ -54,9 +54,11 @@ def forward_encoder(self, x, **kwargs):
return output, (hidden, cell)

def forward_decoder(self, y, hidden, cell, **kwargs):
# Fix y dimensions
if len(y.shape) == 1:
# Fix "y" dimensions
if len(y.shape) == 1: # (batch) => (batch, 1)
y = y.unsqueeze(1)
if len(y.shape) == 2 and y.shape[1] > 1:
y = y[:, -1].unsqueeze(1) # Get last value

# Decode trg: (batch, 1-length) => (batch, length, emb_dim)
y_emb = self.trg_embeddings(y)
Expand Down
16 changes: 8 additions & 8 deletions autonmt/modules/models/transfomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self,
padding_idx=None,
learned=False,
**kwargs):
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, **kwargs)
super().__init__(src_vocab_size, trg_vocab_size, padding_idx, architecture="transformer", **kwargs)
self.max_src_positions = max_src_positions
self.max_trg_positions = max_trg_positions

Expand Down Expand Up @@ -54,10 +54,10 @@ def forward_encoder(self, x):
x_emb = self.src_embeddings(x)
x_emb = (x_emb + x_pos).transpose(0, 1)

memory = self.transformer.encoder(src=x_emb, mask=None, src_key_padding_mask=None)
return memory
state = self.transformer.encoder(src=x_emb, mask=None, src_key_padding_mask=None)
return None, (state,)

def forward_decoder(self, y, memory):
def forward_decoder(self, y, state):
assert y.shape[1] <= self.max_trg_positions

# Encode trg
Expand All @@ -68,15 +68,15 @@ def forward_decoder(self, y, memory):
# Make trg mask
tgt_mask = self.transformer.generate_square_subsequent_mask(y_emb.shape[0]).to(y_emb.device)

output = self.transformer.decoder(tgt=y_emb, memory=memory, tgt_mask=tgt_mask, memory_mask=None,
output = self.transformer.decoder(tgt=y_emb, memory=state, tgt_mask=tgt_mask, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None)

# Get output
output = output.transpose(0, 1)
output = self.output_layer(output)
return output
return output, (state,) # Return state for compatibility

def forward_enc_dec(self, x, y):
memory = self.forward_encoder(x)
output = self.forward_decoder(y, memory)
_, states = self.forward_encoder(x)
output, _ = self.forward_decoder(y, *states)
return output
3 changes: 2 additions & 1 deletion autonmt/modules/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

class LitSeq2Seq(pl.LightningModule):

def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, **kwargs):
def __init__(self, src_vocab_size, trg_vocab_size, padding_idx, architecture="base", **kwargs):
super().__init__()
self.src_vocab_size = src_vocab_size
self.trg_vocab_size = trg_vocab_size
self.padding_idx = padding_idx
self.architecture = architecture

# Hyperparams (PyTorch Lightning stuff)
self.strategy = None
Expand Down
9 changes: 6 additions & 3 deletions autonmt/search/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ def beam_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_
raise NotImplemented("Beam search with a width larger than '1' is currently disabled.")
model.eval()
device = next(model.parameters()).device
pin_memory = False if device.type == "cpu" else True

# Create dataloader
collate_fn = lambda x: dataset.collate_fn(x, max_tokens=max_tokens)
eval_dataloader = tud.DataLoader(dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size,
num_workers=num_workers)
eval_dataloader = tud.DataLoader(dataset,
collate_fn=dataset.get_collate_fn(max_tokens),
num_workers=num_workers, persistent_workers=bool(num_workers),
pin_memory=pin_memory,
batch_size=batch_size, shuffle=False)

idxs = []
probabilities = []
Expand Down
68 changes: 32 additions & 36 deletions autonmt/search/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,49 @@
import tqdm


def greedy_search(model, dataset, sos_id, eos_id, batch_size, max_tokens, max_len_a, max_len_b, num_workers, **kwargs):
def greedy_search(model, dataset, sos_id, eos_id, pad_id, batch_size, max_tokens, max_len_a, max_len_b, num_workers, **kwargs):
model.eval()
device = next(model.parameters()).device
pin_memory = False if device.type == "cpu" else True

# Create dataloader
collate_fn = lambda x: dataset.collate_fn(x, max_tokens=max_tokens)
eval_dataloader = tud.DataLoader(dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size,
num_workers=num_workers, pin_memory=pin_memory)
eval_dataloader = tud.DataLoader(dataset,
collate_fn=dataset.get_collate_fn(max_tokens),
num_workers=num_workers, persistent_workers=bool(num_workers),
pin_memory=pin_memory,
batch_size=batch_size, shuffle=False)

idxs = []
probabilities = []
with torch.no_grad():
outputs = []
for x, _ in tqdm.tqdm(eval_dataloader, total=len(eval_dataloader)):
# Move to device
x = x.to(device)

# Set start token <s> and initial probabilities
# Sentence generated
dec_idxs = torch.full((x.shape[0], 1), sos_id, dtype=torch.long).to(device) # Sentence tokens
dec_probs = torch.zeros(x.shape[0]).to(device) # Sentence probability
max_gen_length = int(max_len_a*x.shape[1] + max_len_b)

# Run encoder
memory = model.forward_encoder(x)
_, states = model.forward_encoder(x.to(device))

# Iterative decoder
all_eos = False
max_gen_length = int(max_len_a*x.shape[1] + max_len_b)
while not all_eos and dec_idxs.shape[1] <= max_gen_length:
# Get next token (probs + idx)
next_probabilities = model.forward_decoder(dec_idxs, memory)[:, -1].log_softmax(-1)
next_max_probabilities, next_max_idxs = next_probabilities.max(-1)

# Concat new tokens with previous tokens
next_max_idxs = next_max_idxs.unsqueeze(-1)
dec_idxs = torch.cat((dec_idxs, next_max_idxs), axis=1)
dec_probs += next_max_probabilities # Sentence probability

# Check if all sentences have an <eos>
if bool((dec_idxs == eos_id).sum(axis=1).bool().all()):
# Set start token <sos> and initial probabilities
y_pred = torch.full((x.shape[0], max_gen_length), pad_id, dtype=torch.long).to(device) # (B, L)
y_pred[:, 0] = sos_id

# Iterate over trg tokens
eos_mask = torch.zeros(x.shape[0], dtype=torch.bool).to(device)
max_iter = 0
for i in range(1, max_gen_length):
max_iter = i
outputs_t, states = model.forward_decoder(y_pred[:, :i], *states)
top1 = outputs_t[:, -1, :].argmax(1) # Get most probable next-word (logits)

# Update y_pred for next iteration
y_pred[:, i] = top1

# Check for EOS tokens
eos_mask |= (top1 == eos_id) # in-place OR

# Break if all sentences have an EOS token
if eos_mask.all():
break

# Store batch results
idxs.append(dec_idxs)
probabilities.append(dec_probs)
# Add outputs
outputs.extend(y_pred[:, :max_iter].tolist())

# Prettify output
idxs = [item for batch_idxs in idxs for item in batch_idxs.tolist()]
probabilities = torch.concat(probabilities)
return idxs, probabilities
return outputs, None
33 changes: 29 additions & 4 deletions autonmt/toolkits/autonmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@
# from torchnlp.samplers import BucketBatchSampler


def set_model_device(model, accelerator="auto"):
# Check available hardware
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"

# Choose target device
if accelerator == "auto":
device = default_device
elif accelerator in {"cuda", "gpu"}:
device = "cuda" if torch.cuda.is_available() else "cpu"
elif accelerator == "mps":
device = "mps" if torch.backends.mps.is_available() else "cpu"
else:
device = "cpu"

# Set device
if model.device.type != device:
print(f"\t-[INFO]: Setting '{device}' as the model's device")
model = model.to(device)
else:
print(f"\t-[INFO]: Model is already on '{device}' device")
return model


class AutonmtTranslator(BaseTranslator): # AutoNMT Translator

Expand All @@ -38,7 +65,6 @@ def __init__(self, model, **kwargs):
self.val_tds = None
self.test_tds = None


def _preprocess(self, train_path, val_path, test_path,
apply2train, apply2val, apply2test,
src_lang, trg_lang, src_vocab_path, trg_vocab_path,
Expand Down Expand Up @@ -174,15 +200,14 @@ def _translate(self, data_path, output_path, src_lang, trg_lang, beam_width, max
self.from_checkpoint = self.load_checkpoint(checkpoint)

# Set evaluation model
if accelerator in {"auto", "cuda", "gpu"} and self.model.device.type != "cuda":
print(f"\t-[INFO]: Setting 'cuda' as the model's device")
self.model = self.model.cuda()
self.model = set_model_device(self.model, accelerator=accelerator)

# Iterative decoding
search_algorithm = beam_search if beam_width > 1 else greedy_search
predictions, log_probabilities = search_algorithm(model=self.model, dataset=self.test_tds[filter_idx],
sos_id=self.trg_vocab.sos_id,
eos_id=self.trg_vocab.eos_id,
pad_id=self.trg_vocab.pad_id,
batch_size=batch_size, max_tokens=max_tokens,
beam_width=beam_width, max_len_a=max_len_a, max_len_b=max_len_b,
num_workers=num_workers)
Expand Down
1 change: 1 addition & 0 deletions autonmt/toolkits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _save_config(self, fname="config.json", force_overwrite=False):
make_dir(logs_path)
save_json(self.config, savepath=os.path.join(logs_path, fname), allow_overwrite=force_overwrite)


def fit(self, train_ds, max_tokens=None, batch_size=128, max_epochs=1, patience=None,
optimizer="adam", learning_rate=0.001, weight_decay=0, gradient_clip_val=0.0, accumulate_grad_batches=1,
criterion="cross_entropy", monitor="val_loss",
Expand Down
12 changes: 6 additions & 6 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
preprocess_splits_fn = lambda data, ds: preprocess_pairs(data["src"]["lines"], data["trg"]["lines"], normalize_fn=normalize_fn, shuffle_lines=False)
preprocess_predict_fn = lambda data, ds: preprocess_lines(data["lines"], normalize_fn=normalize_fn)

BASE_PATH = "/home/scarrion/datasets/translate" # Remote
BASE_PATH = "/Users/salvacarrion/Documents/Programming/datasets/translate" # Remote

def main():
# Create preprocessing for training
Expand Down Expand Up @@ -70,15 +70,15 @@ def main():
else:
raise ValueError(f"Unknown subword model: {train_ds.subword_model}")

for iters in [10]:
for iters in [3]:
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=max_tokens_src).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=max_tokens_tgt).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = LSTM(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
run_prefix = f"{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_prefix = f"{model.architecture}-{iters}ep__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_name = train_ds.get_run_name(run_prefix=run_prefix) #+ f"__{int(time.time())}"
trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab,
runs_dir=runs_dir, run_name=run_name)
Expand All @@ -91,9 +91,9 @@ def main():

# Train model
wandb_params = None #dict(project="vocab-comparison", entity="salvacarrion", reinit=True)
trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=None,
patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
wandb_params=wandb_params)
# trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=128, seed=None,
# patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
# wandb_params=wandb_params)

# Test model
m_scores = trainer.predict(ts_datasets, metrics={"bleu"}, beams=[1], load_checkpoint="best",
Expand Down

0 comments on commit 7a057c3

Please sign in to comment.