diff --git a/lstm.py b/lstm.py index b99bfa8..eb6a9ed 100644 --- a/lstm.py +++ b/lstm.py @@ -95,13 +95,13 @@ def bottom_data_is(self, x, s_prev = None, h_prev = None): self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) self.state.s = self.state.g * self.state.i + s_prev * self.state.f - self.state.h = self.state.s * self.state.o + self.state.h = np.tanh(self.state.s) * self.state.o self.xc = xc def top_diff_is(self, top_diff_h, top_diff_s): # notice that top_diff_s is carried along the constant error carousel - ds = self.state.o * top_diff_h + top_diff_s + ds = self.state.o * top_diff_h * tanh_derivative(self.state.s) + top_diff_s do = self.state.s * top_diff_h di = self.state.g * ds dg = self.state.i * ds