Skip to content

Commit

Permalink
add beta buffer function
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Nov 19, 2023
1 parent ee78e39 commit 1872916
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(

self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu',
bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype)

self._beta_buffer
if beta is not None:
beta = beta.clamp(0, 1)

Expand Down Expand Up @@ -247,7 +247,12 @@ def backward(ctx, grad_output):
return grad, None



def _beta_buffer(self, beta, learn_beta):
if not isinstance(beta, torch.Tensor):
beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy
if not learn_beta:
self.register_buffer("beta", beta)

def _graded_spikes_buffer(
self, graded_spikes_factor, learn_graded_spikes_factor
):
Expand Down

0 comments on commit 1872916

Please sign in to comment.