From 8bdba16a7a3f4989269ed8e900d74e1228709bea Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 4 Oct 2024 13:22:12 -0700 Subject: [PATCH] add ability to interpolate current model with ema model, to further explore some ideas set forth from hare/tortoise paper --- README.md | 11 +++++++++++ ema_pytorch/ema_pytorch.py | 11 +++++++++-- setup.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2771a2a..7e73acd 100644 --- a/README.md +++ b/README.md @@ -107,3 +107,14 @@ synthesized_ema_output = synthesized_ema(data) url = {https://api.semanticscholar.org/CorpusID:265659032} } ``` + +```bibtex +@article{Lee2024SlowAS, + title = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks}, + author = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2406.02596}, + url = {https://api.semanticscholar.org/CorpusID:270258586} +} +``` diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 027b544..7e5b430 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -216,6 +216,12 @@ def copy_params_from_ema_to_model(self): for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): copy(current_buffers.data, ma_buffers.data) + def update_model_from_ema(self, decay): + if decay == 0.: + return self.copy_params_from_ema_to_model() + + self.update_moving_average(self.model, self.ema_model, decay) + def get_current_decay(self): epoch = (self.step - self.update_after_step - 1).clamp(min = 0.) value = 1 - (1 + epoch / self.inv_gamma) ** - self.power @@ -247,7 +253,7 @@ def update(self): self.update_moving_average(self.ema_model, self.model) @torch.no_grad() - def update_moving_average(self, ma_model, current_model): + def update_moving_average(self, ma_model, current_model, current_decay = None): if self.is_frozen: return @@ -258,7 +264,8 @@ def update_moving_average(self, ma_model, current_model): # get current decay - current_decay = self.get_current_decay() + if not exists(current_decay): + current_decay = self.get_current_decay() # store all source and target tensors to copy or lerp diff --git a/setup.py b/setup.py index 144856f..b2c71d5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.6.4', + version = '0.6.5', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',