Skip to content

Commit

Permalink
prioritized-exp-replay
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-1 committed Oct 3, 2020
1 parent b4f8ec3 commit e3006eb
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ A very light weight Scala machine learning library that provide some basic ML al

- [x] Asynchronous Advantage Actor-Critic (A3C) [[Code]](src/main/scala/algorithm/reinforcement/A3C.scala) [[Usage]](src/test/scala/algorithm/reinforcement/A3CTest.scala)

- [x] Prioritized Experience Replay (PER-DQN) [[Code]](src/main/scala/algorithm/reinforcement/PER.scala) [[Usage]](src/test/scala/algorithm/reinforcement/PERTest.scala)

### Feature Analysis :

- [x] Student-T Test [[Code]](src/main/scala/algorithm/analysis/StudentT.scala) [[Usage]](src/test/scala/algorithm/analysis/StudentTTest.scala)
Expand All @@ -125,8 +127,6 @@ A very light weight Scala machine learning library that provide some basic ML al

## TODO

- [ ] Polarize Experience Replay - Deep Reinforcement Learning

- [ ] Rainbow - Deep Reinforcement Learning

- [ ] Alpha-go Zero (MCTS-NN) - Deep Reinforcement Learning
Expand Down
31 changes: 22 additions & 9 deletions src/main/scala/algorithm/deeplearning/NeuralNetwork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Node(
var output: Double = 0.0
/** Error derivative with respect to this node's output. */
var outputDer: Double = 0.0
var rawOutputDer: Double = 0.0
/** Error derivative with respect to this node's total input. */
var inputDer: Double = 0.0
/**
Expand Down Expand Up @@ -264,15 +265,17 @@ class NeuralNetwork {
*/
def backProp(
targets: Array[Double],
errorFunc: ErrorFunction = SQUARE
errorFunc: ErrorFunction = SQUARE,
_outputWeights: Array[Double] = Array.fill[Double](networkShape.last)(1.0)
): Unit = {
val outputNodes = network.last
// The output node is a special case. We use the user-defined error
// function for the derivative.
for((node, target) <- outputNodes.zip(targets)) {
node.outputDer = errorFunc.der(node.output, target)
for((node, target) <- getOutputNodes.zip(targets)) {
node.rawOutputDer = errorFunc.der(node.output, target)
}
for((node, weight) <- getOutputNodes.zip(_outputWeights)) {
node.outputDer = node.rawOutputDer * weight
}

// Go through the layers backwards.
for(layerIdx <- network.length - 1 to 1 by -1) {
val currentLayer = network(layerIdx)
Expand Down Expand Up @@ -374,9 +377,13 @@ class NeuralNetwork {
def clear() = reset(false)

/** Train one inputs to one targets, moved and Modified from Playground. */
def trainOne(inputs: Array[Double], targets: Array[Double], errorFunc: ErrorFunction = SQUARE): Unit = {
def trainOne(
inputs: Array[Double], targets: Array[Double],
errorFunc: ErrorFunction = SQUARE,
_outputWeights: Array[Double] = Array.fill[Double](networkShape.last)(1.0)
): Unit = {
forwardProp(inputs)
backProp(targets, errorFunc)
backProp(targets, errorFunc, _outputWeights)
if((index - updateIndex + 1) % batchSize == 0) {
updateIndex = index
updateWeights()
Expand All @@ -388,10 +395,16 @@ class NeuralNetwork {
def predictOne = forwardProp _

/** Train all data */
def train(x: Array[Array[Double]], y: Array[Array[Double]], errorFunc: ErrorFunction = SQUARE, iter: Int = 1, _learningRate: Double = learningRate): Boolean = {
def train(
x: Array[Array[Double]], y: Array[Array[Double]],
errorFunc: ErrorFunction = SQUARE,
iter: Int = 1,
_learningRate: Double = learningRate,
_outputWeights: Array[Double] = Array.fill[Double](networkShape.last)(1.0)
): Boolean = {
learningRate = _learningRate
val data = x.zip(y)
for(i <- 0 until iter) data.foreach { case (inputs, targets) => trainOne(inputs, targets, errorFunc) }
for(i <- 0 until iter) data.foreach { case (inputs, targets) => trainOne(inputs, targets, errorFunc, _outputWeights) }
true
}

Expand Down
158 changes: 158 additions & 0 deletions src/main/scala/algorithm/reinforcement/PER.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Wei Chen - Prioritized Experience Replay (PER)
// 2017-08-31

package com.scalaml.algorithm

// nextstate, reward, end = simulator(state, action)
class PER(
val layer_neurons: Array[Int],
val initparas: Array[Double],
val actnumber: Int,
val simulator: (Array[Double], Int) => (Array[Double], Double, Boolean),
val batchsize_number: Int = 100,
val epsilon_saturation_number: Int = 10000,
val train_number: Int = 10,
val nn_learning_rate: Double = 0.01,
val prior_eps: Double = 1e-6,
val alpha: Double = 0.6,
var beta: Double = 0.6
) {

val nn = new NeuralNetwork()
nn.config(initparas.size +: layer_neurons :+ actnumber,
_batchSize = batchsize_number, _gradientClipping = true)
val ex = new Exp

class Exp {
var c = 0
var x = Array[Array[Double]]()
var y = Array[Array[Double]]()
var max_priority = 1.0
var fin_priority = Array[Double]()

def consume = {
val indices = _sample_proportional()
var nx = Array[Array[Double]]()
var ny = Array[Array[Double]]()
var nw = Array[Double]()
for (i <- indices) {
nx :+= x(i)
ny :+= y(i)
nw :+= _calculate_weight(i, beta)
}

nn.train(
nx, ny,
iter = train_number,
_learningRate = nn_learning_rate,
_outputWeights = nw
)
for (node <- nn.getOutputNodes) {
max_priority = math.max(max_priority, node.rawOutputDer + prior_eps)
}

x = Array[Array[Double]]()
y = Array[Array[Double]]()
fin_priority = Array[Double]()
c = 0
}
def add(paras: Array[Double], target: Array[Double]) {
x :+= paras
y :+= target
fin_priority :+= math.pow(max_priority, alpha)
c += 1
if (c >= batchsize_number) consume
}
def end = if (c > 0) consume
// Functions for PER
def _sample_proportional(): Array[Int] = {
// Sample indices based on proportions
val indices = new Array[Int](batchsize_number)
val p_sum = fin_priority.sum
val segment = p_sum / batchsize_number
for (i <- 0 until batchsize_number) {
val a = segment * i
val b = segment * (i + 1)
val upperbound = scala.util.Random.nextDouble * (b - a) + a
val idx = _retrieve(upperbound)
indices(i) = idx
}
indices
}
def _retrieve(upperbound: Double): Int = {
var a = 0.0
var i = 0
while (a < upperbound) {
a += fin_priority(i)
i += 1
}
i - 1
}
def _calculate_weight(idx: Int, beta: Double): Double = {
// Calculate the weight of the experience at idx
// get max weight
val p_sum = fin_priority.sum
val p_min = fin_priority.min / p_sum
val max_weight = math.pow(p_min * c, -beta)
// calculate weights
val p_sample = fin_priority(idx) / p_sum
val weight = math.pow(p_sample * c, -beta)
weight / max_weight
}
def _update_priorities(indices: Array[Int], priorities: Array[Double]) {
// Update priorities of sampled transitions
for ((idx, priority) <- indices.zip(priorities)) {
fin_priority(idx) = math.pow(priority, alpha)
max_priority = math.max(max_priority, priority)
}
}
}

class DQState (val paras: Array[Double]) {
def learn(lr: Double, df: Double, epoch: Int): Double = {
val q_s = nn.predictOne(paras)
val act = (if (scala.util.Random.nextDouble > epsilon) q_s.zipWithIndex.maxBy(_._1)._2
else scala.util.Random.nextInt.abs % actnumber)
if (epsilon > 0.1) epsilon -= depsilon
val (newparas, newreward, newfinish) = simulator(paras, act)
if (epoch > 0 && !newfinish) {
val newstate = new DQState(newparas)
val gradient = newreward + df * newstate.learn(lr, df, epoch - 1) // max -> a: Q(s+1, a)
q_s(act) = (1 - lr) * q_s(act) + lr * gradient
} else {
q_s(act) = newreward
}
ex.add(paras, q_s) // nn.train(Array(paras), Array(q_s), batchsize_number, lr)
q_s.max
}
val bestAct: Int = nn.predictOne(paras).zipWithIndex.maxBy(_._1)._2
}

var epsilon = 1.0
var depsilon = 0.9 / epsilon_saturation_number
var state = new DQState(initparas)
def train(number: Int = 1, lr: Double = 0.1, df: Double = 0.6, epoch: Int = 100): Unit = {
for (n <- 0 until number) {
state.learn(lr, df, epoch)
val fraction = math.min(n / number, 1.0)
beta += fraction * (1.0 - beta)
}
ex.end
}
def result(epoch: Int = 100): Array[DQState] = {
var paras = initparas
var curstate = new DQState(initparas)
var arr: Array[DQState] = Array(curstate)
var i = 0
while (i < epoch) {
i += 1
val act = curstate.bestAct
val (newparas, newreward, newfinish) = simulator(paras, act)
if (newfinish) i = epoch
paras = newparas
curstate = new DQState(newparas)
arr :+= curstate
}
arr
}
}
68 changes: 68 additions & 0 deletions src/test/scala/algorithm/reinforcement/PERTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Wei Chen - Deep Q Network
// 2017-09-01

import com.scalaml.TestData._
import com.scalaml.algorithm.PER
import org.scalatest.funsuite.AnyFunSuite

class PERSuite extends AnyFunSuite {

val learning_rate = 0.1
val scale = 1
val limit = 10000
val epoch = 100

test("PER Test : Result 1") { // Case 1
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 100.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}

val ql = new PER(Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, learning_rate, scale, epoch)
val result = ql.result(epoch)
assert(result.size == 3)
assert(result.head.bestAct == 0)
assert(result(1).bestAct == 1)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}

test("PER Test : Result 2") { // Case 2
def simulator(paras: Array[Double], act: Int): (Array[Double], Double, Boolean) = {
val links = Map(0 -> Array(1, 2),
1 -> Array(3, 4))
val scores = Map(2 -> 10.0, 3 -> 0.0, 4 -> 12.0)
val atloc = paras.zipWithIndex.maxBy(_._1)._2
val moves = links.getOrElse(atloc, Array[Int]())
if (moves.size == 0) {
null
} else {
val endloc = moves(act)
val result = Array(0.0, 0.0, 0.0, 0.0, 0.0)
result(endloc) = 1.0
val nextmoves = links.getOrElse(endloc, Array[Int]())
(result, scores.getOrElse(endloc, 0.0), nextmoves.size == 0)
}
}

val ql = new PER(Array(5, 4), Array(1.0, 0.0, 0.0, 0.0, 0.0), 2, simulator, 10)
ql.train(limit, learning_rate, scale, epoch)
val result = ql.result(epoch)
assert(result.size == 3)
assert(result.head.bestAct == 0)
assert(result(1).bestAct == 1)
assert(result.last.paras.zipWithIndex.maxBy(_._1)._2 == 4)
}
}

0 comments on commit e3006eb

Please sign in to comment.