Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 66 additions & 31 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging {
Some(mutableNodeToFeatures.toMap)
}

// array of nodes to train indexed by node index in group
val nodes = new Array[Node](numNodes)
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
}
}

// Calculate best splits for all nodes in the group
timer.start("chooseSplits")

Expand Down Expand Up @@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging {

// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
binsToBestSplit(aggStats, splits, featuresForNode)
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict))
}.collectAsMap()

Expand All @@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging {
// Extract info for this node. Create children if not leaf.
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
assert(node.id == nodeIndex)
node.predict = predict.predict
node.predict = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
node.impurity = stats.impurity
logDebug("Node = " + node)

if (!isLeaf) {
node.split = Some(split)
node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
nodeQueue.enqueue((treeIndex, node.leftNode.get))
nodeQueue.enqueue((treeIndex, node.rightNode.get))
val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also check stats.leftImpurity and rightImpurity. If stats.leftImpurity = 0, then we know the left child will be a leaf. Same for the right child.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is done, it might be good to add 1 more test to make sure it works. A slight modification of the test you already added should work.

val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))

// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
}
if (!rightChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.rightNode.get))
}

logDebug("leftChildIndex = " + node.leftNode.get.id +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + node.rightNode.get.id +
Expand All @@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
metadata: DecisionTreeMetadata): InformationGainStats = {
metadata: DecisionTreeMetadata,
impurity: Double): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count

Expand All @@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging {

val totalCount = leftCount + rightCount

val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)

val impurity = parentNodeAgg.calculate()

val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

Expand All @@ -649,25 +666,36 @@ object DecisionTree extends Serializable with Logging {
return InformationGainStats.invalidInformationGainStats
}

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
// calculate left and right predict
val leftPredict = calculatePredict(leftImpurityCalculator)
val rightPredict = calculatePredict(rightImpurityCalculator)

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
leftPredict, rightPredict)
}

private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
val predict = impurityCalculator.predict
val prob = impurityCalculator.prob(predict)
new Predict(predict, prob)
}

/**
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
* @return predict value for current node
* @return predict value and impurity for current node
*/
private def calculatePredict(
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): Predict = {
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = parentNodeAgg.predict
val prob = parentNodeAgg.prob(predict)
val predict = calculatePredict(parentNodeAgg)
val impurity = parentNodeAgg.calculate()

new Predict(predict, prob)
(predict, impurity)
}

/**
Expand All @@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging {
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
featuresForNode: Option[Array[Int]],
node: Node): (Split, InformationGainStats, Predict) = {

// calculate predict only once
var predict: Option[Predict] = None
// calculate predict and impurity if current node is top node
val level = Node.indexToLevel(node.id)
var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
None
} else {
Some((node.predict, node.impurity))
}

// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Expand All @@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
Expand All @@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
Expand Down Expand Up @@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
rightChildStats, binAggregates.metadata)
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
Expand All @@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)

assert(predict.isDefined, "must calculate predict for each node")

(bestSplit, bestSplitStats, predict.get)
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param leftPredict left node predict
* @param rightPredict right node predict
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double) extends Serializable {
val rightImpurity: Double,
val leftPredict: Predict,
val rightPredict: Predict) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
Expand All @@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
37 changes: 30 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
*
* @param id integer node id, from 1
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param impurity current node impurity
* @param isLeaf whether the node is a leaf
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
Expand All @@ -41,15 +42,16 @@ import org.apache.spark.mllib.linalg.Vector
@DeveloperApi
class Node (
val id: Int,
var predict: Double,
var predict: Predict,
var impurity: Double,
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) extends Serializable with Logging {

override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
"split = " + split + ", stats = " + stats
"impurity = " + impurity + "split = " + split + ", stats = " + stats

/**
* build the left node and right nodes if not leaf
Expand All @@ -62,6 +64,7 @@ class Node (
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
logDebug("impurity = " + impurity)
if (!isLeaf) {
leftNode = Some(nodes(Node.leftChildIndex(id)))
rightNode = Some(nodes(Node.rightChildIndex(id)))
Expand All @@ -77,7 +80,7 @@ class Node (
*/
def predict(features: Vector) : Double = {
if (isLeaf) {
predict
predict.predict
} else{
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
Expand Down Expand Up @@ -109,7 +112,7 @@ class Node (
} else {
Some(rightNode.get.deepCopy())
}
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}

/**
Expand Down Expand Up @@ -154,7 +157,7 @@ class Node (
}
val prefix: String = " " * indentFactor
if (isLeaf) {
prefix + s"Predict: $predict\n"
prefix + s"Predict: ${predict.predict}\n"
} else {
prefix + s"If ${splitToString(split.get, left=true)}\n" +
leftNode.get.subtreeToString(indentFactor + 1) +
Expand All @@ -170,7 +173,27 @@ private[tree] object Node {
/**
* Return a node with the given node id (but nothing else set).
*/
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
false, None, None, None, None)

/**
* Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
* This is used in `DecisionTree.findBestSplits` to construct child nodes
* after finding the best splits for parent nodes.
* Other fields are set at next level.
* @param nodeIndex integer node id, from 1
* @param predict predicted value at the node
* @param impurity current node impurity
* @param isLeaf whether the node is a leaf
* @return new node instance
*/
def apply(
nodeIndex: Int,
predict: Predict,
impurity: Double,
isLeaf: Boolean): Node = {
new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
}

/**
* Return the index of the left child of this node.
Expand Down
Loading