Skip to content

Commit

Permalink
device fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Nov 19, 2023
1 parent 0a42690 commit 284ce89
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion snntorch/_neurons/leakyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ def _diagonal_enable(self, diagonal_enable):
# self.rnn.weight_hh_l0.grad[i, j] = 0

def grad_hook(self, grad):
device = grad.device
# Create a mask that is 1 on the diagonal and 0 elsewhere
mask = torch.eye(self.hidden_size, self.hidden_size)
mask = torch.eye(self.hidden_size, self.hidden_size, device=device)
# Use the mask to zero out non-diagonal elements of the gradient
return grad * mask

Expand Down

0 comments on commit 284ce89

Please sign in to comment.