Skip to content

Commit 73ab7f1

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up
DecisionTree needs to match each example to a node at each iteration. It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node. Fix: Filter top-down using the partly built tree itself. Major changes: * Eliminated Filter class, findBinsForLevel() method. * Set up node parent links in main loop over levels in train(). * Added predictNodeIndex() for filtering top-down. * Added DTMetadata class Other changes: * Pre-compute set of unorderedFeatures. Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]: * The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy). I've done initial tests indicating that this speeds things up, but am only now running large-scale ones. CC: mengxr manishamde chouqin Any comments are welcome---thanks! Author: Joseph K. Bradley <[email protected]> Closes #1975 from jkbradley/dt-opt2 and squashes the following commits: a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
1 parent fbad722 commit 73ab7f1

File tree

9 files changed

+615
-630
lines changed

9 files changed

+615
-630
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 386 additions & 492 deletions
Large diffs are not rendered by default.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impl
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.mllib.regression.LabeledPoint
23+
import org.apache.spark.mllib.tree.configuration.Algo._
24+
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
25+
import org.apache.spark.mllib.tree.configuration.Strategy
26+
import org.apache.spark.mllib.tree.impurity.Impurity
27+
import org.apache.spark.rdd.RDD
28+
29+
30+
/**
31+
* Learning and dataset metadata for DecisionTree.
32+
*
33+
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
34+
* For regression: fixed at 0 (no meaning).
35+
* @param featureArity Map: categorical feature index --> arity.
36+
* I.e., the feature takes values in {0, ..., arity - 1}.
37+
*/
38+
private[tree] class DecisionTreeMetadata(
39+
val numFeatures: Int,
40+
val numExamples: Long,
41+
val numClasses: Int,
42+
val maxBins: Int,
43+
val featureArity: Map[Int, Int],
44+
val unorderedFeatures: Set[Int],
45+
val impurity: Impurity,
46+
val quantileStrategy: QuantileStrategy) extends Serializable {
47+
48+
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
49+
50+
def isClassification: Boolean = numClasses >= 2
51+
52+
def isMulticlass: Boolean = numClasses > 2
53+
54+
def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
55+
56+
def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
57+
58+
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
59+
60+
}
61+
62+
private[tree] object DecisionTreeMetadata {
63+
64+
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
65+
66+
val numFeatures = input.take(1)(0).features.size
67+
val numExamples = input.count()
68+
val numClasses = strategy.algo match {
69+
case Classification => strategy.numClassesForClassification
70+
case Regression => 0
71+
}
72+
73+
val maxBins = math.min(strategy.maxBins, numExamples).toInt
74+
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
75+
76+
val unorderedFeatures = new mutable.HashSet[Int]()
77+
if (numClasses > 2) {
78+
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
79+
if (k - 1 < log2MaxBinsp1) {
80+
// Note: The above check is equivalent to checking:
81+
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
82+
unorderedFeatures.add(f)
83+
} else {
84+
// TODO: Allow this case, where we simply will know nothing about some categories?
85+
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
86+
s"in categorical features (>= $k)")
87+
}
88+
}
89+
} else {
90+
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
91+
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
92+
s"in categorical features (>= $k)")
93+
}
94+
}
95+
96+
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
97+
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
98+
strategy.impurity, strategy.quantileCalculationStrategy)
99+
}
100+
101+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.mllib.tree.impl
1919

2020
import org.apache.spark.mllib.regression.LabeledPoint
21-
import org.apache.spark.mllib.tree.configuration.Strategy
2221
import org.apache.spark.mllib.tree.model.Bin
2322
import org.apache.spark.rdd.RDD
2423

@@ -48,50 +47,35 @@ private[tree] object TreePoint {
4847
* Convert an input dataset into its TreePoint representation,
4948
* binning feature values in preparation for DecisionTree training.
5049
* @param input Input dataset.
51-
* @param strategy DecisionTree training info, used for dataset metadata.
5250
* @param bins Bins for features, of size (numFeatures, numBins).
51+
* @param metadata Learning and dataset metadata
5352
* @return TreePoint dataset representation
5453
*/
5554
def convertToTreeRDD(
5655
input: RDD[LabeledPoint],
57-
strategy: Strategy,
58-
bins: Array[Array[Bin]]): RDD[TreePoint] = {
56+
bins: Array[Array[Bin]],
57+
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
5958
input.map { x =>
60-
TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
61-
strategy.categoricalFeaturesInfo)
59+
TreePoint.labeledPointToTreePoint(x, bins, metadata)
6260
}
6361
}
6462

6563
/**
6664
* Convert one LabeledPoint into its TreePoint representation.
6765
* @param bins Bins for features, of size (numFeatures, numBins).
68-
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
6966
*/
7067
private def labeledPointToTreePoint(
7168
labeledPoint: LabeledPoint,
72-
isMulticlassClassification: Boolean,
7369
bins: Array[Array[Bin]],
74-
categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
70+
metadata: DecisionTreeMetadata): TreePoint = {
7571

7672
val numFeatures = labeledPoint.features.size
7773
val numBins = bins(0).size
7874
val arr = new Array[Int](numFeatures)
7975
var featureIndex = 0
8076
while (featureIndex < numFeatures) {
81-
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
82-
val isFeatureContinuous = featureInfo.isEmpty
83-
if (isFeatureContinuous) {
84-
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
85-
bins, categoricalFeaturesInfo)
86-
} else {
87-
val featureCategories = featureInfo.get
88-
val isSpaceSufficientForAllCategoricalSplits
89-
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
90-
val isUnorderedFeature =
91-
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
92-
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
93-
isUnorderedFeature, bins, categoricalFeaturesInfo)
94-
}
77+
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
78+
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
9579
featureIndex += 1
9680
}
9781

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
2020
import org.apache.spark.mllib.tree.configuration.FeatureType._
2121

2222
/**
23-
* Used for "binning" the features bins for faster best split calculation. For a continuous
24-
* feature, a bin is determined by a low and a high "split". For a categorical feature,
25-
* the a bin is determined using a single label value (category).
23+
* Used for "binning" the features bins for faster best split calculation.
24+
*
25+
* For a continuous feature, the bin is determined by a low and a high split,
26+
* where an example with featureValue falls into the bin s.t.
27+
* lowSplit.threshold < featureValue <= highSplit.threshold.
28+
*
29+
* For ordered categorical features, there is a 1-1-1 correspondence between
30+
* bins, splits, and feature values. The bin is determined by category/feature value.
31+
* However, the bins are not necessarily ordered by feature value;
32+
* they are ordered using impurity.
33+
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
34+
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
35+
*
2636
* @param lowSplit signifying the lower threshold for the continuous feature to be
2737
* accepted in the bin
2838
* @param highSplit signifying the upper threshold for the continuous feature to be
2939
* accepted in the bin
3040
* @param featureType type of feature -- categorical or continuous
31-
* @param category categorical label value accepted in the bin for binary classification
41+
* @param category categorical label value accepted in the bin for ordered features
3242
*/
3343
private[tree]
3444
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
3939
* @return Double prediction from the trained model
4040
*/
4141
def predict(features: Vector): Double = {
42-
topNode.predictIfLeaf(features)
42+
topNode.predict(features)
4343
}
4444

4545
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala

Lines changed: 0 additions & 28 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,24 @@ class Node (
6969

7070
/**
7171
* predict value if node is not leaf
72-
* @param feature feature value
72+
* @param features feature value
7373
* @return predicted value
7474
*/
75-
def predictIfLeaf(feature: Vector) : Double = {
75+
def predict(features: Vector) : Double = {
7676
if (isLeaf) {
7777
predict
7878
} else{
7979
if (split.get.featureType == Continuous) {
80-
if (feature(split.get.feature) <= split.get.threshold) {
81-
leftNode.get.predictIfLeaf(feature)
80+
if (features(split.get.feature) <= split.get.threshold) {
81+
leftNode.get.predict(features)
8282
} else {
83-
rightNode.get.predictIfLeaf(feature)
83+
rightNode.get.predict(features)
8484
}
8585
} else {
86-
if (split.get.categories.contains(feature(split.get.feature))) {
87-
leftNode.get.predictIfLeaf(feature)
86+
if (split.get.categories.contains(features(split.get.feature))) {
87+
leftNode.get.predict(features)
8888
} else {
89-
rightNode.get.predictIfLeaf(feature)
89+
rightNode.get.predict(features)
9090
}
9191
}
9292
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2424
* :: DeveloperApi ::
2525
* Split applied to a feature
2626
* @param feature feature index
27-
* @param threshold threshold for continuous feature
27+
* @param threshold Threshold for continuous feature.
28+
* Split left if feature <= threshold, else right.
2829
* @param featureType type of feature -- categorical or continuous
29-
* @param categories accepted values for categorical variables
30+
* @param categories Split left if categorical feature value is in this set, else right.
3031
*/
3132
@DeveloperApi
3233
case class Split(

0 commit comments

Comments
 (0)