Skip to content

Commit

Permalink
per-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-1 committed Oct 5, 2020
1 parent e8f61b7 commit d3ce4f0
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/main/scala/algorithm/reinforcement/PER.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// 2017-08-31

package com.scalaml.algorithm
import com.scalaml.general.MatrixFunc._

// nextstate, reward, end = simulator(state, action)
class PER(
Expand Down Expand Up @@ -40,17 +41,22 @@ class PER(
ny :+= y(i)
nw :+= _calculate_weight(i, beta)
}
val nyp = nn.predict(nx)
val new_priorities = ny.zip(nyp).map { case (nyi, nypi) =>
math.sqrt(arrayminussquare(nyi, nypi).sum) + prior_eps
}

nn.train(
nx, ny,
iter = train_number,
_learningRate = nn_learning_rate,
_outputWeights = nw
)
for (node <- nn.getOutputNodes) {
max_priority = math.max(max_priority, node.rawOutputDer + prior_eps)
}

// _update_priorities(indices, new_priorities)
for (priority <- new_priorities) {
max_priority = math.max(max_priority, priority)
}
x = Array[Array[Double]]()
y = Array[Array[Double]]()
fin_priority = Array[Double]()
Expand Down Expand Up @@ -99,13 +105,13 @@ class PER(
val weight = math.pow(p_sample * c, -beta)
weight / max_weight
}
def _update_priorities(indices: Array[Int], priorities: Array[Double]) {
// Update priorities of sampled transitions
for ((idx, priority) <- indices.zip(priorities)) {
fin_priority(idx) = math.pow(priority, alpha)
max_priority = math.max(max_priority, priority)
}
}
// def _update_priorities(indices: Array[Int], priorities: Array[Double]) {
// // Update priorities of sampled transitions
// for ((idx, priority) <- indices.zip(priorities)) {
// fin_priority(idx) = math.pow(priority, alpha)
// max_priority = math.max(max_priority, priority)
// }
// }
}

class DQState (val paras: Array[Double]) {
Expand Down

0 comments on commit d3ce4f0

Please sign in to comment.