diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 8852210..98d8390 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -94,8 +94,8 @@ def __init__( # parameter and buffer names - self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype in [torch.float, torch.float16]} - self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype in [torch.float, torch.float16]} + self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} + self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} # tensor update functions diff --git a/ema_pytorch/post_hoc_ema.py b/ema_pytorch/post_hoc_ema.py index 61a0b07..9c22730 100644 --- a/ema_pytorch/post_hoc_ema.py +++ b/ema_pytorch/post_hoc_ema.py @@ -89,8 +89,8 @@ def __init__( # parameter and buffer names - self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype in [torch.float, torch.float16]} - self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype in [torch.float, torch.float16]} + self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} + self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} # tensor update functions @@ -402,4 +402,4 @@ def synthesize_ema_model( return synthesized_ema_model def __call__(self, *args, **kwargs): - return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models) \ No newline at end of file + return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models) diff --git a/setup.py b/setup.py index 396d867..3cac31e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.4.2', + version = '0.4.3', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',