Skip to content

Commit

Permalink
Merge pull request #315 from gekkom/fix-saveload
Browse files Browse the repository at this point in the history
Don't register mem as buffer
  • Loading branch information
jeshraghian authored Apr 17, 2024
2 parents bdc1b49 + 67973ed commit 7055c47
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 34 deletions.
10 changes: 3 additions & 7 deletions snntorch/_neurons/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,9 @@ def __init__(
self.state_function = self._base_int

def _init_mem(self):
syn_exc = torch.zeros(1)
syn_inh = torch.zeros(1)
mem = torch.zeros(1)

self.register_buffer("syn_exc", syn_exc)
self.register_buffer("syn_inh", syn_inh)
self.register_buffer("mem", mem)
self.syn_exc = torch.zeros(1)
self.syn_inh = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.syn_exc = torch.zeros_like(
Expand Down
3 changes: 1 addition & 2 deletions snntorch/_neurons/lapicque.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def __init__(
self.state_function = self._base_int

def _init_mem(self):
mem = torch.zeros(1)
self.register_buffer("mem", mem)
self.mem = torch.zeros(1)

def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
Expand Down
3 changes: 1 addition & 2 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
mem = torch.zeros(1)
self.register_buffer("mem", mem)
self.mem = torch.zeros(1)

def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
Expand Down
6 changes: 2 additions & 4 deletions snntorch/_neurons/rleaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,8 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
spk = torch.zeros(1)
mem = torch.zeros(1)
self.register_buffer("spk", spk)
self.register_buffer("mem", mem)
self.spk = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
Expand Down
10 changes: 3 additions & 7 deletions snntorch/_neurons/rsynaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,9 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
spk = torch.zeros(1)
syn = torch.zeros(1)
mem = torch.zeros(1)

self.register_buffer("spk", spk)
self.register_buffer("syn", syn)
self.register_buffer("mem", mem)
self.spk = torch.zeros(1)
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
Expand Down
6 changes: 2 additions & 4 deletions snntorch/_neurons/sconv2dlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,8 @@ def __init__(
)

def _init_mem(self):
syn = torch.zeros(1)
mem = torch.zeros(1)
self.register_buffer("syn", syn)
self.register_buffer("mem", mem)
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
Expand Down
6 changes: 2 additions & 4 deletions snntorch/_neurons/slstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,8 @@ def __init__(
)

def _init_mem(self):
syn = torch.zeros(1)
mem = torch.zeros(1)
self.register_buffer("syn", syn)
self.register_buffer("mem", mem)
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
Expand Down
6 changes: 2 additions & 4 deletions snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,8 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
syn = torch.zeros(1)
mem = torch.zeros(1)
self.register_buffer("syn", syn)
self.register_buffer("mem", mem)
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
Expand Down

0 comments on commit 7055c47

Please sign in to comment.