Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2) extends AbstractParams[Params]
fracTest: Double = 0.2,
useNodeIdCache: Boolean = false,
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()
Expand Down Expand Up @@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
opt[Boolean]("useNodeIdCache")
.text(s"whether to use node Id cache during training, " +
s"default: ${defaultParams.useNodeIdCache}")
.action((x, c) => c.copy(useNodeIdCache = x))
opt[String]("checkpointDir")
.text(s"checkpoint directory where intermediate node Id caches will be stored, " +
s"default: ${defaultParams.checkpointDir match {
case Some(strVal) => strVal
case None => "None"
}}")
.action((x, c) => c.copy(checkpointDir = Some(x)))
opt[Int]("checkpointInterval")
.text(s"how often to checkpoint the node Id cache, " +
s"default: ${defaultParams.checkpointInterval}")
.action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
Expand Down Expand Up @@ -236,7 +254,10 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache,
checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
Expand Down
114 changes: 98 additions & 16 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
* to pass the entire tree to the executors during
* the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
Expand All @@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
timer: TimeTracker = new TimeTracker): Unit = {
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {

/*
* The high-level descriptions of the best split optimizations are noted here.
Expand Down Expand Up @@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)

/**
* Performs a sequential aggregation over a partition for a particular tree and node.
*
* For each feature, the aggregate sufficient statistics are updated for the relevant
* bins.
*
* @param treeIndex Index of the tree that we want to perform aggregation for.
* @param nodeInfo The node info for the tree node.
* @param agg Array storing aggregate calculation, with a set of sufficient statistics
* for each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
*/
def nodeBinSeqOp(
treeIndex: Int,
nodeInfo: RandomForest.NodeIndexInfo,
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Unit = {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
}

/**
* Performs a sequential aggregation over a partition.
Expand All @@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
// If the example does not reach a node in this group, then nodeIndex = null.
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}

agg
}

/**
* Do the same thing as binSeqOp, but with nodeIdCache.
*/
def binSeqOpWithNodeIdCache(
agg: Array[DTStatsAggregator],
dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val baggedPoint = dataPoint._1
val nodeIdCache = dataPoint._2
val nodeIndex = nodeIdCache(treeIndex)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}

agg
}

Expand Down Expand Up @@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
val nodeToBestSplits =

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}

// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(from above) You can just test useNodeIdCache at this line. That should reduce code duplication.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, one requires zip and the other one doesn't, so fundamentally changes the type of rows.

Additionally, I think if we just branch out within mapPartitions, won't it unnecessarily serialize some things that are not used in one branch and not the other? E.g. it seems that binSeqOp itself becomes object and will be serialized, along with closure items.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point; ignore my comment please.


// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}
} else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
Expand All @@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
}.reduceByKey((a, b) => a.merge(b))
}
}

val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
Expand All @@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {

timer.stop("chooseSplits")

val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
Array.fill[mutable.Map[Int, NodeIndexUpdater]](
metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
} else {
null
}

// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
Expand Down Expand Up @@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))

if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
split = split,
nodeIndex = nodeIndex)
nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
}

// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
Expand All @@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
}
}

if (nodeIdCache.nonEmpty) {
// Update the cache if needed.
nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -160,6 +160,19 @@ private class RandomForest (
* in lower levels).
*/

// Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
val nodeIdCache = if (strategy.useNodeIdCache) {
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
None
}

// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()

Expand All @@ -182,7 +195,7 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}

Expand All @@ -193,6 +206,11 @@ private class RandomForest (
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
nodeIdCache.get.deleteAllCheckpoints()
}

val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
val treeWeights = Array.fill[Double](numTrees)(1.0)
new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
* the node Id cache periodically. This is the checkpoint directory
* to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
Expand All @@ -73,7 +80,10 @@ class Strategy (
@BeanProperty var minInstancesPerNode: Int = 1,
@BeanProperty var minInfoGain: Double = 0.0,
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1) extends Serializable {
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {

if (algo == Classification) {
require(numClassesForClassification >= 2)
Expand Down
Loading