-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Wei Chen - Random Cut Forest | ||
// 2022-03-05 | ||
|
||
package com.scalaml.algorithm | ||
import com.scalaml.general.MatrixFunc._ | ||
|
||
class RandomCutForest() extends Abnormal { | ||
val algoname: String = "RandomCutForest" | ||
val version: String = "0.1" | ||
|
||
var trees = Array[RandomCutTree]() | ||
var tree_n = 10 // Number of Trees | ||
var sample_n = 10 // Number of Sample Data in a Tree | ||
var maxLayer = 5 | ||
|
||
override def clear(): Boolean = { | ||
trees = Array[RandomCutTree]() | ||
tree_n = 10 // Number of Trees | ||
sample_n = 10 // Number of Sample Data in a Tree | ||
maxLayer = 5 | ||
true | ||
} | ||
|
||
override def config(paras: Map[String, Any]): Boolean = try { | ||
tree_n = paras.getOrElse("TREE_NUMBER", paras.getOrElse("tree_number", paras.getOrElse("tree_n", 10.0))).asInstanceOf[Double].toInt | ||
sample_n = paras.getOrElse("SAMPLE_NUMBER", paras.getOrElse("sample_number", paras.getOrElse("sample_n", 10.0))).asInstanceOf[Double].toInt | ||
maxLayer = paras.getOrElse("maxLayer", 5.0).asInstanceOf[Double].toInt | ||
true | ||
} catch { case e: Exception => | ||
Console.err.println(e) | ||
false | ||
} | ||
|
||
private def randomSelect(data: Array[Array[Double]], sample_n: Int) = | ||
scala.util.Random.shuffle(data.toList).take(sample_n).toArray | ||
|
||
private def addTree(data: Array[Array[Double]]): Boolean = { | ||
val itree = new RandomCutTree() | ||
var paras = Map("maxLayer" -> maxLayer.toDouble): Map[String, Any] | ||
val check = itree.config(paras) && itree.train(data) | ||
if(check) trees :+= itree | ||
check | ||
} | ||
|
||
override def train(data: Array[Array[Double]]): Boolean = { | ||
val data_n = data.size | ||
if (data_n > sample_n) { | ||
(0 until tree_n).forall(i => addTree(randomSelect(data, sample_n))) | ||
} else addTree(data) | ||
} | ||
|
||
override def predict(data: Array[Array[Double]]): Array[Double] = { | ||
matrixaccumulate(trees.map { tree => | ||
tree.predict(data) | ||
}).map(_ / tree_n) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Wei Chen - Random Cut Tree | ||
// 2022-03-04 | ||
|
||
package com.scalaml.algorithm | ||
|
||
class RandomCutTree() extends Abnormal { | ||
val algoname: String = "RandomCutTree" | ||
val version: String = "0.1" | ||
|
||
var maxLayer = 5 | ||
var tree: DecisionNode = null | ||
|
||
override def clear(): Boolean = { | ||
maxLayer = 5 | ||
true | ||
} | ||
|
||
override def config(paras: Map[String, Any]): Boolean = try { | ||
maxLayer = paras.getOrElse("maxLayer", 5.0).asInstanceOf[Double].toInt | ||
true | ||
} catch { case e: Exception => | ||
Console.err.println(e) | ||
false | ||
} | ||
|
||
private def buildtree(data: Array[Array[Double]], layer: Int = 0): DecisionNode = { | ||
val dataSize: Int = data.size | ||
val columnSize: Int = data.head.size | ||
val colMinMax = (0 until columnSize).map { col => | ||
val colData = data.map(d => d(col)) | ||
(colData.min, colData.max) | ||
} | ||
val baseSum = colMinMax.foldLeft(0.0)((a, b) => a + b._2 - b._1) | ||
var baseValue: Double = baseSum * scala.util.Random.nextDouble() | ||
var bestColumn: Int = 0 | ||
for (col <- 0 until columnSize) { | ||
val (colMin, colMax) = colMinMax(col) | ||
val colRange = colMax - colMin | ||
if (baseValue > 0 && colRange > baseValue) { | ||
bestColumn = col | ||
} | ||
baseValue -= colRange | ||
} | ||
val (minV, maxV) = colMinMax(bestColumn) | ||
val value = (maxV - minV) * scala.util.Random.nextDouble() + minV | ||
val (tData, fData) = data.partition { d => | ||
d(bestColumn) >= value | ||
} | ||
if (tData.size > 0 && fData.size > 0 && layer < maxLayer) { | ||
val tnode = buildtree(tData, layer + 1) | ||
val fnode = buildtree(fData, layer + 1) | ||
new DecisionNode(bestColumn, value, tnode, fnode) | ||
} else new DecisionNode(0, 0, null, null, layer) | ||
} | ||
|
||
override def train(data: Array[Array[Double]]): Boolean = try { | ||
tree = buildtree(data) | ||
true | ||
} catch { case e: Exception => | ||
Console.err.println(e) | ||
false | ||
} | ||
|
||
override def predict(x: Array[Array[Double]]): Array[Double] = x.map(xi => tree.predict(xi)) | ||
} |
31 changes: 31 additions & 0 deletions
31
src/test/scala/algorithm/abnormal/RandomCutForestTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Wei Chen - Random Cut Forest Test | ||
// 2022-03-05 | ||
|
||
import com.scalaml.TestData._ | ||
import com.scalaml.general.MatrixFunc._ | ||
import com.scalaml.algorithm.RandomCutForest | ||
import org.scalatest.funsuite.AnyFunSuite | ||
|
||
class RandomCutForestSuite extends AnyFunSuite { | ||
|
||
val rcforest = new RandomCutForest() | ||
|
||
test("RandomCutForest Test : Clear") { | ||
assert(rcforest.clear()) | ||
} | ||
|
||
test("RandomCutForest Test : Abnormal Large Data") { | ||
assert(rcforest.clear()) | ||
assert(rcforest.config(Map("tree_n" -> 100.0))) | ||
assert(rcforest.train(UNLABELED_LARGE_DATA)) | ||
val result = rcforest.predict(UNLABELED_LARGE_DATA) | ||
assert(arraysimilar(result, UNLABELED_LARGE_DATA.map(_ => 1.0), UNLABELED_NONLINEAR_DATA.size)) | ||
assert(result.last < result.sum / result.size) | ||
} | ||
|
||
test("RandomCutForest Test : Invalid Data") { | ||
assert(rcforest.clear()) | ||
assert(!rcforest.config(Map("maxLayer" -> "test"))) | ||
assert(!rcforest.train(Array(Array(1, 2), Array()))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Wei Chen - Random Cut Tree Test | ||
// 2022-03-05 | ||
|
||
import com.scalaml.TestData._ | ||
import com.scalaml.general.MatrixFunc._ | ||
import com.scalaml.algorithm.RandomCutTree | ||
import org.scalatest.funsuite.AnyFunSuite | ||
|
||
class RandomCutTreeSuite extends AnyFunSuite { | ||
|
||
val rctree = new RandomCutTree() | ||
|
||
test("RandomCutTree Test : Clear") { | ||
assert(rctree.clear()) | ||
} | ||
|
||
test("RandomCutTree Test : Abnormal Large Data") { | ||
assert(rctree.clear()) | ||
assert(rctree.config(Map[String, Double]())) | ||
assert(rctree.train(UNLABELED_LARGE_DATA)) | ||
val result = rctree.predict(UNLABELED_LARGE_DATA) | ||
assert(arraysimilar(result, UNLABELED_LARGE_DATA.map(_ => 1.0), UNLABELED_NONLINEAR_DATA.size)) | ||
} | ||
|
||
test("RandomCutTree Test : Invalid Data") { | ||
assert(rctree.clear()) | ||
assert(!rctree.config(Map("maxLayer" -> "test"))) | ||
assert(!rctree.train(Array(Array(1, 2), Array()))) | ||
} | ||
} |