diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index af2b914..027b544 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import Set, Tuple from copy import deepcopy from functools import partial @@ -69,15 +68,16 @@ def __init__( inv_gamma = 1.0, power = 2 / 3, min_value = 0.0, - param_or_buffer_names_no_ema: Set[str] = set(), - ignore_names: Set[str] = set(), - ignore_startswith_names: Set[str] = set(), + param_or_buffer_names_no_ema: set[str] = set(), + ignore_names: set[str] = set(), + ignore_startswith_names: set[str] = set(), include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally) allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor use_foreach = False, - forward_method_names: Tuple[str, ...] = (), + forward_method_names: tuple[str, ...] = (), move_ema_to_online_device = False, - coerce_dtype = False + coerce_dtype = False, + lazy_init_ema = False ): super().__init__() self.beta = beta @@ -95,29 +95,13 @@ def __init__( # ema model - self.ema_model = ema_model - - if not exists(self.ema_model): - try: - self.ema_model = deepcopy(model) - except Exception as e: - print(f'Error: While trying to deepcopy model: {e}') - print('Your model was not copyable. Please make sure you are not using any LazyLinear') - exit() - - for p in self.ema_model.parameters(): - p.detach_() + self.ema_model = None + self.forward_method_names = forward_method_names - # forwarding methods - - for forward_method_name in forward_method_names: - fn = getattr(self.ema_model, forward_method_name) - setattr(self, forward_method_name, fn) - - # parameter and buffer names - - self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} - self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} + if not lazy_init_ema: + self.init_ema(ema_model) + else: + assert not exists(ema_model) # tensor update functions @@ -163,6 +147,34 @@ def __init__( self.register_buffer('initted', torch.tensor(False)) self.register_buffer('step', torch.tensor(0)) + def init_ema( + self, + ema_model: Module | None = None + ): + self.ema_model = ema_model + + if not exists(self.ema_model): + try: + self.ema_model = deepcopy(self.model) + except Exception as e: + print(f'Error: While trying to deepcopy model: {e}') + print('Your model was not copyable. Please make sure you are not using any LazyLinear') + exit() + + for p in self.ema_model.parameters(): + p.detach_() + + # forwarding methods + + for forward_method_name in self.forward_method_names: + fn = getattr(self.ema_model, forward_method_name) + setattr(self, forward_method_name, fn) + + # parameter and buffer names + + self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} + self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} + @property def model(self): return self.online_model if self.include_online_model else self.online_model[0] @@ -220,13 +232,17 @@ def update(self): if (step % self.update_every) != 0: return - if step <= self.update_after_step: + if not self.initted.item(): + if not exists(self.ema_model): + self.init_ema() + self.copy_params_from_model_to_ema() + self.initted.data.copy_(torch.tensor(True)) return - if not self.initted.item(): + if step <= self.update_after_step: self.copy_params_from_model_to_ema() - self.initted.data.copy_(torch.tensor(True)) + return self.update_moving_average(self.ema_model, self.model) diff --git a/setup.py b/setup.py index 453ec3e..144856f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.6.3', + version = '0.6.4', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',