Skip to content

Commit ff363a7

Browse files
committed
binary search for bins and while loop for categorical feature bins
1 parent 632818f commit ff363a7

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ object DecisionTree extends Serializable with Logging {
319319
true
320320
}
321321

322+
// TODO: Unit test this
322323
/**
323324
* Finds the right bin for the given feature
324325
*/
@@ -328,26 +329,47 @@ object DecisionTree extends Serializable with Logging {
328329
isFeatureContinuous: Boolean)
329330
: Int = {
330331

331-
if (isFeatureContinuous){
332-
for (binIndex <- 0 until strategy.numBins) {
333-
val bin = bins(featureIndex)(binIndex)
332+
val binForFeatures = bins(featureIndex)
333+
val feature = labeledPoint.features(featureIndex)
334+
335+
def binarySearchForBins(): Int = {
336+
var left = 0
337+
var right = binForFeatures.length-1
338+
while (left <= right) {
339+
val mid = left + (right - left) / 2
340+
val bin = binForFeatures(mid)
334341
val lowThreshold = bin.lowSplit.threshold
335342
val highThreshold = bin.highSplit.threshold
336-
val features = labeledPoint.features
337-
if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) {
338-
return binIndex
343+
if ((lowThreshold < feature) & (highThreshold >= feature)){
344+
return mid
345+
}
346+
else if ((lowThreshold >= feature)){
347+
right = mid - 1
339348
}
349+
else {
350+
left = mid + 1
351+
}
352+
}
353+
-1
354+
}
355+
356+
if (isFeatureContinuous){
357+
val binIndex = binarySearchForBins()
358+
if (binIndex == -1){
359+
throw new UnknownError("no bin was found for continuous variable.")
340360
}
341-
throw new UnknownError("no bin was found for continuous variable.")
361+
binIndex
342362
} else {
343363
val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
344-
for (binIndex <- 0 until numCategoricalBins) {
364+
var binIndex = 0
365+
while (binIndex < numCategoricalBins) {
345366
val bin = bins(featureIndex)(binIndex)
346367
val category = bin.category
347368
val features = labeledPoint.features
348369
if (category == features(featureIndex)) {
349370
return binIndex
350371
}
372+
binIndex += 1
351373
}
352374
throw new UnknownError("no bin was found for categorical variable.")
353375

0 commit comments

Comments
 (0)