Skip to content
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ac42378
add min info gain and min instances per node parameters in decision tree
Sep 9, 2014
ff34845
separate calculation of predict of node from calculation of info gain
Sep 9, 2014
987cbf4
fix bug
Sep 9, 2014
f195e83
fix style
Sep 9, 2014
845c6fa
fix style
Sep 9, 2014
e72c7e4
add comments
Sep 9, 2014
46b891f
fix bug
Sep 9, 2014
cadd569
add api docs
Sep 9, 2014
bb465ca
Merge branch 'master' of https://github.com/apache/spark into dt-prep…
Sep 9, 2014
6728fad
minor fix: remove empty lines
Sep 9, 2014
10b8012
fix style
Sep 9, 2014
efcc736
fix bug
Sep 9, 2014
d593ec7
fix docs and change minInstancesPerNode to 1
chouqin Sep 9, 2014
0278a11
remove `noSplit` and set `Predict` private to tree
chouqin Sep 10, 2014
c6e2dfc
Added minInstancesPerNode and minInfoGain parameters to DecisionTreeR…
jkbradley Sep 10, 2014
39f9b60
change edge `minInstancesPerNode` to 2 and add one more test
chouqin Sep 10, 2014
c7ebaf1
fix typo
chouqin Sep 10, 2014
f1d11d1
fix typo
chouqin Sep 10, 2014
19b01af
Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-pr…
jkbradley Sep 10, 2014
e2628b6
Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
jkbradley Sep 10, 2014
95c479d
* Fixed typo in tree suite test "do not choose split that does not sa…
jkbradley Sep 10, 2014
a95e7c8
Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
jkbradley Sep 13, 2014
61b2e72
Added max of 10GB for maxMemoryInMB in Strategy.
jkbradley Sep 13, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ object DecisionTreeRunner {
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
fracTest: Double = 0.2)

def main(args: Array[String]) {
Expand All @@ -75,6 +77,13 @@ object DecisionTreeRunner {
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
opt[Int]("minInstancesPerNode")
.text(s"min number of instances required at child nodes to create the parent split," +
s" default: ${defaultParams.minInstancesPerNode}")
.action((x, c) => c.copy(minInstancesPerNode = x))
opt[Double]("minInfoGain")
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
.action((x, c) => c.copy(minInfoGain = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
Expand Down Expand Up @@ -179,7 +188,9 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = numClasses)
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
val model = DecisionTree.train(training, strategy)

println(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
maxBins: Int,
minInstancesPerNode: Int,
minInfoGain: Double): DecisionTreeModel = {

val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)

Expand All @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)

DecisionTree.train(data, strategy)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")
assert(predict.isDefined, "must calculate predict for each node")

(bestSplit, bestSplitStats, predict.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
private[tree] class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{
val prob: Double = 0.0) extends Serializable {

override def toString = {
"predict = %f, prob = %f".format(predict, prob)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
}

test("don't choose split that doesn't satisfy min instance per node requirements") {
// if a split doesn't satisfy min instances per node requirements,
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "don't" is typo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a typo. But I figured that, if people are munging logs from tests, quote characters might be troublesome to deal with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds reasonable, thanks.

// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
Expand All @@ -727,7 +727,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

assert(bestSplits.length == 1)
val bestSplit = bestSplits(0)._1
val bestSplitStats = bestSplits(0)._1
val bestSplitStats = bestSplits(0)._2
assert(bestSplit.feature == 1)
assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
}
Expand Down
16 changes: 12 additions & 4 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class DecisionTree(object):

@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32):
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.

Expand All @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32):
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.

Expand All @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

Expand Down