Skip to content

Commit

Permalink
parallel leaky neuron bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Nov 19, 2023
1 parent 29b7840 commit b30e3f3
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __init__(

self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu',
bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype)
self.beta = beta
self._beta_buffer

self._beta_buffer(beta, learn_beta)
if self.beta is not None:
self.beta = self.beta.clamp(0, 1)

Expand All @@ -168,10 +168,10 @@ def __init__(
if self.beta is not None:
# Set all weights to the scalar value of self.beta
if isinstance(self.beta, float) or isinstance(self.beta, int):
self.rnn.weight_hh_10.fill_(self.beta)
self.rnn.weight_hh_l0.fill_(self.beta)
elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor):
if len(self.beta) == 1:
self.rnn.weight_hh_10.fill_(self.beta)
self.rnn.weight_hh_l0.fill_(self.beta[0])
elif len(self.beta) == hidden_size:
# Replace each value with the corresponding value in self.beta
for i in range(hidden_size):
Expand Down Expand Up @@ -248,11 +248,11 @@ def backward(ctx, grad_output):
return grad, None


def _beta_buffer(self, learn_beta):
if not isinstance(self.beta, torch.Tensor):
self.beta = torch.as_tensor(self.beta) # TODO: or .tensor() if no copy
if not learn_beta:
self.register_buffer("beta", self.beta)
def _beta_buffer(self, beta, learn_beta):
if not isinstance(beta, torch.Tensor):
if beta is not None:
beta = torch.as_tensor([beta]) # TODO: or .tensor() if no copy
self.register_buffer("beta", beta)

def _graded_spikes_buffer(
self, graded_spikes_factor, learn_graded_spikes_factor
Expand Down

0 comments on commit b30e3f3

Please sign in to comment.