Skip to content
This repository has been archived by the owner on Sep 15, 2022. It is now read-only.

Commit

Permalink
[super-hot-fix] cahnged h_n to c_n by mistake;; to be consistenet wit…
Browse files Browse the repository at this point in the history
…h prior trainings changing back to h_n
  • Loading branch information
JosephGeoBenjamin committed Aug 20, 2020
1 parent 7c72584 commit ab742d5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 4 additions & 2 deletions algorithms/lm_fused_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def forward(self, x, hidden, enc_output):
hid_for_att = torch.zeros((self.dec_layers, batch_sz,
self.dec_hidden_dim )).to(self.device)
elif self.dec_rnn_type == 'lstm':
hid_for_att = hidden[1] # c_n
hid_for_att = hidden[0] # h_n <<<check
else:
hid_for_att = hidden

Expand Down Expand Up @@ -267,7 +267,9 @@ def get_hidden(self, x, hidden, enc_output):
hid_for_att = torch.zeros((self.dec_layers, batch_sz,
self.dec_hidden_dim )).to(self.device)
elif self.dec_rnn_type == 'lstm':
hid_for_att = hidden[1] # c_n
hid_for_att = hidden[0] # h_n <<<check
else:
hid_for_att = hidden

# x (batch_size, 1, dec_embed_dim) -> after embedding
x = self.embedding(x)
Expand Down
5 changes: 1 addition & 4 deletions algorithms/recurrent_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ def get_word_embedding(self, x):
return out_embed





class Decoder(nn.Module):
def __init__(self, output_dim, embed_dim, hidden_dim,
rnn_type = 'gru', layers = 1,
Expand Down Expand Up @@ -199,7 +196,7 @@ def forward(self, x, hidden, enc_output):
hid_for_att = torch.zeros((self.dec_layers, batch_sz,
self.dec_hidden_dim )).to(self.device)
elif self.dec_rnn_type == 'lstm':
hid_for_att = hidden[1] # c_n
hid_for_att = hidden[0] # h_n
else:
hid_for_att = hidden

Expand Down

0 comments on commit ab742d5

Please sign in to comment.