Skip to content

Commit

Permalink
future annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2024
1 parent 43dc731 commit b32114f
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions ema_pytorch/post_hoc_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

import numpy as np

from typing import Set, Tuple

def exists(val):
return val is not None

Expand Down Expand Up @@ -60,9 +58,9 @@ def __init__(
ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
update_every: int = 100,
frozen: bool = False,
param_or_buffer_names_no_ema: Set[str] = set(),
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
param_or_buffer_names_no_ema: set[str] = set(),
ignore_names: set[str] = set(),
ignore_startswith_names: set[str] = set(),
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
move_ema_to_online_device = False # will move entire EMA model to the same device as online model, if different
):
Expand Down Expand Up @@ -284,8 +282,8 @@ def __init__(
self,
model: Module,
ema_model: Callable[[], Module] | None = None,
sigma_rels: Tuple[float, ...] | None = None,
gammas: Tuple[float, ...] | None = None,
sigma_rels: tuple[float, ...] | None = None,
gammas: tuple[float, ...] | None = None,
checkpoint_every_num_steps: int | Literal['manual'] = 1000,
checkpoint_folder: str = './post-hoc-ema-checkpoints',
checkpoint_dtype: torch.dtype = torch.float16,
Expand Down

0 comments on commit b32114f

Please sign in to comment.