From b2d0f3accd52908ae97d4e28cc0934c39facd196 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 3 Dec 2024 13:56:40 -0800 Subject: [PATCH] address https://github.com/lucidrains/ema-pytorch/issues/35 --- ema_pytorch/post_hoc_ema.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ema_pytorch/post_hoc_ema.py b/ema_pytorch/post_hoc_ema.py index cbff47a..543a93d 100644 --- a/ema_pytorch/post_hoc_ema.py +++ b/ema_pytorch/post_hoc_ema.py @@ -352,7 +352,7 @@ def checkpoint(self): path = self.checkpoint_folder / filename pkg = { - k: v.to(self.checkpoint_dtype) + k: v.to(device = 'cpu', dtype = self.checkpoint_dtype, copy = True) for k, v in ema_model.state_dict().items() } diff --git a/setup.py b/setup.py index 05e1251..7335d04 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.6', + version = '0.7.7', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',