diff --git a/ema_pytorch/post_hoc_ema.py b/ema_pytorch/post_hoc_ema.py index 7994078..8e20442 100644 --- a/ema_pytorch/post_hoc_ema.py +++ b/ema_pytorch/post_hoc_ema.py @@ -389,7 +389,7 @@ def synthesize_ema_model( # load checkpoint into a temporary ema model - ckpt_state_dict = torch.load(str(checkpoint)) + ckpt_state_dict = torch.load(str(checkpoint), weights_only=True) tmp_ema_model.load_state_dict(ckpt_state_dict) # add weighted checkpoint to synthesized