From 67e5a8515accc2a42f26bf57e8fd17c7f7ad2e51 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sun, 19 Nov 2023 11:31:57 -0800 Subject: [PATCH] fix leakyparallel weight_hh_l diagonal bug --- snntorch/_neurons/leakyparallel.py | 67 +++++++++++++++--------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 35942b58..de81266d 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -1,11 +1,11 @@ -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF import torch import torch.nn as nn class LeakyParallel(nn.Module): """ - A parallel implementation of the Leaky neuron with an input linear layer. + A parallel implementation of the Leaky neuron with a fused input linear layer. All time steps are passed to the input at once. + This implementation uses `torch.nn.RNN` to accelerate the implementation. First-order leaky integrate-and-fire neuron model. Input is assumed to be a current injection. @@ -22,6 +22,15 @@ class LeakyParallel(nn.Module): * :math:`U_{\\rm thr}` - Membrane threshold * :math:`β` - Membrane potential decay rate + Several differences between `LeakyParallel` and `Leaky` include: + * Negative hidden states are clipped due to the forced ReLU operation in RNN + * Linear weights are included in addition to recurrent weights + * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise + * There is no explicit reset mechanism + * Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel` + * Only the output spike is returned. Membrane potential is not accessible by default + * RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`. + Example:: import torch @@ -36,8 +45,8 @@ def __init__(self): super().__init__() # initialize layers - self.lif1 = snn.ParallelLeaky(input_size=784, hidden_size=128) - self.lif2 = snn.ParallelLeaky(input_size=128, hidden_size=10, beta=beta) + self.lif1 = snn.LeakyParallel(input_size=784, hidden_size=128) + self.lif2 = snn.LeakyParallel(input_size=128, hidden_size=10, beta=beta) def forward(self, x): spk1 = self.lif1(x) @@ -78,14 +87,6 @@ def forward(self, x): to False :type surrogate_disable: bool, Optional - :param init_hidden: Instantiates state variables as instance variables. - Defaults to False - :type init_hidden: bool, optional - - :param inhibition: If `True`, suppresses all spiking other than the - neuron with the highest state. Defaults to False - :type inhibition: bool, optional - :param learn_beta: Option to enable learnable beta. Defaults to False :type learn_beta: bool, optional @@ -93,6 +94,9 @@ def forward(self, x): to False :type learn_threshold: bool, optional + :param weight_hh_enable: Option to set the hidden matrix to be dense or diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False + :type weight_hh_enable: bool, optional + Inputs: \\input_ @@ -138,7 +142,7 @@ def __init__( learn_threshold=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, - diagonal_enable=False, + weight_hh_enable=False, device=None, dtype=None, ): @@ -152,18 +156,24 @@ def __init__( if self.beta is not None: self.beta = self.beta.clamp(0, 1) - - if diagonal_enable is False: - # Initial gradient and weights of w_hh are made diagonal - self._diagonal_enable(diagonal_enable) - # Register a gradient hook to clamp out non-diagonal matrices in backward pass - self.rnn.weight_hh_l0.register_hook(self.grad_hook) if spike_grad is None: self.spike_grad = self.ATan.apply else: self.spike_grad = spike_grad + self._beta_to_weight_hh() + if weight_hh_enable is False: + # Initial gradient and weights of w_hh are made diagonal + self.weight_hh_enable() + # Register a gradient hook to clamp out non-diagonal matrices in backward pass + if learn_beta: + self.rnn.weight_hh_l0.register_hook(self.grad_hook) + + if not learn_beta: + # Make the weights non-learnable + self.rnn.weight_hh_l0.requires_grad_(False) + self._threshold_buffer(threshold, learn_threshold) self._graded_spikes_buffer( graded_spikes_factor, learn_graded_spikes_factor @@ -173,17 +183,12 @@ def __init__( if self.surrogate_disable: self.spike_grad = self._surrogate_bypass - self._beta_to_weight_hh() - - if not learn_beta: - # Make the weights non-learnable - self.rnn.weight_hh_l0.requires_grad_(False) - - def forward(self, input_): mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state mem_shift = mem[0] - self.threshold + # print(mem[0]) + # print(self.rnn.weight_hh_l0) spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor return spk @@ -243,13 +248,9 @@ def backward(ctx, grad_output): ) return grad, None - def _diagonal_enable(self, diagonal_enable): - if diagonal_enable is False: - for i in range(self.hidden_size): - for j in range(self.hidden_size): - if i != j: - self.rnn.weight_hh_l0.data[i, j] = 0 - # self.rnn.weight_hh_l0.grad[i, j] = 0 + def weight_hh_enable(self): + mask = torch.eye(self.hidden_size, self.hidden_size) + self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask def grad_hook(self, grad): device = grad.device