Skip to content

Commit

Permalink
address #7
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 13, 2023
1 parent a3d9583 commit 76d0aaf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
25 changes: 20 additions & 5 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,22 @@ def __init__(
min_value = 0.0,
param_or_buffer_names_no_ema = set(),
ignore_names = set(),
ignore_startswith_names = set()
ignore_startswith_names = set(),
include_online_model = True # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
):
super().__init__()
self.beta = beta
self.online_model = model

# whether to include the online model within the module tree, so that state_dict also saves it

self.include_online_model = include_online_model

if include_online_model:
self.online_model = model
else:
self.online_model = [model] # hack

# ema model

self.ema_model = ema_model

Expand Down Expand Up @@ -83,6 +94,10 @@ def __init__(
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))

@property
def model(self):
return self.online_model if self.include_online_model else self.online_model[0]

def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
Expand All @@ -100,10 +115,10 @@ def get_buffers_iter(self, model):
yield name, buffer

def copy_params_from_model_to_ema(self):
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.online_model)):
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
ma_params.data.copy_(current_params.data)

for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.online_model)):
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
ma_buffers.data.copy_(current_buffers.data)

def get_current_decay(self):
Expand All @@ -130,7 +145,7 @@ def update(self):
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))

self.update_moving_average(self.ema_model, self.online_model)
self.update_moving_average(self.ema_model, self.model)

@torch.no_grad()
def update_moving_average(self, ma_model, current_model):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.4',
version = '0.2.0',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 76d0aaf

Please sign in to comment.