Skip to content

Commit

Permalink
black code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Jan 20, 2021
1 parent 22a967c commit 4fe3ee5
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 209 deletions.
89 changes: 66 additions & 23 deletions snntorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

class LIF(nn.Module):
"""Parent class for leaky integrate and fire neuron models."""

instances = []

def __init__(self, alpha, beta, threshold=1.0, spike_grad=None):
super(LIF, self).__init__()
LIF.instances.append(self)
Expand All @@ -30,13 +32,13 @@ def fire(self, mem):
mem_shift = mem - self.threshold
spk = self.spike_grad(mem_shift).to(device)
reset = torch.zeros_like(mem)
spk_idx = (mem_shift > 0)
spk_idx = mem_shift > 0
reset[spk_idx] = torch.ones_like(mem)[spk_idx]
return spk, reset

@classmethod
def clear_instances(cls):
cls.instances = []
cls.instances = []

@staticmethod
def init_stein(batch_size, *args):
Expand Down Expand Up @@ -71,7 +73,7 @@ def detach(*args):
@staticmethod
def zeros(*args):
"""Used to clear hidden state variables to zero.
Intended for use where hidden state variables are global variables."""
Intended for use where hidden state variables are global variables."""
for state in args:
state = torch.zeros_like(state)

Expand All @@ -91,12 +93,13 @@ def forward(ctx, input_):

@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input_ < 0] = 0.0
grad = grad_input
return grad


# Neuron Models


Expand All @@ -111,7 +114,16 @@ class Stein(LIF):
R. B. Stein (1965) A theoretical analysis of neuron variability. Biophys. J. 5, pp. 173-194.
R. B. Stein (1967) Some models of neuronal variability. Biophys. J. 7. pp. 37-68."""

def __init__(self, alpha, beta, threshold=1.0, num_inputs=False, spike_grad=None, batch_size=False, hidden_init=False):
def __init__(
self,
alpha,
beta,
threshold=1.0,
num_inputs=False,
spike_grad=None,
batch_size=False,
hidden_init=False,
):
super(Stein, self).__init__(alpha, beta, threshold, spike_grad)

self.num_inputs = num_inputs
Expand All @@ -120,13 +132,21 @@ def __init__(self, alpha, beta, threshold=1.0, num_inputs=False, spike_grad=None

if self.hidden_init:
if not self.num_inputs:
raise ValueError("num_inputs must be specified to initialize hidden states as instance variables.")
raise ValueError(
"num_inputs must be specified to initialize hidden states as instance variables."
)
elif not self.batch_size:
raise ValueError("batch_size must be specified to initialize hidden states as instance variables.")
elif hasattr(self.num_inputs, '__iter__'):
self.spk, self.syn, self.mem = self.init_stein(self.batch_size, *(self.num_inputs)) # need to automatically call batch_size
raise ValueError(
"batch_size must be specified to initialize hidden states as instance variables."
)
elif hasattr(self.num_inputs, "__iter__"):
self.spk, self.syn, self.mem = self.init_stein(
self.batch_size, *(self.num_inputs)
) # need to automatically call batch_size
else:
self.spk, self.syn, self.mem = self.init_stein(self.batch_size, self.num_inputs)
self.spk, self.syn, self.mem = self.init_stein(
self.batch_size, self.num_inputs
)

def forward(self, input_, syn, mem):
if not self.hidden_init:
Expand Down Expand Up @@ -178,7 +198,16 @@ class SRM0(LIF):
R. Jovilet, J. Timothy, W. Gerstner (2003) The spike response model: A framework to predict neuronal spike trains. Artificial Neural Networks and Neural Information Processing, pp. 846-853.
"""

def __init__(self, alpha, beta, threshold=1.0, num_inputs=False, spike_grad=None, batch_size=False, hidden_init=False):
def __init__(
self,
alpha,
beta,
threshold=1.0,
num_inputs=False,
spike_grad=None,
batch_size=False,
hidden_init=False,
):
super(SRM0, self).__init__(alpha, beta, threshold, spike_grad)

self.num_inputs = num_inputs
Expand All @@ -187,14 +216,21 @@ def __init__(self, alpha, beta, threshold=1.0, num_inputs=False, spike_grad=None

if self.hidden_init:
if not self.num_inputs:
raise ValueError("num_inputs must be specified to initialize hidden states as instance variables.")
raise ValueError(
"num_inputs must be specified to initialize hidden states as instance variables."
)
elif not self.batch_size:
raise ValueError("batch_size must be specified to initialize hidden states as instance variables.")
elif hasattr(self.num_inputs, '__iter__'):
self.spk, self.syn_pre, self.syn_post, self.mem = self.init_srm0(batch_size=self.batch_size,
*(self.num_inputs))
raise ValueError(
"batch_size must be specified to initialize hidden states as instance variables."
)
elif hasattr(self.num_inputs, "__iter__"):
self.spk, self.syn_pre, self.syn_post, self.mem = self.init_srm0(
batch_size=self.batch_size, *(self.num_inputs)
)
else:
self.spk, self.syn_pre, self.syn_post, self.mem = self.init_srm0(batch_size, num_inputs)
self.spk, self.syn_pre, self.syn_post, self.mem = self.init_srm0(
batch_size, num_inputs
)

self.tau_srm = np.log(self.alpha) / (np.log(self.beta) - np.log(self.alpha)) + 1
if self.alpha <= self.beta:
Expand All @@ -206,15 +242,19 @@ def forward(self, input_, syn_pre, syn_post, mem):
spk, reset = self.fire(mem)
syn_pre = (self.alpha * syn_pre + input_) * (1 - reset)
syn_post = (self.beta * syn_post - input_) * (1 - reset)
mem = self.tau_srm * (syn_pre + syn_post)*(1-reset) + (mem*reset - reset)
mem = self.tau_srm * (syn_pre + syn_post) * (1 - reset) + (
mem * reset - reset
)
return spk, syn_pre, syn_post, mem

# if hidden states and outputs are instance variables
if self.hidden_init:
self.spk, self.reset = self.fire(self.mem)
self.syn_pre = (self.alpha * self.syn_pre + input_) * (1 - self.reset)
self.syn_post = (self.beta * self.syn_post - input_) * (1 - self.reset)
self.mem = self.tau_srm * (self.syn_pre + self.syn_post) * (1 - self.reset) + (self.mem * self.reset - self.reset)
self.mem = self.tau_srm * (self.syn_pre + self.syn_post) * (
1 - self.reset
) + (self.mem * self.reset - self.reset)
return self.spk, self.syn_pre, self.syn_post, self.mem

# cool forward function that resulted in burst firing - worth exploring
Expand All @@ -230,7 +270,7 @@ def forward(self, input_, syn_pre, syn_post, mem):
# syn_post = self.beta * syn_post - input_
# mem = self.tau_srm * (syn_pre + syn_post) - reset

# return spk, syn_pre, syn_post, mem
# return spk, syn_pre, syn_post, mem

@classmethod
def detach_hidden(cls):
Expand All @@ -248,7 +288,10 @@ def zeros_hidden(cls):
Intended for use where hidden state variables are instance variables."""
for layer in range(len(cls.instances)):
cls.instances[layer].spk = torch.zeros_like(cls.instances[layer].spk)
cls.instances[layer].syn_pre = torch.zeros_like(cls.instances[layer].syn_pre)
cls.instances[layer].syn_post = torch.zeros_like(cls.instances[layer].syn_post)
cls.instances[layer].syn_pre = torch.zeros_like(
cls.instances[layer].syn_pre
)
cls.instances[layer].syn_post = torch.zeros_like(
cls.instances[layer].syn_post
)
cls.instances[layer].mem = torch.zeros_like(cls.instances[layer].mem)

4 changes: 3 additions & 1 deletion snntorch/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

def BPTT(net, data, target, num_steps, batch_size, optimizer, criterion):
# Net requires hidden instance variables rather than global instance variables for TBPTT
return TBPTT(net, data, target, num_steps, batch_size, optimizer, criterion, K=num_steps)
return TBPTT(
net, data, target, num_steps, batch_size, optimizer, criterion, K=num_steps
)


def TBPTT(net, data, target, num_steps, batch_size, optimizer, criterion, K=1):
Expand Down
Loading

0 comments on commit 4fe3ee5

Please sign in to comment.