Skip to content

Commit

Permalink
Fix num_batches_tracked of BatchNorm when load_state_dict (pytorch#11…
Browse files Browse the repository at this point in the history
…0850)

Fixes pytorch#110361

as the title shown

Pull Request resolved: pytorch#110850
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
FFFrog authored and pytorchmergebot committed Oct 24, 2023
1 parent 30cbd2e commit 0e0f6a2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,6 +5288,17 @@ def test_batchnorm_nhwc_cuda(self):
out2 = model(inp2)
self.assertTrue(torch.equal(out1, out2))

def test_batchnorm_load_state_dict(self):
bn = torch.nn.BatchNorm2d(3)
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0))

bn.num_batches_tracked = torch.tensor(10)
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

empty_dict = OrderedDict()
bn.load_state_dict(empty_dict, strict=False)
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

def test_pairwise_distance(self):
input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
Expand Down
6 changes: 5 additions & 1 deletion torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def _load_from_state_dict(
# this should have a default value of 0
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key not in state_dict:
state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
state_dict[num_batches_tracked_key] = (
self.num_batches_tracked
if self.num_batches_tracked is not None
else torch.tensor(0, dtype=torch.long)
)

super()._load_from_state_dict(
state_dict,
Expand Down

0 comments on commit 0e0f6a2

Please sign in to comment.