Skip to content

Commit

Permalink
last changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh committed Dec 11, 2024
1 parent 75d2670 commit 038c0a6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mammoth/utils/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load_parameters_from_checkpoint(
name = component.get_name()
checkpoint_path = f'{checkpoint_prefix}_{name}.pt'
if os.path.isfile(checkpoint_path):
state_dict = torch.load(checkpoint_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")
incompatible_keys = component.load_state_dict(model=model, state_dict=state_dict)
if incompatible_keys.missing_keys or incompatible_keys.unexpected_keys:
logger.info(f'Module {name} incompatible keys: {incompatible_keys}')
Expand All @@ -103,7 +103,7 @@ def load_parameters_from_checkpoint(
optimizer_path = f'{checkpoint_prefix}_{name}_optim.pt'
if os.path.isfile(optimizer_path):
# The optimizer parameters are distributed the same way as the components
optim_state_dict = torch.load(optimizer_path)
optim_state_dict = torch.load(optimizer_path, map_location="cpu")
incompatible_keys = optim.suboptimizers[name].load_state_dict(optim_state_dict)
if incompatible_keys and (incompatible_keys.missing_keys or incompatible_keys.unexpected_keys):
logger.info(f'Optim {name} incompatible keys: {incompatible_keys}')
Expand Down

0 comments on commit 038c0a6

Please sign in to comment.