Skip to content

Commit fdb302f

Browse files
qiping.lqpmengxr
authored andcommitted
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API
Added minInstancesPerNode, minInfoGain params to: * DecisionTreeRunner.scala example * Python API (tree.py) Also: * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes CC: mengxr Author: qiping.lqp <[email protected]> Author: Joseph K. Bradley <[email protected]> Author: chouqin <[email protected]> Closes apache#2349 from jkbradley/chouqin-dt-preprune and squashes the following commits: 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
1 parent 983d6a9 commit fdb302f

File tree

7 files changed

+37
-16
lines changed

7 files changed

+37
-16
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ object DecisionTreeRunner {
5555
maxDepth: Int = 5,
5656
impurity: ImpurityType = Gini,
5757
maxBins: Int = 32,
58+
minInstancesPerNode: Int = 1,
59+
minInfoGain: Double = 0.0,
5860
fracTest: Double = 0.2)
5961

6062
def main(args: Array[String]) {
@@ -75,6 +77,13 @@ object DecisionTreeRunner {
7577
opt[Int]("maxBins")
7678
.text(s"max number of bins, default: ${defaultParams.maxBins}")
7779
.action((x, c) => c.copy(maxBins = x))
80+
opt[Int]("minInstancesPerNode")
81+
.text(s"min number of instances required at child nodes to create the parent split," +
82+
s" default: ${defaultParams.minInstancesPerNode}")
83+
.action((x, c) => c.copy(minInstancesPerNode = x))
84+
opt[Double]("minInfoGain")
85+
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
86+
.action((x, c) => c.copy(minInfoGain = x))
7887
opt[Double]("fracTest")
7988
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
8089
.action((x, c) => c.copy(fracTest = x))
@@ -179,7 +188,9 @@ object DecisionTreeRunner {
179188
impurity = impurityCalculator,
180189
maxDepth = params.maxDepth,
181190
maxBins = params.maxBins,
182-
numClassesForClassification = numClasses)
191+
numClassesForClassification = numClasses,
192+
minInstancesPerNode = params.minInstancesPerNode,
193+
minInfoGain = params.minInfoGain)
183194
val model = DecisionTree.train(training, strategy)
184195

185196
println(model)

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
303303
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
304304
impurityStr: String,
305305
maxDepth: Int,
306-
maxBins: Int): DecisionTreeModel = {
306+
maxBins: Int,
307+
minInstancesPerNode: Int,
308+
minInfoGain: Double): DecisionTreeModel = {
307309

308310
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
309311

@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
316318
maxDepth = maxDepth,
317319
numClassesForClassification = numClasses,
318320
maxBins = maxBins,
319-
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
321+
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
322+
minInstancesPerNode = minInstancesPerNode,
323+
minInfoGain = minInfoGain)
320324

321325
DecisionTree.train(data, strategy)
322326
}

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
389389
var groupIndex = 0
390390
var doneTraining = true
391391
while (groupIndex < numGroups) {
392-
val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
392+
val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
393393
topNode, splits, bins, timer, numGroups, groupIndex)
394394
doneTraining = doneTraining && doneTrainingGroup
395395
groupIndex += 1
@@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
898898
}
899899
}.maxBy(_._2.gain)
900900

901-
require(predict.isDefined, "must calculate predict for each node")
901+
assert(predict.isDefined, "must calculate predict for each node")
902902

903903
(bestSplit, bestSplitStats, predict.get)
904904
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class Strategy (
7777
}
7878
require(minInstancesPerNode >= 1,
7979
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
80+
require(maxMemoryInMB <= 10240,
81+
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
8082

8183
val isMulticlassClassification =
8284
algo == Classification && numClassesForClassification > 2

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,14 @@
1717

1818
package org.apache.spark.mllib.tree.model
1919

20-
import org.apache.spark.annotation.DeveloperApi
21-
2220
/**
23-
* :: DeveloperApi ::
2421
* Predicted value for a node
2522
* @param predict predicted value
2623
* @param prob probability of the label (classification only)
2724
*/
28-
@DeveloperApi
2925
private[tree] class Predict(
3026
val predict: Double,
31-
val prob: Double = 0.0) extends Serializable{
27+
val prob: Double = 0.0) extends Serializable {
3228

3329
override def toString = {
3430
"predict = %f, prob = %f".format(predict, prob)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
714714
assert(gain == InformationGainStats.invalidInformationGainStats)
715715
}
716716

717-
test("don't choose split that doesn't satisfy min instance per node requirements") {
718-
// if a split doesn't satisfy min instances per node requirements,
717+
test("do not choose split that does not satisfy min instance per node requirements") {
718+
// if a split does not satisfy min instances per node requirements,
719719
// this split is invalid, even though the information gain of split is large.
720720
val arr = new Array[LabeledPoint](4)
721721
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))

python/pyspark/mllib/tree.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class DecisionTree(object):
138138

139139
@staticmethod
140140
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
141-
impurity="gini", maxDepth=5, maxBins=32):
141+
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
142+
minInfoGain=0.0):
142143
"""
143144
Train a DecisionTreeModel for classification.
144145
@@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
154155
E.g., depth 0 means 1 leaf node.
155156
Depth 1 means 1 internal node + 2 leaf nodes.
156157
:param maxBins: Number of bins used for finding splits at each node.
158+
:param minInstancesPerNode: Min number of instances required at child nodes to create
159+
the parent split
160+
:param minInfoGain: Min info gain required to create a split
157161
:return: DecisionTreeModel
158162
"""
159163
sc = data.context
@@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
164168
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
165169
dataBytes._jrdd, "classification",
166170
numClasses, categoricalFeaturesInfoJMap,
167-
impurity, maxDepth, maxBins)
171+
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
168172
dataBytes.unpersist()
169173
return DecisionTreeModel(sc, model)
170174

171175
@staticmethod
172176
def trainRegressor(data, categoricalFeaturesInfo,
173-
impurity="variance", maxDepth=5, maxBins=32):
177+
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
178+
minInfoGain=0.0):
174179
"""
175180
Train a DecisionTreeModel for regression.
176181
@@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo,
185190
E.g., depth 0 means 1 leaf node.
186191
Depth 1 means 1 internal node + 2 leaf nodes.
187192
:param maxBins: Number of bins used for finding splits at each node.
193+
:param minInstancesPerNode: Min number of instances required at child nodes to create
194+
the parent split
195+
:param minInfoGain: Min info gain required to create a split
188196
:return: DecisionTreeModel
189197
"""
190198
sc = data.context
@@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo,
195203
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
196204
dataBytes._jrdd, "regression",
197205
0, categoricalFeaturesInfoJMap,
198-
impurity, maxDepth, maxBins)
206+
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
199207
dataBytes.unpersist()
200208
return DecisionTreeModel(sc, model)
201209

0 commit comments

Comments
 (0)