From 367e49b85014b35dd88019755cda7125726aa3a8 Mon Sep 17 00:00:00 2001 From: Pavlo Pohrebnyi Date: Mon, 20 Nov 2023 09:21:01 +0200 Subject: [PATCH] Added fused Batch Norm (#99) * Added fused batch norm --- src/main/scala/scanet/core/Shape.scala | 73 +++--- .../scala/scanet/math/alg/AllKernels.scala | 95 +++++++- .../scala/scanet/math/nn/AllKernels.scala | 226 +++++++++++++++++- .../scala/scanet/models/layer/BatchNorm.scala | 87 +++++-- src/test/scala/scanet/core/ViewSpec.scala | 14 +- .../scala/scanet/math/alg/KernelsSpec.scala | 15 ++ .../scanet/math/logical/KernelsSpec.scala | 4 +- src/test/scala/scanet/models/ANNSpec.scala | 2 +- src/test/scala/scanet/models/CNNSpec.scala | 4 +- .../scanet/models/layer/BatchNormSpec.scala | 12 +- 10 files changed, 442 insertions(+), 90 deletions(-) diff --git a/src/main/scala/scanet/core/Shape.scala b/src/main/scala/scanet/core/Shape.scala index bc15ef0..a58d0ea 100644 --- a/src/main/scala/scanet/core/Shape.scala +++ b/src/main/scala/scanet/core/Shape.scala @@ -30,11 +30,11 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] { def rank: Int = dims.size - def axis: List[Int] = dims.indices.toList + def axes: List[Int] = dims.indices.toList - def axisExcept(other: Int*): List[Int] = { - val indexedAxis = indexAxis(other) - (dims.indices.toSet -- indexedAxis.toSet).toList.sorted + def axesExcept(other: Int*): List[Int] = { + val indexedAxes = indexAxes(other) + (dims.indices.toSet -- indexedAxes.toSet).toList.sorted } def isScalar: Boolean = rank == 0 @@ -158,8 +158,8 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] { def broadcastableAny(other: Shape): Boolean = broadcastableBy(other) || other.broadcastableBy(this) - def broadcastableAxis(other: Shape): Seq[Int] = { - require(broadcastableAny(other), s"cannot find broadcastable axis for $this and $other") + def broadcastableAxes(other: Shape): Seq[Int] = { + require(broadcastableAny(other), s"cannot find broadcastable axes for $this and $other") if (rank < other.rank) { Seq() } else { @@ -179,34 +179,34 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] { Shape(dimsResult) } - def permute(axis: Int*): Shape = { - val indexedAxis = indexAxis(axis) + def permute(axes: Int*): Shape = { + val indexedAxes = indexAxes(axes) require( - rank == indexedAxis.size, + rank == indexedAxes.size, "the number of permutation indexes " + - s"should be equal to rank $rank, but was (${axis.mkString(", ")})") - Shape(indexedAxis.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse) + s"should be equal to rank $rank, but was (${axes.mkString(", ")})") + Shape(indexedAxes.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse) } - def select(axis: Int*): Shape = { - val indexedAxis = indexAxis(axis) + def select(axes: Int*): Shape = { + val indexedAxes = indexAxes(axes) require( - indexedAxis.forall(i => i < rank && i >= 0), - s"the number of selected axis " + - s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})") - Shape(indexedAxis.map(get).toList) + indexedAxes.forall(i => i < rank && i >= 0), + s"the number of selected axes " + + s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})") + Shape(indexedAxes.map(get).toList) } - def remove(axis: Int*): Shape = { - val indexedAxis = indexAxis(axis) + def remove(axes: Int*): Shape = { + val indexedAxes = indexAxes(axes) require( - indexedAxis.forall(i => i < rank && i >= 0), - s"the number of removed axis " + - s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})") + indexedAxes.forall(i => i < rank && i >= 0), + s"the number of removed axes " + + s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})") val filteredDims = dims.zipWithIndex .filter { case (_, i) => - !indexedAxis.contains(i) + !indexedAxes.contains(i) } .map { case (dim, _) => dim } Shape(filteredDims) @@ -214,27 +214,30 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] { def updated(axis: Int, value: Int): Shape = updateAll(value)(axis) - def updateAll(value: Int)(axis: Int*): Shape = { - val indexedAxis = indexAxis(axis) + def updateAll(value: Int)(axes: Int*): Shape = { + val indexedAxes = indexAxes(axes) require( - indexedAxis.forall(i => i < rank && i >= 0), - s"the number of updated axis " + - s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})") + indexedAxes.forall(i => i < rank && i >= 0), + s"the number of updated axes " + + s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})") val updatedDims = dims.zipWithIndex.map { case (dim, i) => - if (indexedAxis.contains(i)) value else dim + if (indexedAxes.contains(i)) value else dim } Shape(updatedDims) } - def updateAllExcept(value: Int)(axis: Int*): Shape = { - val indexedAxis = indexAxis(axis) - val axisToUpdate = dims.indices.toSet -- indexedAxis.toSet - updateAll(value)(axisToUpdate.toList: _*) + def updateAllExcept(value: Int)(axes: Int*): Shape = { + val indexedAxes = indexAxes(axes) + val axesToUpdate = dims.indices.toSet -- indexedAxes.toSet + updateAll(value)(axesToUpdate.toList: _*) } - private def indexAxis(axis: Seq[Int]): Seq[Int] = - axis.map(a => if (a == -1) dims.size - 1 else a) + def indexAxes(axes: Seq[Int]): Seq[Int] = + axes.map(indexAxis) + + def indexAxis(axis: Int): Int = + if (axis == -1) dims.size - 1 else axis def minus(other: Shape): Shape = { require(broadcastableAny(other), s"cannot $this - $other") diff --git a/src/main/scala/scanet/math/alg/AllKernels.scala b/src/main/scala/scanet/math/alg/AllKernels.scala index 0c99fdf..266810b 100644 --- a/src/main/scala/scanet/math/alg/AllKernels.scala +++ b/src/main/scala/scanet/math/alg/AllKernels.scala @@ -23,8 +23,8 @@ case class Plus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] { current: Expr[A], parentGrad: Expr[R]): Seq[Expr[R]] = { val parentShape = parentGrad.shape - val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList - val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList + val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList + val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList List( parentGrad.sum(shrinkLeftAxis).reshape(left.shape), parentGrad.sum(shrinkRightAxis).reshape(right.shape)) @@ -74,8 +74,8 @@ case class Minus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] { current: Expr[A], parentGrad: Expr[R]): Seq[Expr[R]] = { val parentShape = parentGrad.shape - val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList - val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList + val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList + val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList List( parentGrad.sum(shrinkLeftAxis).reshape(left.shape), -parentGrad.sum(shrinkRightAxis).reshape(right.shape)) @@ -111,8 +111,8 @@ case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends current: Expr[A], parentGrad: Expr[R]): Seq[Expr[R]] = { val parentShape = parentGrad.shape - val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList - val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList + val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList + val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList List( (right.cast[R] * parentGrad).sum(shrinkLeftAxis).reshape(left.shape), (left.cast[R] * parentGrad).sum(shrinkRightAxis).reshape(right.shape)) @@ -137,7 +137,7 @@ case class Pow[A: Numeric](expr: Expr[A], exponent: Expr[Float]) extends Expr[A] } } -case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { +case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self => override def name: String = "Sqrt" override def tpe: Option[TensorType[A]] = Some(TensorType[A]) override def shape: Shape = expr.shape @@ -147,12 +147,46 @@ case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { override def calc[R: Floating]( current: Expr[A], parentGrad: Expr[R]): Seq[Expr[R]] = { - val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R] - List(local * parentGrad) + // val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R] + // List(local * parentGrad) + List(SqrtGrad(self.cast[R], parentGrad)) } } } +case class SqrtGrad[A: Numeric](sqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] { + override def name: String = "SqrtGrad" + override def tpe: Option[TensorType[A]] = Some(TensorType[A]) + override def shape: Shape = sqrt.shape + override def inputs: Seq[Expr[_]] = Seq(sqrt, parentGrad) + override def compiler: Compiler[A] = DefaultCompiler[A]() +} + +case class Rsqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self => + override def name: String = "Rsqrt" + override def tpe: Option[TensorType[A]] = Some(TensorType[A]) + override def shape: Shape = expr.shape + override def inputs: Seq[Expr[_]] = Seq(expr) + override def compiler: Compiler[A] = DefaultCompiler[A]() + override def localGrad: Grad[A] = new Grad[A] { + override def calc[R: Floating]( + current: Expr[A], + parentGrad: Expr[R]): Seq[Expr[R]] = { + // val local = (expr.cast[R] ^ -1.5f) * -0.5f.const.cast[R] + // List(local * parentGrad) + List(RsqrtGrad(self.cast[R], parentGrad)) + } + } +} + +case class RsqrtGrad[A: Numeric](rsqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] { + override def name: String = "RsqrtGrad" + override def tpe: Option[TensorType[A]] = Some(TensorType[A]) + override def shape: Shape = rsqrt.shape + override def inputs: Seq[Expr[_]] = Seq(rsqrt, parentGrad) + override def compiler: Compiler[A] = DefaultCompiler[A]() +} + case class Exp[A: Numeric](expr: Expr[A]) extends Expr[A] { override def name: String = "Exp" override def tpe: Option[TensorType[A]] = Some(TensorType[A]) @@ -182,8 +216,8 @@ case class Div[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] { current: Expr[A], parentGrad: Expr[R]): Seq[Expr[R]] = { val parentShape = parentGrad.shape - val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList - val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList + val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList + val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList List( (parentGrad / right.cast[R]).sum(shrinkLeftAxis).reshape(left.shape), (-left.cast[R] * parentGrad / right.sqr.cast[R]) @@ -452,6 +486,8 @@ trait AllKernels { def sqrt[A: Numeric](expr: Expr[A]): Expr[A] = Sqrt(expr) + def rsqrt[A: Numeric](expr: Expr[A]): Expr[A] = Rsqrt(expr) + def sqrtZeroSafe[A: Numeric](out: Expr[A], epsilon: Expr[A]): Expr[A] = sqrt(plus(out, epsilon)) @@ -466,6 +502,20 @@ trait AllKernels { keepDims: Boolean = false): Expr[A] = Mean(expr, axis, keepDims) def mean[A: Numeric](expr: Expr[A]): Expr[A] = mean(expr, 0 until expr.rank) + def moments[A: Numeric]( + expr: Expr[A], + axis: Seq[Int], + keepDims: Boolean = false): (Expr[A], Expr[A]) = { + val m = mean(expr, axis, keepDims) + // try squared_difference, it has optimized kernel op + val v = mean((expr - m).sqr, axis, keepDims) + (m, v) + } + + def moments[A: Numeric]( + expr: Expr[A]): (Expr[A], Expr[A]) = + moments(expr, 0 until expr.rank) + def max[A: TensorType, C](left: Expr[A], right: C)(implicit c: Convertible[C, Expr[A]]): Expr[A] = Max(left, c.convert(right)) @@ -597,6 +647,12 @@ object kernels extends AllKernels { */ def sqr: Expr[A] = pow(2.0f) + /** Computes reciprocal (inversed) of square root of x element-wise: `1 / sqrt(x))` + * + * @return tensor `^` -0.5 + */ + def rsqrt: Expr[A] = f.rsqrt(expr) + /** Returns square root of the given tensor * * {{{Tensor.vector(1.0f, 4.0f, 9.0f).const.sqrt.eval should be(Tensor.vector(1.0f, 2.0f, 3.0f))}}} @@ -676,6 +732,23 @@ object kernels extends AllKernels { */ def mean: Expr[A] = f.mean(expr) + /** Computes the frequency-weighted mean and variance across dimensions of a tensor. + * + * Reduces `(mean, variance)` along the dimensions given in `axis`. + * The rank of the tensor is reduced by 1 for each entry in `axis`. + * + * @param axis to sum + * @return tensors `(mean, variance)` + */ + def moments(axis: Seq[Int], keepDims: Boolean = false): (Expr[A], Expr[A]) = + f.moments(expr, axis, keepDims) + + /** Computes the frequency-weighted mean and variance across all dimensions of a tensor. + * * + * @return tensors `(mean, variance)` + */ + def moments: (Expr[A], Expr[A]) = f.moments(expr) + /** Shuffle dimensions of `out` according to a permutation. * * {{{ diff --git a/src/main/scala/scanet/math/nn/AllKernels.scala b/src/main/scala/scanet/math/nn/AllKernels.scala index bbe5525..1467ec1 100644 --- a/src/main/scala/scanet/math/nn/AllKernels.scala +++ b/src/main/scala/scanet/math/nn/AllKernels.scala @@ -2,10 +2,10 @@ package scanet.math.nn import scanet.core import scanet.core.Require.fail -import scanet.core.syntax._ import scanet.core._ import scanet.math.nn.ConvFormat._ import scanet.math.nn.Padding._ +import scanet.math.syntax._ import scala.collection.immutable.Seq @@ -155,7 +155,11 @@ case class Conv2D[A: Floating] private ( // input = (batch_shape, in_height, in_width, in_channels) // filters = (filter_height, filter_width, in_channels, out_channels) // output = (batch_shape, out_height, out_width, out_channels) - val convolved = padding.shape(format, input.shape, format.shapeOf(filters.shape(0), filters.shape(1)), strides) + val convolved = padding.shape( + format, + input.shape, + format.shapeOf(filters.shape(0), filters.shape(1)), + strides) convolved.updated(format.cAxis, filters.shape(3)) } @@ -392,6 +396,161 @@ object Pool2D { } } +case class FusedBatchNorm[A: Floating]( + input: Expr[A], + scale: Expr[A], + offset: Expr[A], + mean: Expr[A], + variance: Expr[A], + format: ConvFormat, + training: Boolean, + epsilon: Option[Float] = None, + exponentialAvgFactor: Option[Float] = None) + extends Expr[A] { + override def name: String = "FusedBatchNormV3" + override def tpe: Option[TensorType[A]] = Some(TensorType[A]) + override def shape: Shape = input.shape + override val inputs: Seq[Expr[_]] = Seq(input, scale, offset, mean, variance) + override def compiler: core.Compiler[A] = { + // find out good wat to update state with seq of optional values + val comp1 = DefaultCompiler[A]() + .withAttr("data_format", format.name) + .withAttr("is_training", training) + val comp2 = epsilon.fold(comp1)(e => comp1.withAttr("epsilon", e)) + val comp3 = exponentialAvgFactor.fold(comp2)(e => comp1.withAttr("exponential_avg_factor", e)) + comp3 + } + override def localGrad: Grad[A] = new Grad[A] { + override def calc[R: Floating]( + current: Expr[A], + parentGrad: Expr[R]): Seq[Expr[R]] = { + val outs = outputs + val grads = FusedBatchNormGrad[R]( + input.cast[R], + parentGrad, + scale.cast[R], + outs.reserveSpace1.cast[R], + outs.reserveSpace2.cast[R], + outs.reserveSpace3.cast[R], + format, + training, + epsilon).outputs + Seq(grads.inputGrad, grads.scaleGrad, grads.offsetGrad) + } + } + def outputs: FusedBatchNormOutputs[A] = { + // reserveSpace1 + reserveSpace2 + reserveSpace3 shape? + val channelShape = mean.shape + FusedBatchNormOutputs( + output = this, + batchMean = TakeOut[A](this, 1, channelShape), + batchVariance = TakeOut[A](this, 1, channelShape), + reserveSpace1 = TakeOut[A](this, 2, channelShape), + reserveSpace2 = TakeOut[A](this, 3, channelShape), + reserveSpace3 = TakeOut[A](this, 4, channelShape)) + } +} + +trait BatchNormOutputs[A] { + def output: Expr[A] + def batchMean: Expr[A] + def batchVariance: Expr[A] + def mapMean(f: Expr[A] => Expr[A]): BatchNormOutputs[A] + def mapVariance(f: Expr[A] => Expr[A]): BatchNormOutputs[A] +} + +/** Outputs of [[FusedBatchNorm]] operator + * @param output A 4D Output Tensor + * @param batchMean A 1D Tensor for the computed batch mean, to be used by TensorFlow to compute the running mean. + * @param batchVariance A 1D Tensor for the computed batch variance, to be used by TensorFlow to compute the running variance. + * @param reserveSpace1 A 1D Tensor for the computed batch mean, to be reused in the gradient computation. + * @param reserveSpace2 A 1D Tensor for the computed batch variance (inverted variance in the cuDNN case), to be reused in the gradient computation. + * @param reserveSpace3 A 1D Tensor for some intermediate results, to be reused in the gradient computation for better efficiency. + */ +case class FusedBatchNormOutputs[A]( + output: Expr[A], + batchMean: Expr[A], + batchVariance: Expr[A], + reserveSpace1: Expr[A], + reserveSpace2: Expr[A], + reserveSpace3: Expr[A]) + extends BatchNormOutputs[A] { + override def mapMean(f: Expr[A] => Expr[A]): BatchNormOutputs[A] = + copy(batchMean = f(batchMean)) + override def mapVariance(f: Expr[A] => Expr[A]): BatchNormOutputs[A] = + copy(batchVariance = f(batchVariance)) +} + +/** Outputs of batch norm + * @param output An Output Tensor + * @param batchMean A Tensor for the computed batch mean, to be used by TensorFlow to compute the running mean. + * @param batchVariance A Tensor for the computed batch variance, to be used by TensorFlow to compute the running variance. + */ +case class StdBatchNormOutputs[A]( + output: Expr[A], + batchMean: Expr[A], + batchVariance: Expr[A]) + extends BatchNormOutputs[A] { + override def mapMean(f: Expr[A] => Expr[A]): BatchNormOutputs[A] = + copy(batchMean = f(batchMean)) + override def mapVariance(f: Expr[A] => Expr[A]): BatchNormOutputs[A] = + copy(batchVariance = f(batchVariance)) +} + +/** @param input A 4D Tensor for input data + * @param parentGrad A 4D Tensor for the gradient with respect to output + * @param scale A 1D Tensor for scaling factor, to scale the normalized input. + * @param reserveSpace1 When `training` is `true`, a 1D Tensor for the computed batch mean to be reused in gradient computation. + * When `training` is `false`, a 1D Tensor for the population mean to be reused + * in both 1st and 2nd order gradient computation. + * @param reserveSpace2 When `training` is `true`, a 1D Tensor for the computed batch variance + * (inverted variance in the cuDNN case) to be reused in gradient computation. + * When `training` is `false`, a 1D Tensor for the population variance to be reused + * in both 1st and 2nd order gradient computation. + * @param reserveSpace3 When `training` is `true`, a 1D Tensor for some intermediate results to be reused in gradient computation. + * When `training` is `false`, a dummy empty Tensor will be created. + * @param format One of [[NCHW]] or [[NHWC]] + * @param training A bool value to indicate the operation is for training (default) or inference. + * @param epsilon A small float number added to the variance of input + */ +case class FusedBatchNormGrad[A: Floating]( + input: Expr[A], + parentGrad: Expr[A], + scale: Expr[A], + reserveSpace1: Expr[A], + reserveSpace2: Expr[A], + reserveSpace3: Expr[A], + format: ConvFormat, + training: Boolean, + epsilon: Option[Float] = None) + extends Expr[A] { + override def name: String = "FusedBatchNormGradV3" + override def tpe: Option[TensorType[A]] = Some(TensorType[A]) + override def shape: Shape = input.shape + override val inputs: Seq[Expr[_]] = + Seq(parentGrad, input, scale, reserveSpace1, reserveSpace2, reserveSpace3) + override def compiler: core.Compiler[A] = { + val comp1 = DefaultCompiler[A]() + .withAttr("data_format", format.name) + .withAttr("is_training", training) + epsilon.fold(comp1)(e => comp1.withAttr("epsilon", e)) + } + def outputs: FusedBatchNormGradOutputs[A] = { + val channelShape = scale.shape + FusedBatchNormGradOutputs( + inputGrad = this, + scaleGrad = TakeOut[A](this, 1, channelShape), + offsetGrad = TakeOut[A](this, 2, channelShape)) + } +} + +/** Outputs of [[FusedBatchNormGrad]] operator + * @param inputGrad A 4D Tensor for the gradient with respect to input + * @param scaleGrad A 1D Tensor for the gradient with respect to scale + * @param offsetGrad A 1D Tensor for the gradient with respect to offset + */ +case class FusedBatchNormGradOutputs[A](inputGrad: Expr[A], scaleGrad: Expr[A], offsetGrad: Expr[A]) + trait AllKernels { /** Computes a 2-D convolution given input and 4-D filters tensors. @@ -460,6 +619,69 @@ trait AllKernels { format: ConvFormat = NHWC, reduce: Reduce = Reduce.Max): Expr[A] = Pool2D(input, window, strides, padding, format, reduce) + + /** Batch normalization. Note that the size of 4D Tensors are defined by either [[NCHW]] or [[NHWC]]. + * + * @param input A 4D Tensor for input data + * @param scale A 1D Tensor for scaling factor to scale the normalized input + * @param offset A 1D Tensor for offset, to shift to the normalized input + * @param mean A 1D Tensor for population mean. Used for inference only; must be empty for training + * @param variance A 1D Tensor for population variance. Used for inference only; must be empty for training + * @param format One of [[NCHW]] or [[NHWC]] + * @param training A bool value to indicate the operation is for training (default) or inference. + * @param epsilon A small float number added to the variance of input + * @param expAvgFactor The exponential avg factor + * @return Outputs + */ + def fusedBatchNorm[A: Floating]( + input: Expr[A], + scale: Expr[A], + offset: Expr[A], + mean: Expr[A], + variance: Expr[A], + format: ConvFormat, + training: Boolean, + epsilon: Option[Float] = None, + expAvgFactor: Option[Float] = None): FusedBatchNormOutputs[A] = { + // format: off + FusedBatchNorm(input, scale, offset, mean, variance, format, training, epsilon, expAvgFactor).outputs + // format: on + } + + /** Batch normalization. Supports input of any shape + * + * @param input Tensor for input data + * @param scale Tensor for scaling factor to scale the normalized input + * @param offset Tensor for offset, to shift to the normalized input + * @param mean A Tensor for population mean. Used for inference only; must be empty for training + * @param variance A Tensor for population variance. Used for inference only; must be empty for training + * @param training A bool value to indicate the operation is for training (default) or inference. + * @param axes Axes that should be normalized (typically the features axes). + * @param epsilon A small float number added to the variance of input + * @return output + */ + def batchNorm[A: Floating]( + input: Expr[A], + scale: Expr[A], + offset: Expr[A], + mean: Expr[A], + variance: Expr[A], + training: Boolean, + axes: Seq[Int] = Seq(-1), + epsilon: Float = 1e-3f): StdBatchNormOutputs[A] = { + val (batchMean, batchVariance) = + if (training) { + val reduceAxis = input.shape.axesExcept(axes: _*) + input.moments(reduceAxis, keepDims = true) + } else { + (mean, variance) + } + val epsilonE = epsilon.const.cast[A] + // use rsqrt instead of sqrt to increase performance ~20% + val inv = rsqrt(batchVariance + epsilonE) * scale + val output = ((input * inv) - (batchMean * inv)) + offset + StdBatchNormOutputs(output, batchMean, batchVariance) + } } object kernels { diff --git a/src/main/scala/scanet/models/layer/BatchNorm.scala b/src/main/scala/scanet/models/layer/BatchNorm.scala index 1433126..a43b992 100644 --- a/src/main/scala/scanet/models/layer/BatchNorm.scala +++ b/src/main/scala/scanet/models/layer/BatchNorm.scala @@ -1,11 +1,14 @@ package scanet.models.layer +import scanet.core.Require.fail import scanet.core._ +import scanet.math.nn.ConvFormat import scanet.math.syntax.zeros import scanet.models.Aggregation.Avg import scanet.models.layer.BatchNorm.{Beta, Gamma, MovingMean, MovingVariance} import scanet.models.{Initializer, ParamDef, Regularization} import scanet.syntax._ +import scala.collection.immutable.Seq /** Layer that normalizes its inputs. * @@ -42,7 +45,7 @@ import scanet.syntax._ * * Reference [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). * - * @param axis Axis that should be normalized (typically the features axis). + * @param axes Axes that should be normalized (typically the features axes). * @param momentum Momentum for the moving average. * @param epsilon Small float added to variance to avoid dividing by zero. * @param betaInitializer Initializer for the beta weight @@ -51,10 +54,14 @@ import scanet.syntax._ * @param movingVarianceInitializer Initializer for the moving variance. * @param betaRegularizer Regularizer for the beta weight. * @param gammaRegularizer Regularizer for the gamma weight. + * @param fused if `true`, use a faster, fused implementation, or raise an error + * if the fused implementation cannot be used. + * If `false`, do not used the fused implementation. + * If `None`, use the faster implementation if possible. * @param trainable Whether layer is trainable */ case class BatchNorm( - axis: Seq[Int] = Seq(-1), + axes: Seq[Int] = Seq(-1), momentum: Float = 0.99f, epsilon: Float = 1e-3f, betaInitializer: Initializer = Initializer.Zeros, @@ -63,6 +70,7 @@ case class BatchNorm( movingVarianceInitializer: Initializer = Initializer.Ones, betaRegularizer: Regularization = Regularization.Zero, gammaRegularizer: Regularization = Regularization.Zero, + fused: Option[Boolean] = None, override val trainable: Boolean = true) extends Layer { @@ -74,7 +82,7 @@ case class BatchNorm( // so result shape = (1, 4, 6) // note 1: in case of reduction operation, such as mean - it works vice-versa, specified dimension will become 1 // note 2: we keep all 1 dimensions without squeezing to perform proper broadcast with complex deep shapes - val reduceAxis = input.axisExcept(axis: _*) + val reduceAxis = input.axesExcept(axes: _*) input.updateAll(1)(reduceAxis: _*) } @@ -87,33 +95,64 @@ case class BatchNorm( MovingVariance -> ParamDef(shape, movingVarianceInitializer, Some(Avg))) } + private def fusedDataFormat(input: Shape): Option[ConvFormat] = { + val indexAxes = input.indexAxes(axes) + (indexAxes, input.rank) match { + case (Seq(1), 4) => Some(ConvFormat.NCHW) + case (Seq(3), 4) => Some(ConvFormat.NHWC) + case _ => None + } + } + override def build[E: Floating]( input: Expr[E], params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { - val prevMovingMean = params(MovingMean) - val prevMovingVariance = params(MovingVariance) - val (movingMean, movingVariance) = + + val mean = params(MovingMean) + val variance = params(MovingVariance) + val gamma = params(Gamma) + val beta = params(Beta) + + val result = (fused, fusedDataFormat(input.shape)) match { + case (Some(true) | None, Some(format)) => + // fused optimized kernel version, works only for input of shape NCHW, NHWC + fusedBatchNorm( + input = input, + scale = gamma.squeeze, + offset = beta.squeeze, + mean = mean.squeeze, + variance = variance.squeeze, + format = format, + training = trainable, + epsilon = Some(epsilon)) + .mapMean(_.reshape(mean.shape)) + .mapVariance(_.reshape(variance.shape)) + case (Some(true), None) => + fail(s"Cannot use fused implementation with input shape ${input.shape} " + + s"and axes: ${axes.mkString("[", ",", "]")}") + case (_, _) => + // generic batch norm, works for any input + batchNorm( + input = input, + scale = gamma, + offset = beta, + mean = mean, + variance = variance, + training = trainable, + axes = axes, + epsilon = epsilon) + } + + val nextParams: Params[Expr[E]] = if (trainable) { val momentumE = momentum.const.cast[E] - val reduceAxis = input.shape.axisExcept(axis: _*) - val batchMean = input.mean(reduceAxis, keepDims = true) - val batchVariance = (input - batchMean).sqr.mean(reduceAxis, keepDims = true) - val movingMean = prevMovingMean.decayingAvg(batchMean, momentumE) - val movingVariance = prevMovingVariance.decayingAvg(batchVariance, momentumE) - (movingMean, movingVariance) + val movingMean = mean.decayingAvg(result.batchMean, momentumE) + val movingVariance = variance.decayingAvg(result.batchVariance, momentumE) + Params(MovingMean -> movingMean, MovingVariance -> movingVariance) } else { - (prevMovingMean, prevMovingVariance) + Params.empty } - val epsilonE = epsilon.const.cast[E] - val output = - (input - movingMean) * params(Gamma) / - (movingVariance.sqrt + epsilonE) - params(Beta) - val nextState: Params[Expr[E]] = - if (trainable) Params( - MovingMean -> movingMean, - MovingVariance -> movingVariance) - else Params.empty - (output, nextState) + (result.output, nextParams) } override def penalty[E: Floating](params: Params[Expr[E]]): Expr[E] = @@ -124,7 +163,7 @@ case class BatchNorm( override def makeTrainable(trainable: Boolean): BatchNorm = copy(trainable = trainable) - override def toString: String = s"BatchNorm($axis)" + override def toString: String = s"BatchNorm($axes)" } object BatchNorm { diff --git a/src/test/scala/scanet/core/ViewSpec.scala b/src/test/scala/scanet/core/ViewSpec.scala index e0282ff..53ae98e 100644 --- a/src/test/scala/scanet/core/ViewSpec.scala +++ b/src/test/scala/scanet/core/ViewSpec.scala @@ -175,29 +175,29 @@ class ViewSpec extends AnyWordSpec with Matchers { "have broadcastable axis operation" which { "should leave higher dimensions from first shape if it is bigger" in { - Shape(2, 3, 4) broadcastableAxis Shape(3, 4) should be(Seq(0)) + Shape(2, 3, 4) broadcastableAxes Shape(3, 4) should be(Seq(0)) } "should return dimension index with size one" in { - Shape(2, 3, 4) broadcastableAxis Shape(2, 3, 1) should be(Seq(2)) + Shape(2, 3, 4) broadcastableAxes Shape(2, 3, 1) should be(Seq(2)) } "should return empty shape if both are equal" in { - Shape(2, 3, 4) broadcastableAxis Shape(2, 3, 4) should be(Seq()) + Shape(2, 3, 4) broadcastableAxes Shape(2, 3, 4) should be(Seq()) } "should return empty shape if other dimension is bigger" in { - Shape(2, 3, 4) broadcastableAxis Shape(1, 2, 3, 4) should be(Seq()) + Shape(2, 3, 4) broadcastableAxes Shape(1, 2, 3, 4) should be(Seq()) } "should same shape if other dimension is empty" in { - Shape(2, 3, 4) broadcastableAxis Shape() should be(Seq(0, 1, 2)) + Shape(2, 3, 4) broadcastableAxes Shape() should be(Seq(0, 1, 2)) } "should fail if shapes are incompatible" in { the[IllegalArgumentException] thrownBy { - Shape(2, 3, 4) broadcastableAxis Shape(2, 5, 4) - } should have message "requirement failed: cannot find broadcastable axis for (2, 3, 4) and (2, 5, 4)" + Shape(2, 3, 4) broadcastableAxes Shape(2, 5, 4) + } should have message "requirement failed: cannot find broadcastable axes for (2, 3, 4) and (2, 5, 4)" } } diff --git a/src/test/scala/scanet/math/alg/KernelsSpec.scala b/src/test/scala/scanet/math/alg/KernelsSpec.scala index 67e006e..0894af3 100644 --- a/src/test/scala/scanet/math/alg/KernelsSpec.scala +++ b/src/test/scala/scanet/math/alg/KernelsSpec.scala @@ -377,11 +377,26 @@ class KernelsSpec extends AnyWordSpec with Matchers { } "calculate a gradient" in { + // grad = dy * 0.5 / y, where y = sqrt val x = Tensor.vector(1.0f, 4.0f, 16.0f).const x.sqrt.sum.grad(x).returns[Float].eval should be(Tensor.vector(0.5f, 0.25, 0.125f)) } } + "rsqrt" should { + + "compute reciprocal root of tensor" in { + Tensor.vector(1.0f, 4.0f, 9.0f).const.rsqrt.roundAt(3).eval should be( + Tensor.vector(1.0f, 0.5f, 0.333f)) + } + + "calculate a gradient" in { + val x = Tensor.vector(1.0f, 4.0f, 16.0f).const + x.rsqrt.sum.grad(x).returns[Float].roundAt(3).eval should be( + Tensor.vector(-0.5f, -0.062f, -0.008f)) + } + } + "sum" should { "calculate sum across all axis by default" in { diff --git a/src/test/scala/scanet/math/logical/KernelsSpec.scala b/src/test/scala/scanet/math/logical/KernelsSpec.scala index ddcbd2f..0e484c9 100644 --- a/src/test/scala/scanet/math/logical/KernelsSpec.scala +++ b/src/test/scala/scanet/math/logical/KernelsSpec.scala @@ -92,7 +92,7 @@ class KernelsSpec extends AnyFlatSpec with Matchers { it should "fail when given axis is out of bound" in { the[IllegalArgumentException] thrownBy { Tensor.matrix(Array(true, false), Array(true, true)).const.all(Seq(2)).eval - } should have message "requirement failed: the number of removed axis should be less or equal to rank 2, but was (2)" + } should have message "requirement failed: the number of removed axes should be less or equal to rank 2, but was (2)" } "any" should "return true if at least one element is true" in { @@ -116,7 +116,7 @@ class KernelsSpec extends AnyFlatSpec with Matchers { it should "fail when given axis is out of bound" in { the[IllegalArgumentException] thrownBy { Tensor.matrix(Array(true, false), Array(true, true)).const.any(Seq(2)).eval - } should have message "requirement failed: the number of removed axis should be less or equal to rank 2, but was (2)" + } should have message "requirement failed: the number of removed axes should be less or equal to rank 2, but was (2)" } "greater comparison" should "work element wise" in { diff --git a/src/test/scala/scanet/models/ANNSpec.scala b/src/test/scala/scanet/models/ANNSpec.scala index 07cebc6..1ecd4bd 100644 --- a/src/test/scala/scanet/models/ANNSpec.scala +++ b/src/test/scala/scanet/models/ANNSpec.scala @@ -20,7 +20,7 @@ class ANNSpec extends AnyWordSpec with CustomMatchers with SharedSpark with Data "minimize logistic regression" in { val ds = logisticRegression - val model = Dense(4, Sigmoid) >> BatchNorm() >> Dense(1, Sigmoid) + val model = Dense(4, Sigmoid) >> Dense(1, Sigmoid) val trained = ds .train(model) .loss(BinaryCrossentropy) diff --git a/src/test/scala/scanet/models/CNNSpec.scala b/src/test/scala/scanet/models/CNNSpec.scala index ca26e31..7205241 100644 --- a/src/test/scala/scanet/models/CNNSpec.scala +++ b/src/test/scala/scanet/models/CNNSpec.scala @@ -6,7 +6,7 @@ import scanet.core.Shape import scanet.estimators.accuracy import scanet.models.Activation._ import scanet.models.Loss._ -import scanet.models.layer.{Activate, Conv2D, Dense, Flatten, Pool2D} +import scanet.models.layer.{Activate, BatchNorm, Conv2D, Dense, Flatten, Pool2D} import scanet.optimizers.Adam import scanet.optimizers.Effect.{RecordAccuracy, RecordLoss} import scanet.optimizers.syntax._ @@ -21,7 +21,7 @@ class CNNSpec extends AnyWordSpec with CustomMatchers with SharedSpark with Data "train on MNIST dataset" in { val (trainingDs, testDs) = MNIST() val model = - Conv2D(32, activation = ReLU()) >> Pool2D(strides = (2, 2)) >> + Conv2D(32, activation = ReLU()) >> BatchNorm(fused = Some(true)) >> Pool2D(strides = (2, 2)) >> Conv2D(64, activation = ReLU()) >> Pool2D(strides = (2, 2)) >> Activate(ReLU()) >> Flatten >> Dense(64, ReLU()) >> Dense(10, Softmax) val trained = trainingDs diff --git a/src/test/scala/scanet/models/layer/BatchNormSpec.scala b/src/test/scala/scanet/models/layer/BatchNormSpec.scala index 17b07db..65936f1 100644 --- a/src/test/scala/scanet/models/layer/BatchNormSpec.scala +++ b/src/test/scala/scanet/models/layer/BatchNormSpec.scala @@ -38,10 +38,10 @@ class BatchNormSpec extends AnyWordSpec with CustomMatchers { val out = outExpr.const.roundAt(3).eval val state = stateExpr.mapValues(_.const.roundAt(3)).eval out shouldBe Tensor.matrix( - Array(0.374f, -0.81f, -0.564f, 0.0f), - Array(-0.45f, 1.396f, 0.573f, 3.327f), - Array(1.642f, -0.121f, 0.372f, 0.751f), - Array(3.274f, 2.774f, 3.214f, 2.361f)) + Array(-0.595f, -1.168f, -1.042f, -1.231f), + Array(-1.181f, 0.422f, -0.232f, 1.313f), + Array(0.307f, -0.671f, -0.375f, -0.656f), + Array(1.469f, 1.417f, 1.649f, 0.574f)) state("2" / MovingMean) shouldBe Tensor.matrix( Array(7.638f, 2.938f, 5.375f, 3.0f)) state("2" / MovingVariance) shouldBe Tensor.matrix( @@ -60,8 +60,8 @@ class BatchNormSpec extends AnyWordSpec with CustomMatchers { out shouldBe Tensor.matrix( Array(1.118f, -1.118f, -0.671f, -0.447f), Array(-0.045f, 0.671f, 0.85f, 0.939f), - Array(2.906f, -0.559f, 0.581f, -0.134f), - Array(5.209f, 1.788f, 4.382f, 0.537f)) + Array(2.907f, -0.559f, 0.581f, -0.134f), + Array(5.21f, 1.789f, 4.383f, 0.537f)) } } }