Skip to content

Commit 9e26473

Browse files
asolimandosethah
authored andcommitted
[SPARK-3159][ML] Add decision tree pruning
## What changes were proposed in this pull request? Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode ## How was this patch tested? Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala": - test("SPARK-3159 tree model redundancy - classification") - test("SPARK-3159 tree model redundancy - regression") 4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing) Author: Alessandro Solimando <[email protected]> Closes #20632 from asolimando/master.
1 parent 487377e commit 9e26473

File tree

5 files changed

+115
-65
lines changed

5 files changed

+115
-65
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.ml.tree
1919

2020
import org.apache.spark.ml.linalg.Vector
2121
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
22-
import org.apache.spark.mllib.tree.model.{ImpurityStats,
23-
InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
22+
import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
2423

2524
/**
2625
* Decision tree node interface.
@@ -266,15 +265,23 @@ private[tree] class LearningNode(
266265
var isLeaf: Boolean,
267266
var stats: ImpurityStats) extends Serializable {
268267

268+
def toNode: Node = toNode(prune = true)
269+
269270
/**
270271
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
271272
*/
272-
def toNode: Node = {
273-
if (leftChild.nonEmpty) {
274-
assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
273+
def toNode(prune: Boolean = true): Node = {
274+
275+
if (!leftChild.isEmpty || !rightChild.isEmpty) {
276+
assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
275277
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
276-
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
277-
leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
278+
(leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
279+
case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
280+
new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
281+
case (l, r) =>
282+
new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
283+
l, r, split.get, stats.impurityCalculator)
284+
}
278285
} else {
279286
if (stats.valid) {
280287
new LeafNode(stats.impurityCalculator.predict, stats.impurity,
@@ -283,7 +290,6 @@ private[tree] class LearningNode(
283290
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
284291
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
285292
}
286-
287293
}
288294
}
289295

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
9292
featureSubsetStrategy: String,
9393
seed: Long,
9494
instr: Option[Instrumentation[_]],
95+
prune: Boolean = true, // exposed for testing only, real trees are always pruned
9596
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
9697

9798
val timer = new TimeTracker()
@@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging {
223224
case Some(uid) =>
224225
if (strategy.algo == OldAlgo.Classification) {
225226
topNodes.map { rootNode =>
226-
new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
227+
new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures,
227228
strategy.getNumClasses)
228229
}
229230
} else {
230231
topNodes.map { rootNode =>
231-
new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
232+
new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures)
232233
}
233234
}
234235
case None =>
235236
if (strategy.algo == OldAlgo.Classification) {
236237
topNodes.map { rootNode =>
237-
new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
238+
new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures,
238239
strategy.getNumClasses)
239240
}
240241
} else {
241-
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
242+
topNodes.map(rootNode =>
243+
new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures))
242244
}
243245
}
244246
}

mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite
280280
dt.fit(df)
281281
}
282282

283-
test("Use soft prediction for binary classification with ordered categorical features") {
284-
// The following dataset is set up such that the best split is {1} vs. {0, 2}.
285-
// If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
286-
val arr = Array(
287-
LabeledPoint(0.0, Vectors.dense(0.0)),
288-
LabeledPoint(0.0, Vectors.dense(0.0)),
289-
LabeledPoint(0.0, Vectors.dense(0.0)),
290-
LabeledPoint(1.0, Vectors.dense(0.0)),
291-
LabeledPoint(0.0, Vectors.dense(1.0)),
292-
LabeledPoint(0.0, Vectors.dense(1.0)),
293-
LabeledPoint(0.0, Vectors.dense(1.0)),
294-
LabeledPoint(0.0, Vectors.dense(1.0)),
295-
LabeledPoint(0.0, Vectors.dense(2.0)),
296-
LabeledPoint(0.0, Vectors.dense(2.0)),
297-
LabeledPoint(0.0, Vectors.dense(2.0)),
298-
LabeledPoint(1.0, Vectors.dense(2.0)))
299-
val data = sc.parallelize(arr)
300-
val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
301-
302-
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
303-
val dt = new DecisionTreeClassifier()
304-
.setImpurity("gini")
305-
.setMaxDepth(1)
306-
.setMaxBins(3)
307-
val model = dt.fit(df)
308-
model.rootNode match {
309-
case n: InternalNode =>
310-
n.split match {
311-
case s: CategoricalSplit =>
312-
assert(s.leftCategories === Array(1.0))
313-
case other =>
314-
fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.")
315-
}
316-
case other =>
317-
fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.")
318-
}
319-
}
320-
321283
test("Feature importance with toy data") {
322284
val dt = new DecisionTreeClassifier()
323285
.setImpurity("gini")

mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.tree.impl
1919

20+
import scala.annotation.tailrec
2021
import scala.collection.mutable
2122

2223
import org.apache.spark.SparkFunSuite
@@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
3839

3940
import RandomForestSuite.mapToVec
4041

42+
private val seed = 42
43+
4144
/////////////////////////////////////////////////////////////////////////////
4245
// Tests for split calculation
4346
/////////////////////////////////////////////////////////////////////////////
@@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
320323
assert(topNode.isLeaf === false)
321324
assert(topNode.stats === null)
322325

323-
val nodesForGroup = Map((0, Array(topNode)))
324-
val treeToNodeToIndexInfo = Map((0, Map(
325-
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
326-
)))
326+
val nodesForGroup = Map(0 -> Array(topNode))
327+
val treeToNodeToIndexInfo = Map(0 -> Map(
328+
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
329+
))
327330
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
328331
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
329332
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
362365
assert(topNode.isLeaf === false)
363366
assert(topNode.stats === null)
364367

365-
val nodesForGroup = Map((0, Array(topNode)))
366-
val treeToNodeToIndexInfo = Map((0, Map(
367-
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
368-
)))
368+
val nodesForGroup = Map(0 -> Array(topNode))
369+
val treeToNodeToIndexInfo = Map(0 -> Map(
370+
topNode.id -> new RandomForest.NodeIndexInfo(0, None)
371+
))
369372
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
370373
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
371374
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
407410
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
408411

409412
val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
410-
seed = 42, instr = None).head
413+
seed = 42, instr = None, prune = false).head
414+
411415
model.rootNode match {
412416
case n: InternalNode => n.split match {
413417
case s: CategoricalSplit =>
@@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
631635
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
632636
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
633637
}
638+
639+
///////////////////////////////////////////////////////////////////////////////
640+
// Tests for pruning of redundant subtrees (generated by a split improving the
641+
// impurity measure, but always leading to the same prediction).
642+
///////////////////////////////////////////////////////////////////////////////
643+
644+
test("SPARK-3159 tree model redundancy - classification") {
645+
// The following dataset is set up such that splitting over feature_1 for points having
646+
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
647+
// in both branches.
648+
val arr = Array(
649+
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
650+
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
651+
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
652+
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
653+
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
654+
LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
655+
)
656+
val rdd = sc.parallelize(arr)
657+
658+
val numClasses = 2
659+
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4,
660+
numClasses = numClasses, maxBins = 32)
661+
662+
val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
663+
seed = 42, instr = None).head
664+
665+
val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
666+
seed = 42, instr = None, prune = false).head
667+
668+
assert(prunedTree.numNodes === 5)
669+
assert(unprunedTree.numNodes === 7)
670+
671+
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
672+
}
673+
674+
test("SPARK-3159 tree model redundancy - regression") {
675+
// The following dataset is set up such that splitting over feature_0 for points having
676+
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
677+
// in both branches.
678+
val arr = Array(
679+
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
680+
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
681+
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
682+
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
683+
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
684+
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
685+
LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
686+
)
687+
val rdd = sc.parallelize(arr)
688+
689+
val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
690+
numClasses = 0, maxBins = 32)
691+
692+
val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
693+
seed = 42, instr = None).head
694+
695+
val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
696+
seed = 42, instr = None, prune = false).head
697+
698+
assert(prunedTree.numNodes === 3)
699+
assert(unprunedTree.numNodes === 5)
700+
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
701+
}
634702
}
635703

636704
private object RandomForestSuite {
637-
638705
def mapToVec(map: Map[Int, Double]): Vector = {
639706
val size = (map.keys.toSeq :+ 0).max + 1
640707
val (indices, values) = map.toSeq.sortBy(_._1).unzip
641708
Vectors.sparse(size, indices.toArray, values.toArray)
642709
}
710+
711+
@tailrec
712+
private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = {
713+
if (nodes.isEmpty) {
714+
acc
715+
}
716+
else {
717+
nodes.head match {
718+
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
719+
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
720+
}
721+
}
722+
}
643723
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
363363
// if a split does not satisfy min instances per node requirements,
364364
// this split is invalid, even though the information gain of split is large.
365365
val arr = Array(
366-
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
367-
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
368-
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
369-
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
366+
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
367+
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
368+
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)),
369+
LabeledPoint(1.0, Vectors.dense(0.0, 0.0)))
370370

371371
val rdd = sc.parallelize(arr)
372372
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
541541
Array[LabeledPoint] = {
542542
val arr = new Array[LabeledPoint](3000)
543543
for (i <- 0 until 3000) {
544-
if (i < 1000) {
544+
if (i < 1001) {
545545
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
546546
} else if (i < 2000) {
547547
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))

0 commit comments

Comments
 (0)