From c19a127730d12734e4094dcb67525b03e28c1857 Mon Sep 17 00:00:00 2001 From: "@picocreator (Eugene Cheah)" Date: Fri, 8 Sep 2023 00:27:01 +0000 Subject: [PATCH] Fixing v5 model output bug, as per r3 changes --- RWKV-v5/src/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index 2def9e6e..06cd3b7b 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -324,8 +324,10 @@ def _forward_state_chunk(self, r, k, v, g, w, wk, wb, ws, x_l, last_state: TimeM # x = self.ln_x(x/self.head_size_divisor).view(B, TT, H*S) x = self.ln_x(x/8).view(B, TT, H*S) - return self.output(x), TimeMixState(x_l, s) - + # Fix missing *g for output as per : + # https://github.com/RWKV/RWKV-infctx-trainer/commit/beb46d599042b77d53db9c7fa59a5966e7d33719#r126730367 + return self.output(x)*g, TimeMixState(x_l, s) + def _forward_chunk(self, x, last_state: TimeMixState): # Forward sizings (Batch, Time/ContextLength, Tokens) B, TT, C = x.size()