@@ -232,7 +232,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
232232 assert(tree2.rootNode.prediction === lp.label)
233233 }
234234
235- ignore (" Multiclass classification with unordered categorical features: split calculations" ) {
235+ test (" Multiclass classification with unordered categorical features: split calculations" ) {
236236 val arr = OldDTSuite .generateCategoricalDataPoints().map(_.asML)
237237 assert(arr.length === 1000 )
238238 val rdd = sc.parallelize(arr)
@@ -249,28 +249,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
249249 assert(metadata.isUnordered(featureIndex = 1 ))
250250 val splits = RandomForest .findSplits(rdd, metadata, seed = 42 )
251251 assert(splits.length === 2 )
252- assert(splits(0 ).length === 3 )
252+ assert(splits(0 ).length === 0 )
253253 assert(metadata.numSplits(0 ) === 3 )
254254 assert(metadata.numBins(0 ) === 3 )
255255 assert(metadata.numSplits(1 ) === 3 )
256256 assert(metadata.numBins(1 ) === 3 )
257-
258- // Expecting 2^2 - 1 = 3 splits per feature
259- def checkCategoricalSplit (s : Split , featureIndex : Int , leftCategories : Array [Double ]): Unit = {
260- assert(s.featureIndex === featureIndex)
261- assert(s.isInstanceOf [CategoricalSplit ])
262- val s0 = s.asInstanceOf [CategoricalSplit ]
263- assert(s0.leftCategories === leftCategories)
264- assert(s0.numCategories === 3 ) // for this unit test
265- }
266- // Feature 0
267- checkCategoricalSplit(splits(0 )(0 ), 0 , Array (0.0 ))
268- checkCategoricalSplit(splits(0 )(1 ), 0 , Array (1.0 ))
269- checkCategoricalSplit(splits(0 )(2 ), 0 , Array (0.0 , 1.0 ))
270- // Feature 1
271- checkCategoricalSplit(splits(1 )(0 ), 1 , Array (0.0 ))
272- checkCategoricalSplit(splits(1 )(1 ), 1 , Array (1.0 ))
273- checkCategoricalSplit(splits(1 )(2 ), 1 , Array (0.0 , 1.0 ))
274257 }
275258
276259 test(" Multiclass classification with ordered categorical features: split calculations" ) {
@@ -631,6 +614,42 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
631614 val expected = Map (0 -> 1.0 / 3.0 , 2 -> 2.0 / 3.0 )
632615 assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01 )
633616 }
617+
618+ test(" traverseUnorderedSplits" ) {
619+
620+ val numBins = 8
621+ val numSplits = DecisionTreeMetadata .numUnorderedSplits(numBins)
622+
623+ val resultCheck = Array .fill(numSplits + 1 )(false )
624+
625+ RandomForest .traverseUnorderedSplits[Int ](numBins, 0 ,
626+ (statsVal, binIndex) => statsVal + (1 << binIndex),
627+ (bitSet, statsVal) => {
628+ // We get a combination here, the bitSet mark the bits to be true
629+ // which are in the combination.
630+ // the statsVal is the combNumber:
631+ // e.g.
632+ // suppose get combination [0,0,1,0,1,1,0,1] (binIndex from high to low)
633+ // then the statsVal == the number which binary representation is "00101101"
634+
635+ // 1. check the combination do not be traversed more than once
636+ assert(resultCheck(statsVal) === false )
637+ resultCheck(statsVal) = true
638+
639+ // 2. check the combNumber we get is correct.
640+ // e.g combNumber "00101101" (binary format) match the combination stored in
641+ // the bitSet [0,0,1,0,1,1,0,1]
642+ for (i <- 0 until numBins) {
643+ val testBit = (((statsVal >> i) & 1 ) == 1 )
644+ assert(bitSet.get(i) === testBit)
645+ }
646+ }
647+ )
648+ // 3. check the traverse cover all combinations (total combination number = numSplits)
649+ for (i <- 1 to numSplits) {
650+ assert(resultCheck(i) === true )
651+ }
652+ }
634653}
635654
636655private object RandomForestSuite {
0 commit comments