Skip to content
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
517 changes: 226 additions & 291 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

18 changes: 14 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* Used for "binning" the features bins for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
* lowSplit.threshold < featureValue <= highSplit.threshold.
*
* For ordered categorical features, there is a 1-1-1 correspondence between
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for binary classification
* @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predictIfLeaf(features)
topNode.predict(features)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ class Node (

/**
* predict value if node is not leaf
* @param feature feature value
* @param features feature value
* @return predicted value
*/
def predictIfLeaf(feature: Vector) : Double = {
def predict(features: Vector) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
} else {
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
if (split.get.categories.contains(features(split.get.feature))) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
* :: DeveloperApi ::
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param threshold Threshold for continuous feature.
* Split left if feature <= threshold, else right.
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
* @param categories Split left if categorical feature value is in this set, else right.
*/
@DeveloperApi
case class Split(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.scalatest.FunSuite

import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.TreePoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -64,7 +64,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
Expand All @@ -82,7 +82,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
Expand Down Expand Up @@ -162,7 +162,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)

// Check splits.

Expand Down Expand Up @@ -279,7 +279,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)

// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
Expand Down Expand Up @@ -373,7 +373,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)

// 2^10 - 1 > 100, so categorical variables will be ordered

Expand Down Expand Up @@ -428,10 +428,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand All @@ -456,10 +456,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd,strategy)
val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

val split = bestSplits(0)._1
assert(split.categories.length === 1)
Expand Down Expand Up @@ -495,7 +495,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
Expand All @@ -505,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
Expand All @@ -518,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
Expand All @@ -528,7 +528,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
Expand All @@ -542,7 +542,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
Expand All @@ -552,7 +552,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
Expand All @@ -566,7 +566,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
Expand All @@ -576,7 +576,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
Expand All @@ -590,31 +590,36 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
val nodes: Array[Node] = new Array[Node](7)
nodes(0) = modelOneNode.topNode
nodes(0).leftNode = None
nodes(0).rightNode = None

val parentImpurities = Array(0.5, 0.5, 0.5)

// Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
splits, bins, 10)
val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, nodes,
splits, bins, 10, unorderedFeatures)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
assert(bestSplits(1)._2.gain > 0)

// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
filters, splits, bins, 0)
nodes, splits, bins, 0, unorderedFeatures)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
assert(bestSplitsWithGroups(1)._2.gain > 0)
Expand All @@ -629,7 +634,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
}

}

test("stump with categorical variables for multiclass classification") {
Expand All @@ -638,10 +642,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
Expand Down Expand Up @@ -690,18 +694,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
numClassesForClassification = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)

val model = DecisionTree.train(input, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)

val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
Expand All @@ -724,10 +729,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
validateClassifier(model, arr, 0.9)

val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
Expand All @@ -749,10 +754,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
validateClassifier(model, arr, 0.9)

val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
Expand All @@ -769,10 +774,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
new Array[Node](0), splits, bins, 10, unorderedFeatures)

assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
Expand Down