diff --git a/src/arraymancer/laser/strided_iteration/foreach_common.nim b/src/arraymancer/laser/strided_iteration/foreach_common.nim index ecfc0f89d..288fae920 100644 --- a/src/arraymancer/laser/strided_iteration/foreach_common.nim +++ b/src/arraymancer/laser/strided_iteration/foreach_common.nim @@ -7,12 +7,12 @@ import std/[macros, strutils], ../compiler_optim_hints -template isVar[T: object](x: T): bool = +template isVar[T](x: T): bool = ## Workaround due to `is` operator not working for `var` ## https://github.com/nim-lang/Nim/issues/9443 compiles(addr(x)) -proc aliasTensor(id: int, tensor: NimNode): NimNode = +proc aliasTensor(id: int, tensor: NimNode): tuple[alias: NimNode, isVar: NimNode] = ## Produce an alias variable for a tensor ## Supports: ## - identifiers @@ -41,11 +41,22 @@ proc aliasTensor(id: int, tensor: NimNode): NimNode = # Rewrite the AST to untyped t = nnkBracketExpr.newTree( - tensor[0][1], - tensor[0][2] + tensor[1] ) + for i in 2 ..< tensor.len: + t.add tensor[i] var alias = "" + let isVar = block: + # Handle slicing cases like foo[0..<1, 0..<2] + # that do not return `var` but are technically `var` + # if `foo` is var + if t.kind in {nnkDotExpr, nnkBracketExpr}: + let t0 = t[0] + quote do: isVar(`t0`) + else: + quote do: isVar(`t`) + while t.kind in {nnkDotExpr, nnkBracketExpr}: if t[0].kind notin {nnkIdent, nnkSym}: error "Expected a field name but found \"" & t[0].repr() @@ -57,7 +68,7 @@ proc aliasTensor(id: int, tensor: NimNode): NimNode = alias &= $t - return newIdentNode($alias & "_alias" & $id & '_') + return (newIdentNode($alias & "_alias" & $id & '_'), isVar) proc initForEach*( params: NimNode, @@ -105,10 +116,10 @@ proc initForEach*( aliases_stmt.add newCall(bindSym"withCompilerOptimHints") for i, tensor in tensors: - let alias = aliasTensor(i, tensor) + let (alias, detectVar) = aliasTensor(i, tensor) aliases.add alias aliases_stmt.add quote do: - when isVar(`tensor`): + when `detectVar`: var `alias`{.align_variable.} = `tensor` else: let `alias`{.align_variable.} = `tensor` diff --git a/src/arraymancer/nn_primitives/nnp_gru.nim b/src/arraymancer/nn_primitives/nnp_gru.nim index 1453081eb..4d02a9a1c 100644 --- a/src/arraymancer/nn_primitives/nnp_gru.nim +++ b/src/arraymancer/nn_primitives/nnp_gru.nim @@ -63,7 +63,6 @@ proc gru_cell_inference*[T: SomeFloat]( # Slices sr = (0 ..< H)|1 sz = (H ..< 2*H)|1 - srz = (0 ..< 2*H)|1 s = (2*H ..< 3*H)|1 @@ -73,19 +72,29 @@ proc gru_cell_inference*[T: SomeFloat]( linear(input, W3, bW3, W3x) linear(hidden, U3, bU3, U3h) - # Step 2 - Computing reset (r) and update (z) gate - var W2ru = W3x[_, srz] # shape [batch_size, 2*H] - we reuse the previous buffer - apply2_inline(W2ru, U3h[_, srz]): - sigmoid(x + y) - - # Step 3 - Computing candidate hidden state ñ - var n = W3x[_, s] # shape [batch_size, H] - we reuse the previous buffer - apply3_inline(n, W2ru[_, sr], U3h[_, s]): - tanh(x + y * z) - - # Step 4 - Update the hidden state - apply3_inline(hidden, W3x[_, sz], n): - (1 - y) * z + y * x + # Step 2 - Fused evaluation of the 4 GRU equations + # r = σ(Wr * x + bWr + Ur * h + bUr) + # z = σ(Wz * x + bWz + Uz * h + bUz) + # n = tanh(W * x + bW + r *. (U * h + bU )) + # h' = (1 - z) *. n + z *. h + + # shape [batch_size, H] - we reuse the previous buffers + forEach wrx in W3x[_, sr], # Wr*x + wzx in W3x[_, sz], # Wz*x + wx in W3x[_, s], # W*x + urh in U3h[_, sr], # Ur*h + uzh in U3h[_, sz], # Uz*h + uh in U3h[_, s], # U*h + h in hidden: # hidden state + # Reset (r) gate and Update (z) gate + let r = sigmoid(wrx + urh) + let z = sigmoid(wzx + uzh) + + # Candidate hidden state ñ + let n = tanh(wx + r * uh) + + # h' = (1 - z) *. ñ + z *. h + h = (1-z) * n + z*h proc gru_cell_forward*[T: SomeFloat]( input, @@ -124,26 +133,38 @@ proc gru_cell_forward*[T: SomeFloat]( linear(input, W3, bW3, W3x) linear(hidden, U3, bU3, U3h) - # # Saving for backprop - apply2_inline(Uh, U3h[_, s]): - y - - # Step 2 - Computing reset (r) and update (z) gate - apply3_inline(r, W3x[_, sr], U3h[_, sr]): - sigmoid(y + z) - - apply3_inline(z, W3x[_, sz], U3h[_, sz]): - sigmoid(y + z) - - # Step 3 - Computing candidate hidden state ñ - # TODO: need apply4 / loopfusion for efficient - # buffer passing in Stacked GRU implementation - n = map3_inline(W3x[_, s], r, U3h[_, s]): - tanh(x + y * z) - - # Step 4 - Update the hidden state - apply3_inline(hidden, z, n): - (1 - y) * z + y * x + # Step 2 - Fused evaluation of the 4 GRU equations + # and saving for backprop + # r = σ(Wr * x + bWr + Ur * h + bUr) + # z = σ(Wz * x + bWz + Uz * h + bUz) + # n = tanh(W * x + bW + r *. (U * h + bU )) + # h' = (1 - z) *. n + z *. h + + # shape [batch_size, H] - we reuse the previous buffers + forEach wrx in W3x[_, sr], # Wr*x + wzx in W3x[_, sz], # Wz*x + wx in W3x[_, s], # W*x + urh in U3h[_, sr], # Ur*h + uzh in U3h[_, sz], # Uz*h + uh in U3h[_, s], # U*h + h in hidden, # hidden state + saveUh in Uh, # U*h cache for backprop + reset in r, # reset gate cache for backprop + update in z, # update gate cache for backprop + candidate in n: # candidate hidden state cache for backprop + + # Cache for backprop + saveUh = uh + + # Reset (r) gate and Update (z) gate + reset = sigmoid(wrx + urh) + update = sigmoid(wzx + uzh) + + # Candidate hidden state ñ + candidate = tanh(wx + reset * uh) + + # h' = (1 - z) *. ñ + z *. h + h = (1-update) * candidate + update*h proc gru_cell_backward*[T: SomeFloat]( dx, dh, dW3, dU3, # input and weights gradients @@ -162,6 +183,9 @@ proc gru_cell_backward*[T: SomeFloat]( ## - dnext: gradient flowing back from the next layer ## - x, h, W3, U3: inputs saved from the forward pass ## - r, z, n, Uh: intermediate results saved from the forward pass of shape [batch_size, hidden_size] + + # TODO: fused backprop with forEach + # Backprop of step 4 - z part let dz = (h - n) *. dnext let dn = (1.0.T -. z) *. dnext