Skip to content

Commit

Permalink
[easy] improve hint on error message in nn.Module.load_state_dict (py…
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Jul 27, 2023
1 parent 70bc1b0 commit ca7ece9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,14 @@ def forward(self, input):
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
net2.load_state_dict(state_dict, strict=False, assign=True)

def test_load_state_dict_warn_assign(self):
with torch.device('meta'):
m = torch.nn.Linear(3, 5)
state_dict = m.state_dict()
state_dict['weight'] = torch.empty_like(state_dict['weight'], device='cpu')
with self.assertWarnsRegex(UserWarning, "for weight: copying from a non-meta parameter in the checkpoint to a meta"):
m.load_state_dict(state_dict)

def test_extra_state_missing_set_extra_state(self):

class MyModule(torch.nn.Module):
Expand Down
7 changes: 7 additions & 0 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,13 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue

if param.is_meta and not input_param.is_meta and not assign_to_params_buffers:
warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta '
'parameter in the current model, which is a no-op. (Did you mean to '
'pass `assign=True` to assign items in the state dictionary to their '
'corresponding key in the module instead of copying them in place?)')

try:
with torch.no_grad():
if assign_to_params_buffers:
Expand Down

0 comments on commit ca7ece9

Please sign in to comment.