@@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
2626import scala .util .control .Breaks ._
2727import org .apache .spark .mllib .tree .configuration .Strategy
2828import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
29+ import org .apache .spark .mllib .tree .configuration .FeatureType ._
2930
3031
3132class DecisionTree (val strategy : Strategy ) extends Serializable with Logging {
@@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging {
353354 def extractLeftRightNodeAggregates (binData : Array [Double ]): (Array [Array [Double ]], Array [Array [Double ]]) = {
354355 val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
355356 val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
356- // logDebug("binData.length = " + binData.length)
357- // logDebug("binData.sum = " + binData.sum)
358357 for (featureIndex <- 0 until numFeatures) {
359- // logDebug("featureIndex = " + featureIndex)
360358 val shift = 2 * featureIndex* numSplits
361359 leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
362- // logDebug("binData(shift + 0) = " + binData(shift + 0))
363360 leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
364- // logDebug("binData(shift + 1) = " + binData(shift + 1))
365361 rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
366- // logDebug(binData(shift + (2 * (numSplits - 1))))
367362 rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
368- // logDebug(binData(shift + (2 * (numSplits - 1)) + 1))
369363 for (splitIndex <- 1 until numSplits - 1 ) {
370- // logDebug("splitIndex = " + splitIndex)
371364 leftNodeAgg(featureIndex)(2 * splitIndex)
372365 = binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
373366 leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
@@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging {
479472
480473 // Find all splits
481474 for (featureIndex <- 0 until numFeatures){
482- val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
483-
484- val stride : Double = numSamples.toDouble/ numBins
485- logDebug(" stride = " + stride)
486- for (index <- 0 until numBins- 1 ) {
487- val sampleIndex = (index+ 1 )* stride.toInt
488- val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
489- splits(featureIndex)(index) = split
475+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
476+ if (isFeatureContinous) {
477+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
478+
479+ val stride : Double = numSamples.toDouble/ numBins
480+ logDebug(" stride = " + stride)
481+ for (index <- 0 until numBins- 1 ) {
482+ val sampleIndex = (index+ 1 )* stride.toInt
483+ val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous )
484+ splits(featureIndex)(index) = split
485+ }
486+ } else {
487+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
488+ for (index <- 0 until maxFeatureValue){
489+ // TODO: Sort by centriod
490+ val split = new Split (featureIndex,index,Categorical )
491+ splits(featureIndex)(index) = split
492+ }
490493 }
491494 }
492495
493496 // Find all bins
494497 for (featureIndex <- 0 until numFeatures){
495498 bins(featureIndex)(0 )
496- = new Bin (new DummyLowSplit (" continuous " ),splits(featureIndex)(0 )," continuous " )
499+ = new Bin (new DummyLowSplit (Continuous ),splits(featureIndex)(0 ),Continuous )
497500 for (index <- 1 until numBins - 1 ){
498- val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index)," continuous " )
501+ val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous )
499502 bins(featureIndex)(index) = bin
500503 }
501504 bins(featureIndex)(numBins- 1 )
502- = new Bin (splits(featureIndex)(numBins- 3 ),new DummyHighSplit (" continuous " ), " continuous " )
505+ = new Bin (splits(featureIndex)(numBins- 3 ),new DummyHighSplit (Continuous ), Continuous )
503506 }
504507
505508 (splits,bins)
506509 }
507510 case MinMax => {
508- ( Array .ofDim[ Split ](numFeatures,numBins), Array .ofDim[ Bin ](numFeatures,numBins + 2 ) )
511+ throw new UnsupportedOperationException ( " minmax not supported yet. " )
509512 }
510513 case ApproxHist => {
511514 throw new UnsupportedOperationException (" approximate histogram not supported yet." )
0 commit comments