Skip to content

Commit eefeef1

Browse files
committed
adjust comments and check child nodes' impurity
1 parent c41b1b6 commit eefeef1

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,11 @@ object DecisionTree extends Serializable with Logging {
533533
}
534534

535535
// array of nodes to train indexed by node index in group
536-
val nodes = {
537-
val nodes = Array.fill[Node](numNodes)(null)
538-
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
539-
nodesForTree.foreach { node =>
540-
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
541-
}
536+
val nodes = Array.fill[Node](numNodes)(null)
537+
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
538+
nodesForTree.foreach { node =>
539+
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
542540
}
543-
nodes
544541
}
545542

546543
// Calculate best splits for all nodes in the group
@@ -607,14 +604,18 @@ object DecisionTree extends Serializable with Logging {
607604
if (!isLeaf) {
608605
node.split = Some(split)
609606
val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
607+
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
608+
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
610609
node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
611-
stats.leftPredict, stats.leftImpurity, childIsLeaf))
610+
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
612611
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
613-
stats.rightPredict, stats.rightImpurity, childIsLeaf))
612+
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
614613

615614
// enqueue left child and right child if they are not leaves
616-
if (!childIsLeaf) {
615+
if (!leftChildIsLeaf) {
617616
nodeQueue.enqueue((treeIndex, node.leftNode.get))
617+
}
618+
if (!rightChildIsLeaf) {
618619
nodeQueue.enqueue((treeIndex, node.rightNode.get))
619620
}
620621

@@ -691,11 +692,10 @@ object DecisionTree extends Serializable with Logging {
691692
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
692693
val parentNodeAgg = leftImpurityCalculator.copy
693694
parentNodeAgg.add(rightImpurityCalculator)
694-
val predict = parentNodeAgg.predict
695-
val prob = parentNodeAgg.prob(predict)
695+
val predict = calculatePredict(parentNodeAgg)
696696
val impurity = parentNodeAgg.calculate()
697697

698-
(new Predict(predict, prob), impurity)
698+
(predict, impurity)
699699
}
700700

701701
/**
@@ -709,7 +709,7 @@ object DecisionTree extends Serializable with Logging {
709709
featuresForNode: Option[Array[Int]],
710710
node: Node): (Split, InformationGainStats, Predict) = {
711711

712-
// calculate predict and impurity if current node are top node
712+
// calculate predict and impurity if current node is top node
713713
val level = Node.indexToLevel(node.id)
714714
var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
715715
None

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.Vector
3333
* @param id integer node id, from 1
3434
* @param predict predicted value at the node
3535
* @param impurity current node impurity
36-
* @param isLeaf whether the leaf is a node
36+
* @param isLeaf whether the node is a leaf
3737
* @param split split to calculate left and right nodes
3838
* @param leftNode left child
3939
* @param rightNode right child
@@ -179,13 +179,13 @@ private[tree] object Node {
179179
/**
180180
* Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
181181
* This is used in `DecisionTree.findBestSplits` to construct child nodes
182-
* after find best splits for each node.
182+
* after finding the best splits for parent nodes.
183183
* Other fields are set at next level.
184184
* @param nodeIndex integer node id, from 1
185185
* @param predict predicted value at the node
186186
* @param impurity current node impurity
187-
* @param isLeaf whether the leaf is a node
188-
* @return newed node instance
187+
* @param isLeaf whether the node is a leaf
188+
* @return new node instance
189189
*/
190190
def apply(
191191
nodeIndex: Int,

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,49 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
748748
assert(topNode.leftNode.get.impurity === 0.0)
749749
assert(topNode.rightNode.get.impurity === 0.0)
750750
}
751+
752+
test("Avoid aggregation if impurity is 0.0") {
753+
val arr = new Array[LabeledPoint](4)
754+
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
755+
arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
756+
arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
757+
arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
758+
val input = sc.parallelize(arr)
759+
760+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
761+
numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
762+
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
763+
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
764+
765+
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
766+
val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
767+
768+
val topNode = Node.emptyNode(nodeIndex = 1)
769+
assert(topNode.predict.predict === Double.MinValue)
770+
assert(topNode.impurity === -1.0)
771+
assert(topNode.isLeaf === false)
772+
773+
val nodesForGroup = Map((0, Array(topNode)))
774+
val treeToNodeToIndexInfo = Map((0, Map(
775+
(topNode.id, new RandomForest.NodeIndexInfo(0, None))
776+
)))
777+
val nodeQueue = new mutable.Queue[(Int, Node)]()
778+
DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
779+
nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
780+
781+
// don't enqueue a node into node queue if its impurity is 0.0
782+
assert(nodeQueue.isEmpty)
783+
784+
// set impurity and predict for topNode
785+
assert(topNode.predict.predict !== Double.MinValue)
786+
assert(topNode.impurity !== -1.0)
787+
788+
// set impurity and predict for child nodes
789+
assert(topNode.leftNode.get.predict.predict === 0.0)
790+
assert(topNode.rightNode.get.predict.predict === 1.0)
791+
assert(topNode.leftNode.get.impurity === 0.0)
792+
assert(topNode.rightNode.get.impurity === 0.0)
793+
}
751794
}
752795

753796
object DecisionTreeSuite {

0 commit comments

Comments
 (0)