Skip to content

Commit

Permalink
2022-12-28-regression-tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-1 committed Dec 28, 2022
1 parent 10360a7 commit 1b8433d
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ A very light weight Scala machine learning library that provide some basic ML al

- [x] Stochastic Gradient Decent [[Code]](src/main/scala/algorithm/regression/MultivariateLinearRegression.scala) [[Usage]](src/test/scala/algorithm/regression/MultivariateLinearRegressionTest.scala)

- [x] Regression Tree [[Code]](src/main/scala/algorithm/regression/RegressionTree.scala) [[Usage]](src/test/scala/algorithm/regression/RegressionTreeTest.scala)

### Clustering :

- [x] Hierarchical [[Code]](src/main/scala/algorithm/clustering/Hierarchical.scala) [[Usage]](src/test/scala/algorithm/clustering/HierarchicalTest.scala)
Expand Down
101 changes: 101 additions & 0 deletions src/main/scala/algorithm/regression/RegressionTree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Wei Chen - Regression Tree
// 2022-12-28

package com.scalaml.algorithm

class RegressionNode(
val col: Int, val v: Double,
val tnode: RegressionNode, val fnode: RegressionNode,
val r: Double = 0,
val cats: Set[Int] = Set[Int]()
) {
def predict(x: Array[Double]): Double = {
if(tnode != null && fnode != null) {
if((!cats.contains(col) && x(col) > v) || x(col) == v) tnode.predict(x)
else fnode.predict(x)
} else r
}
override def toString: String = {
if(tnode != null && fnode != null) {
s"col[$col]" + (if(cats.contains(col)) " == " else " >= ") + v +
s" ? ($tnode) : ($fnode)"
} else s"class[$r]"
}
}

class RegressionTree() extends Regression {
val algoname: String = "RegressionTree"
val version: String = "0.1"

var tree: RegressionNode = null
var catColumns: Set[Int] = Set[Int]()
var maxLayer: Int = 5

override def clear(): Boolean = {
tree = null
true
}

override def config(paras: Map[String, Any]): Boolean = try {
catColumns = paras.getOrElse("CATEGORYCOLUMNS", paras.getOrElse("catColumns", Set[Int]())).asInstanceOf[Set[Int]]
maxLayer = paras.getOrElse("maxLayer", 5.0).asInstanceOf[Double].toInt
true
} catch { case e: Exception =>
Console.err.println(e)
false
}

private def log2(x: Double) = Math.log(x) / Math.log(2)

private def entropy(data: Array[(Double, Array[Double])]): Double = {
val dataSize = data.size.toDouble
val dataAvg = data.map(_._1).sum / dataSize
data.map(d => Math.abs(d._1 - dataAvg)).sum / dataSize
}

private def buildtree(data: Array[(Double, Array[Double])], layer: Int = maxLayer): RegressionNode = {
var currentScore: Double = entropy(data)
var bestGain: Double = 0
var bestColumn: Int = 0
var bestValue: Double = 0
var bestTrueData = Array[(Double, Array[Double])]()
var bestFalseData = Array[(Double, Array[Double])]()

val dataSize = data.size.toDouble
val columnSize: Int = data.head._2.size
for (col <- 0 until columnSize) {
var valueSet: Set[Double] = Set()
for (d <- data) valueSet += d._2(col)
for (value <- valueSet) {
val (tData, fData) = data.partition { d =>
if(catColumns.contains(col)) d._2(col) == value
else d._2(col) >= value
}
val p = tData.size / dataSize
val gain = currentScore - p * entropy(tData) - (1 - p) * entropy(fData)
if (gain > bestGain && tData.size > 0 && fData.size > 0) {
bestGain = gain
bestColumn = col
bestValue = value
bestTrueData = tData
bestFalseData = fData
}
}
}
if (bestGain > 0 && layer > 0) {
val tnode = buildtree(bestTrueData, layer - 1)
val fnode = buildtree(bestFalseData, layer - 1)
new RegressionNode(bestColumn, bestValue, tnode, fnode)
} else new RegressionNode(0, 0, null, null, data.map(_._1).sum / dataSize)
}

override def train(data: Array[(Double, Array[Double])]): Boolean = try {
tree = buildtree(data)
true
} catch { case e: Exception =>
Console.err.println(e)
false
}

override def predict(x: Array[Array[Double]]): Array[Double] = x.map(xi => tree.predict(xi))
}
12 changes: 5 additions & 7 deletions src/test/scala/algorithm/regression/GradientBoostTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,20 @@ class GradientBoostSuite extends AnyFunSuite {
assert(arraysimilar(nResult, LABEL_LINEAR_DATA.map(_.toDouble), 0.9))
}

test("GradientBoost Test : Nonlinear Data, 1 Linear Model - WRONG") {
test("GradientBoost Test : Nonlinear Data, 1 Model - WRONG") {
assert(gb.clear())
assert(gb.config(Map[String, Any]()))
assert(gb.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = gb.predict(UNLABELED_NONLINEAR_DATA)
assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
}

// More linear regressors will not solve nonlinear problems
test("GradientBoost Test : Nonlinear Data, 5 Linear Models - WRONG") {
test("GradientBoost Test : Nonlinear Data, 4 Models - WRONG") {
val regressors: Any = Array(
new MultipleLinearRegression,
new MultivariateLinearRegression,
new StochasticGradientDecent,
new StochasticGradientDecent,
new StochasticGradientDecent,
new StochasticGradientDecent,
new StochasticGradientDecent
new RegressionTree
)
assert(gb.clear())
assert(gb.config(Map("regressors" -> regressors)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MultipleLinearRegressionSuite extends AnyFunSuite {
assert(mlr.config(Map[String, Double]()))
assert(mlr.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = mlr.predict(UNLABELED_NONLINEAR_DATA)
assert(!arraysimilar(result, LABEL_LINEAR_DATA.map(_.toDouble), 0.45))
assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
}

test("MultipleLinearRegression Test : Invalid Data") {
Expand Down
38 changes: 38 additions & 0 deletions src/test/scala/algorithm/regression/RegressionTreeTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Wei Chen - Multiple Linear Regression Test
// 2016-06-04

import com.scalaml.TestData._
import com.scalaml.general.MatrixFunc._
import com.scalaml.algorithm.RegressionTree
import org.scalatest.funsuite.AnyFunSuite

class RegressionTreeSuite extends AnyFunSuite {

val rt = new RegressionTree()

test("RegressionTree Test : Clear") {
assert(rt.clear())
}

test("RegressionTree Test : Linear Data") {
assert(rt.clear())
assert(rt.config(Map[String, Double]()))
assert(rt.train(LABELED_LINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = rt.predict(UNLABELED_LINEAR_DATA)
assert(arraysimilar(result, LABEL_LINEAR_DATA.map(_.toDouble), 0.9))
Console.err.println(result.mkString(","), LABEL_LINEAR_DATA.mkString(","))
}

test("RegressionTree Test : Nonlinear Data - WRONG") {
assert(rt.clear())
assert(rt.config(Map[String, Double]()))
assert(rt.train(LABELED_NONLINEAR_DATA.map(yx => yx._1.toDouble -> yx._2)))
val result = rt.predict(UNLABELED_NONLINEAR_DATA)
assert(!arraysimilar(result, LABEL_NONLINEAR_DATA.map(_.toDouble), 0.45))
}

test("RegressionTree Test : Invalid Data") {
assert(rt.clear())
assert(!rt.train(Array((1, Array(1, 2)), (1, Array()))))
}
}

0 comments on commit 1b8433d

Please sign in to comment.