Skip to content

Commit e79abfd

Browse files
committed
add unit test
1 parent 93c7e0f commit e79abfd

File tree

1 file changed

+38
-19
lines changed

1 file changed

+38
-19
lines changed

mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
232232
assert(tree2.rootNode.prediction === lp.label)
233233
}
234234

235-
ignore("Multiclass classification with unordered categorical features: split calculations") {
235+
test("Multiclass classification with unordered categorical features: split calculations") {
236236
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
237237
assert(arr.length === 1000)
238238
val rdd = sc.parallelize(arr)
@@ -249,28 +249,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
249249
assert(metadata.isUnordered(featureIndex = 1))
250250
val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
251251
assert(splits.length === 2)
252-
assert(splits(0).length === 3)
252+
assert(splits(0).length === 0)
253253
assert(metadata.numSplits(0) === 3)
254254
assert(metadata.numBins(0) === 3)
255255
assert(metadata.numSplits(1) === 3)
256256
assert(metadata.numBins(1) === 3)
257-
258-
// Expecting 2^2 - 1 = 3 splits per feature
259-
def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = {
260-
assert(s.featureIndex === featureIndex)
261-
assert(s.isInstanceOf[CategoricalSplit])
262-
val s0 = s.asInstanceOf[CategoricalSplit]
263-
assert(s0.leftCategories === leftCategories)
264-
assert(s0.numCategories === 3) // for this unit test
265-
}
266-
// Feature 0
267-
checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
268-
checkCategoricalSplit(splits(0)(1), 0, Array(1.0))
269-
checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0))
270-
// Feature 1
271-
checkCategoricalSplit(splits(1)(0), 1, Array(0.0))
272-
checkCategoricalSplit(splits(1)(1), 1, Array(1.0))
273-
checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0))
274257
}
275258

276259
test("Multiclass classification with ordered categorical features: split calculations") {
@@ -631,6 +614,42 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
631614
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
632615
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
633616
}
617+
618+
test("traverseUnorderedSplits") {
619+
620+
val numBins = 8
621+
val numSplits = DecisionTreeMetadata.numUnorderedSplits(numBins)
622+
623+
val resultCheck = Array.fill(numSplits + 1)(false)
624+
625+
RandomForest.traverseUnorderedSplits[Int](numBins, 0,
626+
(statsVal, binIndex) => statsVal + (1 << binIndex),
627+
(bitSet, statsVal) => {
628+
// We get a combination here, the bitSet mark the bits to be true
629+
// which are in the combination.
630+
// the statsVal is the combNumber:
631+
// e.g.
632+
// suppose get combination [0,0,1,0,1,1,0,1] (binIndex from high to low)
633+
// then the statsVal == the number which binary representation is "00101101"
634+
635+
// 1. check the combination do not be traversed more than once
636+
assert(resultCheck(statsVal) === false)
637+
resultCheck(statsVal) = true
638+
639+
// 2. check the combNumber we get is correct.
640+
// e.g combNumber "00101101" (binary format) match the combination stored in
641+
// the bitSet [0,0,1,0,1,1,0,1]
642+
for (i <- 0 until numBins) {
643+
val testBit = (((statsVal >> i) & 1) == 1)
644+
assert(bitSet.get(i) === testBit)
645+
}
646+
}
647+
)
648+
// 3. check the traverse cover all combinations (total combination number = numSplits)
649+
for (i <- 1 to numSplits) {
650+
assert(resultCheck(i) === true)
651+
}
652+
}
634653
}
635654

636655
private object RandomForestSuite {

0 commit comments

Comments
 (0)