Skip to content

Commit 711356b

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregation improvements
Summary: 1. Variable numBins for each feature [SPARK-3043] 2. Reduced data reshaping in aggregation [SPARK-3043] 3. Choose ordering for ordered categorical features adaptively [SPARK-3156] 4. Changed nodes to use 1-indexing [SPARK-3086] 5. Small clean-ups Note: This PR looks bigger than it is since I moved several functions from inside findBestSplitsPerGroup to outside of it (to make it clear what was being serialized in the aggregation). Speedups: This update helps most when many features use few bins but a few features use many bins. Some example results on speedups with 2M examples, 3.5K features (15-worker EC2 cluster): * Example where old code was reasonably efficient (1/2 continuous, 1/4 binary, 1/4 20-category): 164.813 --> 116.491 sec * Example where old code wasted many bins (1/10 continuous, 81/100 binary, 9/100 20-category): 128.701 --> 39.334 sec Details: (1) Variable numBins for each feature [SPARK-3043] DecisionTreeMetadata now computes a variable numBins for each feature. It also tracks numSplits. (2) Reduced data reshaping in aggregation [SPARK-3043] Added DTStatsAggregator, a wrapper around the aggregate statistics array for easy but efficient indexing. * Added ImpurityAggregator and ImpurityCalculator classes, to make DecisionTree code more oblivious to the type of impurity. * Design note: I originally tried creating Impurity classes which stored data and storing the aggregates in an Array[Array[Array[Impurity]]]. However, this led to significant slowdowns, perhaps because of overhead in creating so many objects. The aggregate statistics are never reshaped, and cumulative sums are computed in-place. Updated the layout of aggregation functions. The update simplifies things by (1) dividing features into ordered/unordered (instead of ordered/unordered/continuous) and (2) making use of the DTStatsAggregator for indexing. For this update, the following functions were refactored: * updateBinForOrderedFeature * updateBinForUnorderedFeature * binaryOrNotCategoricalBinSeqOp * multiclassWithCategoricalBinSeqOp * regressionBinSeqOp The above 5 functions were replaced with: * orderedBinSeqOp * someUnorderedBinSeqOp Other changes: * calculateGainForSplit now treats all feature types the same way. * Eliminated extractLeftRightNodeAggregates. (3) Choose ordering for ordered categorical features adaptively [SPARK-3156] Updated binsToBestSplit(): * This now computes cumulative sums of stats for ordered features. * For ordered categorical features, it chooses an ordering for categories. (This uses to be done by findSplitsBins.) * Uses iterators to shorten code and avoid building an Array[Array[InformationGainStats]]. Side effects: * In findSplitsBins: A sample of the data is only taken for data with continuous features. It is not needed for data with only categorical features. * In findSplitsBins: splits and bins are no longer pre-computed for ordered categorical features since they are not needed. * TreePoint binning is simpler for categorical features. (4) Changed nodes to use 1-indexing [SPARK-3086] Nodes used to be indexed from 0. Now they are indexed from 1. Node indexing functions are now collected in object Node (Node.scala). (5) Small clean-ups Eliminated functions extractNodeInfo() and extractInfoForLowerLevels() to reduce duplicate code. Eliminated InvalidBinIndex since it is no longer used. CC: mengxr manishamde Please let me know if you have thoughts on this—thanks! Author: Joseph K. Bradley <[email protected]> Closes #2125 from jkbradley/dt-opt3alt and squashes the following commits: 42c192a [Joseph K. Bradley] Merge branch 'rfs' into dt-opt3alt d3cc46b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 00e4404 [Joseph K. Bradley] optimization for TreePoint construction (pre-computing featureArity and isUnordered as arrays) 425716c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs a2acea5 [Joseph K. Bradley] Small optimizations based on profiling aa4e4df [Joseph K. Bradley] Updated DTStatsAggregator with bug fix (nodeString should not be multiplied by statsSize) 4651154 [Joseph K. Bradley] Changed numBins semantics for unordered features. * Before: numBins = numSplits = (1 << k - 1) - 1 * Now: numBins = 2 * numSplits = 2 * [(1 << k - 1) - 1] * This also involved changing the semantics of: ** DecisionTreeMetadata.numUnorderedBins() 1e3b1c7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 1485fcc [Joseph K. Bradley] Made some DecisionTree methods private. 92f934f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt e676da1 [Joseph K. Bradley] Updated documentation for DecisionTree 37ca845 [Joseph K. Bradley] Fixed problem with how DecisionTree handles ordered categorical features. 105f8ab [Joseph K. Bradley] Removed commented-out getEmptyBinAggregates from DecisionTree 062c31d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 6d32ccd [Joseph K. Bradley] In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code. 807cd00 [Joseph K. Bradley] Finished DTStatsAggregator, a wrapper around the aggregate statistics for easy but hopefully efficient indexing. Modified old ImpurityAggregator classes and renamed them ImpurityCalculator; added ImpurityAggregator classes which work with DTStatsAggregator but do not store data. Unit tests all succeed. f2166fd [Joseph K. Bradley] still working on DTStatsAggregator 92f7118 [Joseph K. Bradley] Added partly written DTStatsAggregator fd8df30 [Joseph K. Bradley] Moved some aggregation helpers outside of findBestSplitsPerGroup d7c53ee [Joseph K. Bradley] Added more doc for ImpurityAggregator a40f8f1 [Joseph K. Bradley] Changed nodes to be indexed from 1. Tests work. 95cad7c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 5f94342 [Joseph K. Bradley] Added treeAggregate since not yet merged from master. Moved node indexing functions to Node. 61c4509 [Joseph K. Bradley] Fixed bugs from merge: missing DT timer call, and numBins setting. Cleaned up DT Suite some. 3ba7166 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 b314659 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 9c83363 [Joseph K. Bradley] partial merge but not done yet 45f7ea7 [Joseph K. Bradley] partial merge, not yet done 5fce635 [Joseph K. Bradley] Merge branch 'dt-opt2' into dt-opt3 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 dd4d3aa [Joseph K. Bradley] Mid-process in bug fix: bug for binary classification with categorical features * Bug: Categorical features were all treated as ordered for binary classification. This is possible but would require the bin ordering to be determined on-the-fly after the aggregation. Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code. I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done. 438a660 [Joseph K. Bradley] removed subsampling for mnist8m from DT 86e217f [Joseph K. Bradley] added cache to DT input e3c84cc [Joseph K. Bradley] Added stuff fro mnist8m to D T Runner 51ef781 [Joseph K. Bradley] Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally fd65372 [Joseph K. Bradley] Major changes: * Created ImpurityAggregator classes, rather than old aggregates. * Feature split/bin semantics are based on ordered vs. unordered ** E.g.: numSplits = numBins for all unordered features, and numSplits = numBins - 1 for all ordered features. * numBins can differ for each feature c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
1 parent 0d1cc4a commit 711356b

File tree

11 files changed

+1322
-1248
lines changed

11 files changed

+1322
-1248
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 450 additions & 891 deletions
Large diffs are not rendered by default.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impl
19+
20+
import org.apache.spark.mllib.tree.impurity._
21+
22+
/**
23+
* DecisionTree statistics aggregator.
24+
* This holds a flat array of statistics for a set of (nodes, features, bins)
25+
* and helps with indexing.
26+
*/
27+
private[tree] class DTStatsAggregator(
28+
val metadata: DecisionTreeMetadata,
29+
val numNodes: Int) extends Serializable {
30+
31+
/**
32+
* [[ImpurityAggregator]] instance specifying the impurity type.
33+
*/
34+
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
35+
case Gini => new GiniAggregator(metadata.numClasses)
36+
case Entropy => new EntropyAggregator(metadata.numClasses)
37+
case Variance => new VarianceAggregator()
38+
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
39+
}
40+
41+
/**
42+
* Number of elements (Double values) used for the sufficient statistics of each bin.
43+
*/
44+
val statsSize: Int = impurityAggregator.statsSize
45+
46+
val numFeatures: Int = metadata.numFeatures
47+
48+
/**
49+
* Number of bins for each feature. This is indexed by the feature index.
50+
*/
51+
val numBins: Array[Int] = metadata.numBins
52+
53+
/**
54+
* Number of splits for the given feature.
55+
*/
56+
def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
57+
58+
/**
59+
* Indicator for each feature of whether that feature is an unordered feature.
60+
* TODO: Is Array[Boolean] any faster?
61+
*/
62+
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
63+
64+
/**
65+
* Offset for each feature for calculating indices into the [[allStats]] array.
66+
*/
67+
private val featureOffsets: Array[Int] = {
68+
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
69+
if (isUnordered(featureIndex)) {
70+
total + 2 * numBins(featureIndex)
71+
} else {
72+
total + numBins(featureIndex)
73+
}
74+
}
75+
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
76+
}
77+
78+
/**
79+
* Number of elements for each node, corresponding to stride between nodes in [[allStats]].
80+
*/
81+
private val nodeStride: Int = featureOffsets.last
82+
83+
/**
84+
* Total number of elements stored in this aggregator.
85+
*/
86+
val allStatsSize: Int = numNodes * nodeStride
87+
88+
/**
89+
* Flat array of elements.
90+
* Index for start of stats for a (node, feature, bin) is:
91+
* index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
92+
* Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
93+
* and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
94+
*/
95+
val allStats: Array[Double] = new Array[Double](allStatsSize)
96+
97+
/**
98+
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
99+
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
100+
* from [[getNodeFeatureOffset]].
101+
* For unordered features, this is a pre-computed
102+
* (node, feature, left/right child) offset from
103+
* [[getLeftRightNodeFeatureOffsets]].
104+
*/
105+
def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
106+
impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
107+
}
108+
109+
/**
110+
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
111+
*/
112+
def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
113+
val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
114+
impurityAggregator.update(allStats, i, label)
115+
}
116+
117+
/**
118+
* Pre-compute node offset for use with [[nodeUpdate]].
119+
*/
120+
def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
121+
122+
/**
123+
* Faster version of [[update]].
124+
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
125+
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
126+
*/
127+
def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
128+
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
129+
impurityAggregator.update(allStats, i, label)
130+
}
131+
132+
/**
133+
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
134+
* For ordered features only.
135+
*/
136+
def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
137+
require(!isUnordered(featureIndex),
138+
s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
139+
s" for unordered feature $featureIndex.")
140+
nodeIndex * nodeStride + featureOffsets(featureIndex)
141+
}
142+
143+
/**
144+
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
145+
* For unordered features only.
146+
*/
147+
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
148+
require(isUnordered(featureIndex),
149+
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
150+
s" but was called for ordered feature $featureIndex.")
151+
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
152+
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
153+
}
154+
155+
/**
156+
* Faster version of [[update]].
157+
* Update the stats for a given (node, feature, bin), using the given label.
158+
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
159+
* from [[getNodeFeatureOffset]].
160+
* For unordered features, this is a pre-computed
161+
* (node, feature, left/right child) offset from
162+
* [[getLeftRightNodeFeatureOffsets]].
163+
*/
164+
def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
165+
impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
166+
}
167+
168+
/**
169+
* For a given (node, feature), merge the stats for two bins.
170+
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
171+
* from [[getNodeFeatureOffset]].
172+
* For unordered features, this is a pre-computed
173+
* (node, feature, left/right child) offset from
174+
* [[getLeftRightNodeFeatureOffsets]].
175+
* @param binIndex The other bin is merged into this bin.
176+
* @param otherBinIndex This bin is not modified.
177+
*/
178+
def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
179+
impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
180+
nodeFeatureOffset + otherBinIndex * statsSize)
181+
}
182+
183+
/**
184+
* Merge this aggregator with another, and returns this aggregator.
185+
* This method modifies this aggregator in-place.
186+
*/
187+
def merge(other: DTStatsAggregator): DTStatsAggregator = {
188+
require(allStatsSize == other.allStatsSize,
189+
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
190+
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
191+
var i = 0
192+
// TODO: Test BLAS.axpy
193+
while (i < allStatsSize) {
194+
allStats(i) += other.allStats(i)
195+
i += 1
196+
}
197+
this
198+
}
199+
200+
}
201+
202+
private[tree] object DTStatsAggregator extends Serializable {
203+
204+
/**
205+
* Combines two aggregates (modifying the first) and returns the combination.
206+
*/
207+
def binCombOp(
208+
agg1: DTStatsAggregator,
209+
agg2: DTStatsAggregator): DTStatsAggregator = {
210+
agg1.merge(agg2)
211+
}
212+
213+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2626
import org.apache.spark.mllib.tree.impurity.Impurity
2727
import org.apache.spark.rdd.RDD
2828

29-
3029
/**
3130
* Learning and dataset metadata for DecisionTree.
3231
*
3332
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
3433
* For regression: fixed at 0 (no meaning).
34+
* @param maxBins Maximum number of bins, for all features.
3535
* @param featureArity Map: categorical feature index --> arity.
3636
* I.e., the feature takes values in {0, ..., arity - 1}.
37+
* @param numBins Number of bins for each feature.
3738
*/
3839
private[tree] class DecisionTreeMetadata(
3940
val numFeatures: Int,
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
4243
val maxBins: Int,
4344
val featureArity: Map[Int, Int],
4445
val unorderedFeatures: Set[Int],
46+
val numBins: Array[Int],
4547
val impurity: Impurity,
4648
val quantileStrategy: QuantileStrategy) extends Serializable {
4749

@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(
5759

5860
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
5961

62+
/**
63+
* Number of splits for the given feature.
64+
* For unordered features, there are 2 bins per split.
65+
* For ordered features, there is 1 more bin than split.
66+
*/
67+
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
68+
numBins(featureIndex) >> 1
69+
} else {
70+
numBins(featureIndex) - 1
71+
}
72+
6073
}
6174

6275
private[tree] object DecisionTreeMetadata {
6376

77+
/**
78+
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
79+
* This computes which categorical features will be ordered vs. unordered,
80+
* as well as the number of splits and bins for each feature.
81+
*/
6482
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
6583

6684
val numFeatures = input.take(1)(0).features.size
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
7088
case Regression => 0
7189
}
7290

73-
val maxBins = math.min(strategy.maxBins, numExamples).toInt
74-
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
91+
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
92+
93+
// We check the number of bins here against maxPossibleBins.
94+
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
95+
// based on the number of training examples.
96+
if (strategy.categoricalFeaturesInfo.nonEmpty) {
97+
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
98+
require(maxCategoriesPerFeature <= maxPossibleBins,
99+
s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
100+
s"in categorical features (= $maxCategoriesPerFeature)")
101+
}
75102

76103
val unorderedFeatures = new mutable.HashSet[Int]()
104+
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
77105
if (numClasses > 2) {
78-
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
79-
if (k - 1 < log2MaxBinsp1) {
80-
// Note: The above check is equivalent to checking:
81-
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
82-
unorderedFeatures.add(f)
106+
// Multiclass classification
107+
val maxCategoriesForUnorderedFeature =
108+
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
109+
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
110+
// Decide if some categorical features should be treated as unordered features,
111+
// which require 2 * ((1 << numCategories - 1) - 1) bins.
112+
// We do this check with log values to prevent overflows in case numCategories is large.
113+
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
114+
if (numCategories <= maxCategoriesForUnorderedFeature) {
115+
unorderedFeatures.add(featureIndex)
116+
numBins(featureIndex) = numUnorderedBins(numCategories)
83117
} else {
84-
// TODO: Allow this case, where we simply will know nothing about some categories?
85-
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
86-
s"in categorical features (>= $k)")
118+
numBins(featureIndex) = numCategories
87119
}
88120
}
89121
} else {
90-
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
91-
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
92-
s"in categorical features (>= $k)")
122+
// Binary classification or regression
123+
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
124+
numBins(featureIndex) = numCategories
93125
}
94126
}
95127

96-
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
97-
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
128+
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
129+
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
98130
strategy.impurity, strategy.quantileCalculationStrategy)
99131
}
100132

133+
/**
134+
* Given the arity of a categorical feature (arity = number of categories),
135+
* return the number of bins for the feature if it is to be treated as an unordered feature.
136+
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
137+
* there are math.pow(2, arity - 1) - 1 such splits.
138+
* Each split has 2 corresponding bins.
139+
*/
140+
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
141+
101142
}

0 commit comments

Comments
 (0)