Skip to content

Commit

Permalink
fix self instance of beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Nov 19, 2023
1 parent 1872916 commit 4c8f993
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def __init__(
self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu',
bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype)
self._beta_buffer
if beta is not None:
beta = beta.clamp(0, 1)
if self.beta is not None:
self.beta = self.beta.clamp(0, 1)

if spike_grad is None:
self.spike_grad = self.ATan.apply
Expand All @@ -164,17 +164,17 @@ def __init__(
self.spike_grad = self._surrogate_bypass

with torch.no_grad():
if beta is not None:
# Set all weights to the scalar value of beta
if isinstance(beta, float) or isinstance(beta, int):
self.rnn.weight_hh_10.fill_(beta)
elif isinstance(beta, torch.Tensor) or isinstance(beta, torch.FloatTensor):
if len(beta) == 1:
self.rnn.weight_hh_10.fill_(beta)
elif len(beta) == hidden_size:
# Replace each value with the corresponding value in beta
if self.beta is not None:
# Set all weights to the scalar value of self.beta
if isinstance(self.beta, float) or isinstance(self.beta, int):
self.rnn.weight_hh_10.fill_(self.beta)
elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor):
if len(self.beta) == 1:
self.rnn.weight_hh_10.fill_(self.beta)
elif len(self.beta) == hidden_size:
# Replace each value with the corresponding value in self.beta
for i in range(hidden_size):
self.rnn.weight_hh_l0.data[i].fill_(beta[i])
self.rnn.weight_hh_l0.data[i].fill_(self.beta[i])
else:
raise ValueError("Beta must be either a single value or of length 'hidden_size'.")

Expand Down

0 comments on commit 4c8f993

Please sign in to comment.