Skip to content

Commit

Permalink
Added loading from checkpoint, dirty
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh committed Feb 10, 2025
1 parent e87f39e commit 39da63a
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
7 changes: 4 additions & 3 deletions mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def train(opts):
vocabs_dict = OrderedDict()
# For creating fields, we use a task_queue_manager that doesn't filter by node and gpu
global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context)

if opts.train_from:
checkpoint = load_checkpoint(ckpt_path=opts.train_from)
vocab_size = {'src': opts.src_vocab_size or None, 'tgt': opts.tgt_vocab_size or None}
for side in ('src', 'tgt'):
for lang in global_task_queue_manager.get_langs(side):
Expand Down Expand Up @@ -231,7 +232,7 @@ def train(opts):
procs.append(
mp.Process(
target=consumer,
args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager),
args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager, checkpoint),
daemon=True,
)
)
Expand Down Expand Up @@ -274,7 +275,7 @@ def train(opts):
local_rank=0,
opts=opts
)
train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager)
train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager, checkpoint=checkpoint)


def _get_parser():
Expand Down
3 changes: 2 additions & 1 deletion mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def batch_producer(generator_to_serve, queue, semaphore, opts, device_id):
queue.put((batch, metadata, communication_batch_id))


def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager):
def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager, checkpoint):
"""Run `process_fn` on `device_id` with data from `batch_queue`."""
try:
logger.info(
Expand All @@ -271,6 +271,7 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho
batch_queue=batch_queue,
semaphore=semaphore,
task_queue_manager=task_queue_manager,
checkpoint=checkpoint,
)

except KeyboardInterrupt:
Expand Down
65 changes: 59 additions & 6 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_

from pathlib import Path
from collections import defaultdict

import mammoth.modules
Expand Down Expand Up @@ -300,16 +300,24 @@ def build_task_specific_model(
src_emb = build_src_emb(model_opts, vocab)
src_embs[lang] = src_emb
pluggable_src_emb = PluggableEmbeddings(src_embs)
encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager)
encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager, checkpoint)

for side, lang, _, vocab in task_queue_manager.get_vocabs(side='tgt', vocabs_dict=vocabs_dict):
tgt_emb = build_tgt_emb(model_opts, vocab)
tgt_embs[lang] = tgt_emb
generator = build_generator(model_opts, len(vocab), tgt_emb)
generators_md.add_module(f'generator_{lang}', generator)

if checkpoint:
trainstep = checkpoint['opt'].train_step
for modname, gen in generators_md.items():
mod_path = Path(checkpoint['opt'].save_model + f"_step_{trainstep}_{modname}.pt")
if mod_path.exists():
module = torch.load(mod_path)
gen.load_state_dict(module)
logger.info(f"Successfully loaded {modname} from the checkpoint.")
pluggable_tgt_emb = PluggableEmbeddings(tgt_embs)
decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager)
decoder = build_only_dec(model_opts, pluggable_tgt_emb, task_queue_manager, checkpoint)

# TODO: implement hierarchical approach to layer sharing
attention_bridge = AttentionBridge.from_opts(model_opts)
Expand Down Expand Up @@ -360,7 +368,7 @@ def has_grad_hook(module, input, output) -> None:
return nmt_model, generators_md


def build_only_enc(model_opts, src_emb, task_queue_manager):
def build_only_enc(model_opts, src_emb, task_queue_manager, checkpoint=None):
"""Truly only builds encoder: no embeddings"""
encoder = build_encoder(model_opts, src_emb, task_queue_manager)
if model_opts.param_init != 0.0:
Expand All @@ -373,13 +381,36 @@ def build_only_enc(model_opts, src_emb, task_queue_manager):
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))
if checkpoint:
trainstep = checkpoint['opt'].train_step
embnames = [srctgt['src_tgt'].split('-')[0] for srctgt in checkpoint['opt'].data.values()]
embnames = set(embnames)
groupnames = [
(idx, modname) for srctgt in checkpoint['opt'].data.values()
for idx, modname in enumerate(srctgt['enc_sharing_group'])
]
groupnames = set(groupnames)
# load embs
for modname in embnames:
module = torch.load(checkpoint['opt'].save_model + f"_step_{trainstep}_src_embeddings_{modname}.pt")
if f'embeddings_{modname}' in encoder.embeddings._modules.keys():
encoder.embeddings._modules[f'embeddings_{modname}'].load_state_dict(module)
logger.info(f"Successfully loaded the embeddings of {modname} from the checkpoint.")

# load layers
for idx, modname in groupnames:
mod_path = Path(checkpoint['opt'].save_model + f"_step_{trainstep}_encoder_{idx}_{modname}.pt")
if mod_path.exists() and modname in encoder.encoders._modules[str(idx)].keys():
module = torch.load(mod_path)
encoder.encoders._modules[str(idx)][modname].load_state_dict(module)
logger.info(f"Successfully loaded layer {str(idx)} of {modname} from the checkpoint.")
if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam':
encoder.half()

return encoder


def build_only_dec(model_opts, tgt_emb, task_queue_manager):
def build_only_dec(model_opts, tgt_emb, task_queue_manager, checkpoint=None):
decoder = build_decoder(model_opts, tgt_emb, task_queue_manager)
if model_opts.param_init != 0.0:
for name, p in decoder.named_parameters():
Expand All @@ -390,7 +421,29 @@ def build_only_dec(model_opts, tgt_emb, task_queue_manager):
if not ("embedding" in name and "pe" not in name and model_opts.enable_embeddingless is True):
if p.dim() > 1:
xavier_uniform_(p, gain=nn.init.calculate_gain('relu'))

if checkpoint:
trainstep = checkpoint['opt'].train_step
embnames = [srctgt['src_tgt'].split('-')[1] for srctgt in checkpoint['opt'].data.values()]
embnames = set(embnames)
groupnames = [
(idx, modname) for srctgt in checkpoint['opt'].data.values()
for idx, modname in enumerate(srctgt['dec_sharing_group'])
]
groupnames = set(groupnames)
# load embs
for modname in embnames:
if f'embeddings_{modname}' in decoder.embeddings._modules.keys():
module = torch.load(checkpoint['opt'].save_model + f"_step_{trainstep}_tgt_embeddings_{modname}.pt")
decoder.embeddings._modules[f'embeddings_{modname}'].load_state_dict(module)
logger.info(f"Successfully loaded the embeddings of {modname} from the checkpoint.")

# load layers
for idx, modname in groupnames:
mod_path = Path(checkpoint['opt'].save_model + f"_step_{trainstep}_decoder_{idx}_{modname}.pt")
if mod_path.exists() and modname in decoder.decoders._modules[str(idx)].keys():
module = torch.load(mod_path)
decoder.decoders._modules[str(idx)][modname].load_state_dict(module)
logger.info(f"Successfully loaded layer {str(idx)} of {modname} from the checkpoint.")
if model_opts.model_dtype == 'fp16' and model_opts.optim == 'fusedadam':
decoder.half()

Expand Down
6 changes: 6 additions & 0 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,12 @@ def _add_train_general_opts(parser):
default='model',
help="Model filename (the model will be saved as <save_model>_N.pt where N is the number of steps",
)
group.add(
'--trainstep',
'-trainstep',
default=0,
help="Train steps to load from.",
)
group.add(
"--save_all_gpus",
"-save_all_gpus",
Expand Down

0 comments on commit 39da63a

Please sign in to comment.