Skip to content

Commit

Permalink
allow for lazy init of the ema model
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 29, 2024
1 parent 11c5931 commit bbff51d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
78 changes: 47 additions & 31 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
from typing import Set, Tuple

from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -69,15 +68,16 @@ def __init__(
inv_gamma = 1.0,
power = 2 / 3,
min_value = 0.0,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
param_or_buffer_names_no_ema: set[str] = set(),
ignore_names: set[str] = set(),
ignore_startswith_names: set[str] = set(),
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
use_foreach = False,
forward_method_names: Tuple[str, ...] = (),
forward_method_names: tuple[str, ...] = (),
move_ema_to_online_device = False,
coerce_dtype = False
coerce_dtype = False,
lazy_init_ema = False
):
super().__init__()
self.beta = beta
Expand All @@ -95,29 +95,13 @@ def __init__(

# ema model

self.ema_model = ema_model

if not exists(self.ema_model):
try:
self.ema_model = deepcopy(model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()

for p in self.ema_model.parameters():
p.detach_()
self.ema_model = None
self.forward_method_names = forward_method_names

# forwarding methods

for forward_method_name in forward_method_names:
fn = getattr(self.ema_model, forward_method_name)
setattr(self, forward_method_name, fn)

# parameter and buffer names

self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
if not lazy_init_ema:
self.init_ema(ema_model)
else:
assert not exists(ema_model)

# tensor update functions

Expand Down Expand Up @@ -163,6 +147,34 @@ def __init__(
self.register_buffer('initted', torch.tensor(False))
self.register_buffer('step', torch.tensor(0))

def init_ema(
self,
ema_model: Module | None = None
):
self.ema_model = ema_model

if not exists(self.ema_model):
try:
self.ema_model = deepcopy(self.model)
except Exception as e:
print(f'Error: While trying to deepcopy model: {e}')
print('Your model was not copyable. Please make sure you are not using any LazyLinear')
exit()

for p in self.ema_model.parameters():
p.detach_()

# forwarding methods

for forward_method_name in self.forward_method_names:
fn = getattr(self.ema_model, forward_method_name)
setattr(self, forward_method_name, fn)

# parameter and buffer names

self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}

@property
def model(self):
return self.online_model if self.include_online_model else self.online_model[0]
Expand Down Expand Up @@ -220,13 +232,17 @@ def update(self):
if (step % self.update_every) != 0:
return

if step <= self.update_after_step:
if not self.initted.item():
if not exists(self.ema_model):
self.init_ema()

self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
return

if not self.initted.item():
if step <= self.update_after_step:
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.tensor(True))
return

self.update_moving_average(self.ema_model, self.model)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.6.3',
version = '0.6.4',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit bbff51d

Please sign in to comment.