diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 2e92ff9a..5781a4d1 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -145,6 +145,7 @@ def __init__( self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) + self.beta = beta self._beta_buffer if self.beta is not None: self.beta = self.beta.clamp(0, 1) @@ -247,11 +248,11 @@ def backward(ctx, grad_output): return grad, None - def _beta_buffer(self, beta, learn_beta): - if not isinstance(beta, torch.Tensor): - beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy + def _beta_buffer(self, learn_beta): + if not isinstance(self.beta, torch.Tensor): + self.beta = torch.as_tensor(self.beta) # TODO: or .tensor() if no copy if not learn_beta: - self.register_buffer("beta", beta) + self.register_buffer("beta", self.beta) def _graded_spikes_buffer( self, graded_spikes_factor, learn_graded_spikes_factor