Skip to content

Commit

Permalink
fix leakyparallel weight_hh_l diagonal bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Nov 19, 2023
1 parent 284ce89 commit 67e5a85
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .neurons import _SpikeTensor, _SpikeTorchConv, LIF
import torch
import torch.nn as nn

class LeakyParallel(nn.Module):
"""
A parallel implementation of the Leaky neuron with an input linear layer.
A parallel implementation of the Leaky neuron with a fused input linear layer.
All time steps are passed to the input at once.
This implementation uses `torch.nn.RNN` to accelerate the implementation.
First-order leaky integrate-and-fire neuron model.
Input is assumed to be a current injection.
Expand All @@ -22,6 +22,15 @@ class LeakyParallel(nn.Module):
* :math:`U_{\\rm thr}` - Membrane threshold
* :math:`β` - Membrane potential decay rate
Several differences between `LeakyParallel` and `Leaky` include:
* Negative hidden states are clipped due to the forced ReLU operation in RNN
* Linear weights are included in addition to recurrent weights
* `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise
* There is no explicit reset mechanism
* Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel`
* Only the output spike is returned. Membrane potential is not accessible by default
* RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`.
Example::
import torch
Expand All @@ -36,8 +45,8 @@ def __init__(self):
super().__init__()
# initialize layers
self.lif1 = snn.ParallelLeaky(input_size=784, hidden_size=128)
self.lif2 = snn.ParallelLeaky(input_size=128, hidden_size=10, beta=beta)
self.lif1 = snn.LeakyParallel(input_size=784, hidden_size=128)
self.lif2 = snn.LeakyParallel(input_size=128, hidden_size=10, beta=beta)
def forward(self, x):
spk1 = self.lif1(x)
Expand Down Expand Up @@ -78,21 +87,16 @@ def forward(self, x):
to False
:type surrogate_disable: bool, Optional
:param init_hidden: Instantiates state variables as instance variables.
Defaults to False
:type init_hidden: bool, optional
:param inhibition: If `True`, suppresses all spiking other than the
neuron with the highest state. Defaults to False
:type inhibition: bool, optional
:param learn_beta: Option to enable learnable beta. Defaults to False
:type learn_beta: bool, optional
:param learn_threshold: Option to enable learnable threshold. Defaults
to False
:type learn_threshold: bool, optional
:param weight_hh_enable: Option to set the hidden matrix to be dense or diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False
:type weight_hh_enable: bool, optional
Inputs: \\input_
Expand Down Expand Up @@ -138,7 +142,7 @@ def __init__(
learn_threshold=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
diagonal_enable=False,
weight_hh_enable=False,
device=None,
dtype=None,
):
Expand All @@ -152,18 +156,24 @@ def __init__(

if self.beta is not None:
self.beta = self.beta.clamp(0, 1)

if diagonal_enable is False:
# Initial gradient and weights of w_hh are made diagonal
self._diagonal_enable(diagonal_enable)
# Register a gradient hook to clamp out non-diagonal matrices in backward pass
self.rnn.weight_hh_l0.register_hook(self.grad_hook)

if spike_grad is None:
self.spike_grad = self.ATan.apply
else:
self.spike_grad = spike_grad

self._beta_to_weight_hh()
if weight_hh_enable is False:
# Initial gradient and weights of w_hh are made diagonal
self.weight_hh_enable()
# Register a gradient hook to clamp out non-diagonal matrices in backward pass
if learn_beta:
self.rnn.weight_hh_l0.register_hook(self.grad_hook)

if not learn_beta:
# Make the weights non-learnable
self.rnn.weight_hh_l0.requires_grad_(False)

self._threshold_buffer(threshold, learn_threshold)
self._graded_spikes_buffer(
graded_spikes_factor, learn_graded_spikes_factor
Expand All @@ -173,17 +183,12 @@ def __init__(
if self.surrogate_disable:
self.spike_grad = self._surrogate_bypass

self._beta_to_weight_hh()

if not learn_beta:
# Make the weights non-learnable
self.rnn.weight_hh_l0.requires_grad_(False)


def forward(self, input_):
mem = self.rnn(input_)
# mem[0] contains relu'd outputs, mem[1] contains final hidden state
mem_shift = mem[0] - self.threshold
# print(mem[0])
# print(self.rnn.weight_hh_l0)
spk = self.spike_grad(mem_shift)
spk = spk * self.graded_spikes_factor
return spk
Expand Down Expand Up @@ -243,13 +248,9 @@ def backward(ctx, grad_output):
)
return grad, None

def _diagonal_enable(self, diagonal_enable):
if diagonal_enable is False:
for i in range(self.hidden_size):
for j in range(self.hidden_size):
if i != j:
self.rnn.weight_hh_l0.data[i, j] = 0
# self.rnn.weight_hh_l0.grad[i, j] = 0
def weight_hh_enable(self):
mask = torch.eye(self.hidden_size, self.hidden_size)
self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask

def grad_hook(self, grad):
device = grad.device
Expand Down

0 comments on commit 67e5a85

Please sign in to comment.