Skip to content

Commit 2116360

Browse files
committed
removing dummy bin calculation for categorical variables
1 parent 6068356 commit 2116360

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,8 @@ object DecisionTree extends Serializable with Logging {
340340
}
341341
throw new UnknownError("no bin was found for continuous variable.")
342342
} else {
343-
344-
for (binIndex <- 0 until strategy.numBins) {
343+
val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
344+
for (binIndex <- 0 until numCategoricalBins) {
345345
val bin = bins(featureIndex)(binIndex)
346346
val category = bin.category
347347
val features = labeledPoint.features
@@ -917,13 +917,6 @@ object DecisionTree extends Serializable with Logging {
917917
bins(featureIndex)(numBins-1)
918918
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex,
919919
Continuous), Continuous, Double.MinValue)
920-
} else {
921-
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
922-
for (i <- maxFeatureValue until numBins){
923-
bins(featureIndex)(i)
924-
= new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
925-
new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue)
926-
}
927920
}
928921
}
929922
(splits,bins)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6464
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
6565
assert(arr.length == 1000)
6666
val rdd = sc.parallelize(arr)
67-
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
67+
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2,
68+
1-> 2))
6869
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
6970
assert(splits.length==2)
7071
assert(bins.length==2)
@@ -120,7 +121,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
120121
assert(bins(0)(1).highSplit.categories.contains(1.0))
121122
assert(bins(0)(1).highSplit.categories.contains(0.0))
122123

123-
assert(bins(0)(2).category == Double.MaxValue)
124+
assert(bins(0)(2) == null)
124125

125126
assert(bins(1)(0).category == 0.0)
126127
assert(bins(1)(0).lowSplit.categories.length == 0)
@@ -134,15 +135,16 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
134135
assert(bins(1)(1).highSplit.categories.contains(0.0))
135136
assert(bins(1)(1).highSplit.categories.contains(1.0))
136137

137-
assert(bins(1)(2).category == Double.MaxValue)
138+
assert(bins(1)(2) == null)
138139

139140
}
140141

141142
test("split and bin calculations for categorical variables with no sample for one category"){
142143
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
143144
assert(arr.length == 1000)
144145
val rdd = sc.parallelize(arr)
145-
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
146+
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3,
147+
1-> 3))
146148
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
147149

148150
//Checking splits
@@ -217,7 +219,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
217219
assert(bins(0)(2).highSplit.categories.contains(0.0))
218220
assert(bins(0)(2).highSplit.categories.contains(2.0))
219221

220-
assert(bins(0)(3).category == Double.MaxValue)
222+
assert(bins(0)(3) == null)
221223

222224
assert(bins(1)(0).category == 0.0)
223225
assert(bins(1)(0).lowSplit.categories.length == 0)
@@ -240,7 +242,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
240242
assert(bins(1)(2).highSplit.categories.contains(1.0))
241243
assert(bins(1)(2).highSplit.categories.contains(2.0))
242244

243-
assert(bins(1)(3).category == Double.MaxValue)
245+
assert(bins(1)(3) == null)
244246

245247

246248
}
@@ -249,10 +251,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
249251
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
250252
assert(arr.length == 1000)
251253
val rdd = sc.parallelize(arr)
252-
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
254+
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3,
255+
1-> 3))
253256
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
254257
strategy.numBins = 100
255-
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
258+
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
259+
Array[List[Filter]](), splits, bins)
256260

257261
val split = bestSplits(0)._1
258262
assert(split.categories.length == 1)
@@ -272,10 +276,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
272276
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
273277
assert(arr.length == 1000)
274278
val rdd = sc.parallelize(arr)
275-
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
279+
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3,
280+
1-> 3))
276281
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
277282
strategy.numBins = 100
278-
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
283+
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
284+
Array[List[Filter]](), splits, bins)
279285

280286
val split = bestSplits(0)._1
281287
assert(split.categories.length == 1)
@@ -305,7 +311,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
305311
assert(bins(0).length==100)
306312

307313
strategy.numBins = 100
308-
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
314+
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
315+
Array[List[Filter]](), splits, bins)
309316
assert(bestSplits.length == 1)
310317
assert(0==bestSplits(0)._1.feature)
311318
assert(10==bestSplits(0)._1.threshold)
@@ -329,7 +336,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
329336
assert(bins(0).length==100)
330337

331338
strategy.numBins = 100
332-
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
339+
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
340+
Array[List[Filter]](), splits, bins)
333341
assert(bestSplits.length == 1)
334342
assert(0==bestSplits(0)._1.feature)
335343
assert(10==bestSplits(0)._1.threshold)
@@ -355,7 +363,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
355363
assert(bins(0).length==100)
356364

357365
strategy.numBins = 100
358-
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
366+
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
367+
Array[List[Filter]](), splits, bins)
359368
assert(bestSplits.length == 1)
360369
assert(0==bestSplits(0)._1.feature)
361370
assert(10==bestSplits(0)._1.threshold)
@@ -379,7 +388,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
379388
assert(bins(0).length==100)
380389

381390
strategy.numBins = 100
382-
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
391+
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
392+
Array[List[Filter]](), splits, bins)
383393
assert(bestSplits.length == 1)
384394
assert(0==bestSplits(0)._1.feature)
385395
assert(10==bestSplits(0)._1.threshold)

0 commit comments

Comments
 (0)