diff --git a/test/test_nn.py b/test/test_nn.py index f277accc4299ec..d12469ee91fd2c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5288,6 +5288,17 @@ def test_batchnorm_nhwc_cuda(self): out2 = model(inp2) self.assertTrue(torch.equal(out1, out2)) + def test_batchnorm_load_state_dict(self): + bn = torch.nn.BatchNorm2d(3) + self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0)) + + bn.num_batches_tracked = torch.tensor(10) + self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10)) + + empty_dict = OrderedDict() + bn.load_state_dict(empty_dict, strict=False) + self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10)) + def test_pairwise_distance(self): input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double) input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e529dcaaf79cec..643b55819ecb02 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -105,7 +105,11 @@ def _load_from_state_dict( # this should have a default value of 0 num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key not in state_dict: - state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) + state_dict[num_batches_tracked_key] = ( + self.num_batches_tracked + if self.num_batches_tracked is not None + else torch.tensor(0, dtype=torch.long) + ) super()._load_from_state_dict( state_dict,