Skip to content

Commit

Permalink
Batch Norm (#97)
Browse files Browse the repository at this point in the history
* Added batch norm
  • Loading branch information
pashashiz authored Nov 12, 2023
1 parent 7c1ea4a commit d715657
Show file tree
Hide file tree
Showing 27 changed files with 477 additions and 103 deletions.
6 changes: 6 additions & 0 deletions src/main/scala/scanet/core/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ case class DependsOn[A: TensorType](expr: Expr[A], dep: Expr[_]) extends Expr[A]
override def inputs: Seq[Expr[_]] = Seq(expr)
override def controls: Seq[Expr[_]] = Seq(dep)
override def compiler: Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] =
List(parentGrad)
}
}

case class Switch[A: TensorType](cond: Expr[Boolean], output: Expr[A]) extends Expr[(A, A)] {
Expand Down
62 changes: 49 additions & 13 deletions src/main/scala/scanet/core/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
indexes.reverse
}

def apply(dim: Int): Int = dims(dim)
def get(dim: Int): Int = dims(dim)
def get(dim: Int): Int = if (dim == -1) last else dims(dim)

def apply(dim: Int): Int = get(dim)

def rank: Int = dims.size

def axis: List[Int] = dims.indices.toList

def axisExcept(other: Int*): List[Int] = {
val indexedAxis = indexAxis(other)
(dims.indices.toSet -- indexedAxis.toSet).toList.sorted
}

def isScalar: Boolean = rank == 0

def isInBound(projection: Projection): Boolean = {
Expand All @@ -52,7 +60,10 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
def drop(n: Int): Shape = Shape(dims.drop(n))
def dropRight(n: Int): Shape = Shape(dims.dropRight(n))

def last: Int = dims.last
def last: Int = {
require(!isScalar, "cannot get last dimension for scalar")
dims.last
}

def prepend(dim: Int): Shape = Shape(dim +: dims: _*)
def +:(dim: Int): Shape = prepend(dim)
Expand Down Expand Up @@ -159,47 +170,72 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
}
}

def permute(indexes: Int*): Shape = {
def maxDims(other: Shape): Shape = {
val maxRank = rank max other.rank
val left = alignLeft(maxRank, 1)
val right = other.alignLeft(maxRank, 1)
val dimsResult = left.dims.zip(right.dims)
.map { case (l, r) => l max r }
Shape(dimsResult)
}

def permute(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
rank == indexes.size,
rank == indexedAxis.size,
"the number of permutation indexes " +
s"should be equal to rank $rank, but was (${indexes.mkString(", ")})")
Shape(indexes.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
s"should be equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
}

def select(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of selected axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(axis.map(dims(_)).toList)
Shape(indexedAxis.map(get).toList)
}

def remove(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of removed axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
val filteredDims = dims.zipWithIndex
.filter { case (_, i) => !axis.contains(i) }
.filter {
case (_, i) =>
!indexedAxis.contains(i)
}
.map { case (dim, _) => dim }
Shape(filteredDims)
}

def updated(axis: Int, value: Int): Shape = updateAll(value)(axis)

def updateAll(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
require(
axis.forall(_ < rank),
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of updated axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
val updatedDims = dims.zipWithIndex.map {
case (dim, i) =>
if (axis.contains(i)) value else dim
if (indexedAxis.contains(i)) value else dim
}
Shape(updatedDims)
}

def updateAllExcept(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
val axisToUpdate = dims.indices.toSet -- indexedAxis.toSet
updateAll(value)(axisToUpdate.toList: _*)
}

private def indexAxis(axis: Seq[Int]): Seq[Int] =
axis.map(a => if (a == -1) dims.size - 1 else a)

def minus(other: Shape): Shape = {
require(broadcastableAny(other), s"cannot $this - $other")
if (endsWith(other)) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/scanet/core/core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ package object core {

object syntax extends CoreSyntax

def error(message: String): Nothing = throw new RuntimeException(message)
def error(message: String): Nothing =
throw new RuntimeException(message)

def memoize[I1, O](f: I1 => O): I1 => O = {
val cache = mutable.HashMap[I1, O]()
Expand Down
8 changes: 5 additions & 3 deletions src/main/scala/scanet/math/alg/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends
s"cannot multiply tensors with shapes ${left.shape} * ${right.shape}")
override def name: String = "Mul"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = left.shape max right.shape
override val shape: Shape = left.shape maxDims right.shape
override def inputs: Seq[Expr[_]] = Seq(left, right)
override def compiler: core.Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
Expand Down Expand Up @@ -233,8 +233,10 @@ case class Mean[A: Numeric] private (expr: Expr[A], axis: Seq[Int], keepDims: Bo
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
// we need to recover reduced axis with 1, cause broadcasting will not always work
val parentShape = axis.foldLeft(parentGrad.shape)((s, axis) => s.insert(axis, 1))
val parentShape =
if (keepDims) parentGrad.shape
// we need to recover reduced axis with 1, cause broadcasting will not always work
else axis.foldLeft(parentGrad.shape)((s, axis) => s.insert(axis, 1))
val size = expr.shape.select(axis: _*).power
List(kernels.ones[R](expr.shape) * parentGrad.reshape(parentShape) / size.const.cast[R])
}
Expand Down
10 changes: 6 additions & 4 deletions src/main/scala/scanet/models/Math.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@ import scanet.core.Params.Weights
import scanet.core.{Expr, Floating, Params, Shape}
import scanet.math.syntax._
import scanet.models.Aggregation.Avg
import scanet.models.layer.StatelessLayer
import scanet.models.layer.{Layer, StatelessLayer}

object Math {

case object `x^2` extends StatelessLayer {
case class `x^2`(override val trainable: Boolean = true) extends StatelessLayer {

override def params(input: Shape): Params[ParamDef] =
Params(Weights -> ParamDef(Shape(), Initializer.Zeros, Some(Avg), trainable = true))

override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] =
override def buildStateless[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] =
pow(params(Weights), 2)

override def penalty[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] =
override def penalty[E: Floating](params: Params[Expr[E]]): Expr[E] =
zeros[E](Shape())

override def outputShape(input: Shape): Shape = input

override def makeTrainable(trainable: Boolean): Layer = copy(trainable = trainable)
}
}
28 changes: 22 additions & 6 deletions src/main/scala/scanet/models/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ abstract class Model extends Serializable {
* @param params initialized or calculated model params
* @return penalty
*/
def penalty[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E]
def penalty[E: Floating](params: Params[Expr[E]]): Expr[E]

def result[E: Floating]: (Expr[E], Params[Expr[E]]) => Expr[E] =
(input, params) => build(input, params)._1
Expand All @@ -40,6 +40,11 @@ abstract class Model extends Serializable {

def outputShape(input: Shape): Shape

def trainable: Boolean
def makeTrainable(trainable: Boolean): Model
def freeze: Model = makeTrainable(false)
def unfreeze: Model = makeTrainable(true)

def withLoss(loss: Loss): LossModel = LossModel(this, loss)

private def makeGraph[E: Floating](input: Shape): Expr[E] =
Expand Down Expand Up @@ -89,15 +94,14 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
output: Expr[E],
params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = {
val (result, nextParams) = model.build(input, params)
val loss = lossF.build(result, output) plus model.penalty(input.shape, params)
val loss = lossF.build(result, output) plus model.penalty(params)
(loss, nextParams)
}

def loss[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Expr[E] =
(input, output, params) => buildStateful(input, output, params)._1

def lossStateful[E: Floating]
: (Expr[E], Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) =
def lossStateful[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) =
(input, output, params) => buildStateful(input, output, params)

def grad[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Params[Expr[E]] =
Expand All @@ -114,7 +118,13 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
(grad, nextState)
}

def trained[E: Floating](params: Params[Tensor[E]]) = new TrainedModel(this, params)
def trainable: Boolean = model.trainable
def makeTrainable(trainable: Boolean): LossModel = copy(model = model.makeTrainable(trainable))
def freeze: LossModel = makeTrainable(false)
def unfreeze: LossModel = makeTrainable(true)

def trained[E: Floating](params: Params[Tensor[E]]): TrainedModel[E] =
TrainedModel(this.freeze, params)

def displayLoss[E: Floating](input: Shape, dir: String = ""): Unit = {
val params = model.params(input)
Expand All @@ -141,7 +151,7 @@ case class LossModel(model: Model, lossF: Loss) extends Serializable {
override def toString: String = s"$lossF($model)"
}

class TrainedModel[E: Floating](val lossModel: LossModel, val params: Params[Tensor[E]]) {
case class TrainedModel[E: Floating](lossModel: LossModel, params: Params[Tensor[E]]) {

def buildResult(input: Expr[E]): Expr[E] =
buildResultStateful(input)._1
Expand All @@ -168,4 +178,10 @@ class TrainedModel[E: Floating](val lossModel: LossModel, val params: Params[Ten
(input, output) => buildLossStateful(input, output)

def outputShape(input: Shape): Shape = lossModel.model.outputShape(input)

def trainable: Boolean = lossModel.trainable
def makeTrainable(trainable: Boolean): TrainedModel[E] =
copy(lossModel = lossModel.makeTrainable(trainable))
def freeze: TrainedModel[E] = makeTrainable(false)
def unfreeze: TrainedModel[E] = makeTrainable(true)
}
135 changes: 135 additions & 0 deletions src/main/scala/scanet/models/layer/BatchNorm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package scanet.models.layer

import scanet.core._
import scanet.math.syntax.zeros
import scanet.models.Aggregation.Avg
import scanet.models.layer.BatchNorm.{Beta, Gamma, MovingMean, MovingVariance}
import scanet.models.{Initializer, ParamDef, Regularization}
import scanet.syntax._

/** Layer that normalizes its inputs.
*
* Batch normalization applies a transformation that maintains the mean output
* close to 0 and the output standard deviation close to 1.
*
* Importantly, batch normalization works differently during training and
* during inference.
*
* '''During training''', the layer normalizes its output using
* the mean and standard deviation of the current batch of inputs. That is to
* say, for each channel being normalized, the layer returns
*
* {{{gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta}}}
*
* where:
* - `epsilon` is small constant (configurable as part of the constructor arguments)
* - `gamma` is a learned scaling factor (initialized as 1)
* - `beta` is a learned offset factor (initialized as 0)
*
* '''During inference''' the layer normalizes its output using a moving average of the
* mean and standard deviation of the batches it has seen during training. That
* is to say, it returns
*
* {{{gamma * (batch - moving_mean) / sqrt(moving_var + epsilon) + beta}}}
*
* where `moving_mean` and `moving_var` are non-trainable variables that
* are updated each time the layer in called in training mode, as such:
* - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
* - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
*
* As such, the layer will only normalize its inputs during inference
* after having been trained on data that has similar statistics as the inference data.
*
* Reference [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
*
* @param axis Axis that should be normalized (typically the features axis).
* @param momentum Momentum for the moving average.
* @param epsilon Small float added to variance to avoid dividing by zero.
* @param betaInitializer Initializer for the beta weight
* @param gammaInitializer Initializer for the gamma weight.
* @param movingMeanInitializer Initializer for the moving mean.
* @param movingVarianceInitializer Initializer for the moving variance.
* @param betaRegularizer Regularizer for the beta weight.
* @param gammaRegularizer Regularizer for the gamma weight.
* @param trainable Whether layer is trainable
*/
case class BatchNorm(
axis: Seq[Int] = Seq(-1),
momentum: Float = 0.99f,
epsilon: Float = 1e-3f,
betaInitializer: Initializer = Initializer.Zeros,
gammaInitializer: Initializer = Initializer.Ones,
movingMeanInitializer: Initializer = Initializer.Zeros,
movingVarianceInitializer: Initializer = Initializer.Ones,
betaRegularizer: Regularization = Regularization.Zero,
gammaRegularizer: Regularization = Regularization.Zero,
override val trainable: Boolean = true)
extends Layer {

override def stateful: Boolean = true

private def paramsShape(input: Shape): Shape = {
// given shape = (2, 4, 6) and axis = (1, 2)
// we will end up with specified axis keeping their dimension, while the rest reduced to 1
// so result shape = (1, 4, 6)
// note 1: in case of reduction operation, such as mean - it works vice-versa, specified dimension will become 1
// note 2: we keep all 1 dimensions without squeezing to perform proper broadcast with complex deep shapes
val reduceAxis = input.axisExcept(axis: _*)
input.updateAll(1)(reduceAxis: _*)
}

override def params(input: Shape): Params[ParamDef] = {
val shape = paramsShape(input)
Params(
Beta -> ParamDef(shape, betaInitializer, Some(Avg), trainable = trainable),
Gamma -> ParamDef(shape, gammaInitializer, Some(Avg), trainable = trainable),
MovingMean -> ParamDef(shape, movingMeanInitializer, Some(Avg)),
MovingVariance -> ParamDef(shape, movingVarianceInitializer, Some(Avg)))
}

override def build[E: Floating](
input: Expr[E],
params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = {
val prevMovingMean = params(MovingMean)
val prevMovingVariance = params(MovingVariance)
val (movingMean, movingVariance) =
if (trainable) {
val momentumE = momentum.const.cast[E]
val reduceAxis = input.shape.axisExcept(axis: _*)
val batchMean = input.mean(reduceAxis, keepDims = true)
val batchVariance = (input - batchMean).sqr.mean(reduceAxis, keepDims = true)
val movingMean = prevMovingMean.decayingAvg(batchMean, momentumE)
val movingVariance = prevMovingVariance.decayingAvg(batchVariance, momentumE)
(movingMean, movingVariance)
} else {
(prevMovingMean, prevMovingVariance)
}
val epsilonE = epsilon.const.cast[E]
val output =
(input - movingMean) * params(Gamma) /
(movingVariance.sqrt + epsilonE) - params(Beta)
val nextState: Params[Expr[E]] =
if (trainable) Params(
MovingMean -> movingMean,
MovingVariance -> movingVariance)
else Params.empty
(output, nextState)
}

override def penalty[E: Floating](params: Params[Expr[E]]): Expr[E] =
if (trainable) betaRegularizer.build(params(Gamma)) + betaRegularizer.build(params(Gamma))
else zeros[E](Shape())

override def outputShape(input: Shape): Shape = input

override def makeTrainable(trainable: Boolean): BatchNorm = copy(trainable = trainable)

override def toString: String = s"BatchNorm($axis)"
}

object BatchNorm {
val Beta: Path = "beta"
val Gamma: Path = "gamma"
val MovingMean: Path = "moving_mean"
val MovingVariance: Path = "moving_variance"
}
Loading

0 comments on commit d715657

Please sign in to comment.