Skip to content

Commit bec5565

Browse files
committed
Merge pull request #13 from jkbradley/dt-features-2
Added more unit tests
2 parents a8d8583 + 0840e91 commit bec5565

File tree

9 files changed

+185
-25
lines changed

9 files changed

+185
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,8 @@ private[tree] class LearningNode(
275275
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
276276
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
277277
}
278-
279278
}
280279
}
281-
282280
}
283281

284282
private[tree] object LearningNode {
@@ -292,8 +290,8 @@ private[tree] object LearningNode {
292290
}
293291

294292
/** Create an empty node with the given node index. Values must be set later on. */
295-
def emptyNode(nodeIndex: Int): LearningNode = {
296-
new LearningNode(nodeIndex, None, None, None, false, null)
293+
def emptyNode(id: Int): LearningNode = {
294+
new LearningNode(id, None, None, None, false, null)
297295
}
298296

299297
// The below indexing methods were copied from spark.mllib.tree.model.Node

mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ private[ml] object AltDT extends Logging {
137137
// rather than 1 copy per worker. This means a lot of random accesses.
138138
// We could improve this by applying first-level sorting (by node) to labels.
139139

140-
// TODO: RIGHT HERE NOW: JUST ADDED ISUNORDERED
141-
142140
// Sort each column by feature values.
143141
val colStore: RDD[FeatureVector] = colStoreInit.map { case (featureIndex: Int, col: Vector) =>
144142
val featureArity: Int = strategy.categoricalFeaturesInfo.getOrElse(featureIndex, 0)
@@ -293,9 +291,11 @@ private[ml] object AltDT extends Logging {
293291
* On driver: Grow tree based on chosen splits, and compute new set of active nodes.
294292
* @param oldPeriphery Old periphery of active nodes.
295293
* @param bestSplitsAndGains Best (split, gain) pairs, which can be zipped with the old
296-
* periphery.
294+
* periphery. These stats will be used to replace the stats in
295+
* any nodes which are split.
297296
* @param minInfoGain Threshold for min info gain required to split a node.
298-
* @return New active node periphery
297+
* @return New active node periphery.
298+
* If a node is split, then this method will update its fields.
299299
*/
300300
private[impl] def computeActiveNodePeriphery(
301301
oldPeriphery: Array[LearningNode],
@@ -482,12 +482,13 @@ private[ml] object AltDT extends Logging {
482482

483483
var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid
484484
val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
485-
var bestGain: Double = -1.0
485+
var bestGain: Double = 0.0
486486
val fullImpurity = rightImpurityAgg.getCalculator.calculate()
487487
var leftCount: Double = 0.0
488488
var rightCount: Double = rightImpurityAgg.getCount
489489
val fullCount: Double = rightCount
490490

491+
// Consider all splits. These only cover valid splits, with at least one category on each side.
491492
val numSplits = categoriesSortedByCentroid.length - 1
492493
var sortedCatIndex = 0
493494
while (sortedCatIndex < numSplits) {
@@ -512,9 +513,6 @@ private[ml] object AltDT extends Logging {
512513
sortedCatIndex += 1
513514
}
514515

515-
assert(bestSplitIndex != -1, "Unknown error in AltDT split selection for ordered categorical" +
516-
s" variable with numSplits = $numSplits.")
517-
518516
val categoriesForSplit =
519517
categoriesSortedByCentroid.slice(0, bestSplitIndex + 1).map(_.toDouble)
520518
val bestFeatureSplit =
@@ -524,7 +522,7 @@ private[ml] object AltDT extends Logging {
524522
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
525523
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
526524

527-
if (bestSplitIndex == 0 || bestSplitIndex == categoriesSortedByCentroid.length - 1) {
525+
if (bestSplitIndex == -1 || bestGain == 0.0) {
528526
(None, bestImpurityStats)
529527
} else {
530528
(Some(bestFeatureSplit), bestImpurityStats)

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ private[ml] object RandomForest extends Logging {
143143
rng.setSeed(seed)
144144

145145
// Allocate and queue root nodes.
146-
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
146+
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(id = 1))
147147
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
148148

149149
while (nodeQueue.nonEmpty) {

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,8 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
166166

167167
override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])"
168168

169+
private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
170+
case o: EntropyCalculator => stats.sameElements(other.stats)
171+
case _ => false
172+
}
169173
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,8 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
162162

163163
override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])"
164164

165+
private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
166+
case o: GiniCalculator => stats.sameElements(other.stats)
167+
case _ => false
168+
}
165169
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,6 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
185185
result._1
186186
}
187187

188+
/** Test exact equality */
189+
private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean
188190
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,8 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
136136
s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})"
137137
}
138138

139+
private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
140+
case o: VarianceCalculator => stats.sameElements(other.stats)
141+
case _ => false
142+
}
139143
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private[spark] object InformationGainStats {
8888
* @param rightImpurityCalculator impurity statistics for right child node
8989
* @param valid whether the current split satisfies minimum info gain or
9090
* minimum number of instances per node
91+
* TODO: Can we remove this? Not sure if this is used anywhere...
9192
*/
9293
@DeveloperApi
9394
private[spark] class ImpurityStats(
@@ -114,6 +115,15 @@ private[spark] class ImpurityStats(
114115
} else {
115116
-1.0
116117
}
118+
119+
/** Test exact equality */
120+
private[spark] def exactlyEquals(other: ImpurityStats): Boolean = {
121+
gain == other.gain && impurity == other.impurity &&
122+
impurityCalculator.exactlyEquals(other.impurityCalculator) &&
123+
leftImpurityCalculator.exactlyEquals(other.leftImpurityCalculator) &&
124+
rightImpurityCalculator.exactlyEquals(other.rightImpurityCalculator) &&
125+
valid == other.valid
126+
}
117127
}
118128

119129
private[spark] object ImpurityStats {

mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala

Lines changed: 151 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.ml.tree.impl
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.regression.DecisionTreeRegressor
22-
import org.apache.spark.ml.tree.{LeafNode, InternalNode, ContinuousSplit}
22+
import org.apache.spark.ml.tree._
2323
import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo}
2424
import org.apache.spark.ml.tree.impl.TreeUtil._
2525
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
2626
import org.apache.spark.mllib.regression.LabeledPoint
27-
import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity}
27+
import org.apache.spark.mllib.tree.impurity._
28+
import org.apache.spark.mllib.tree.model.ImpurityStats
2829
import org.apache.spark.mllib.util.MLlibTestSparkContext
2930
import org.apache.spark.util.collection.BitSet
3031

@@ -44,7 +45,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
4445
.setMaxDepth(10)
4546
.setAlgorithm("byCol")
4647
val model = dt.fit(df)
47-
println(model.toDebugString) // TODO: remove println
4848
assert(model.rootNode.isInstanceOf[InternalNode])
4949
val root = model.rootNode.asInstanceOf[InternalNode]
5050
assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[LeafNode])
@@ -61,7 +61,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
6161
.setMaxDepth(10)
6262
.setAlgorithm("byCol")
6363
val model = dt.fit(df)
64-
println(model.toDebugString) // TODO: remove println
6564
assert(model.rootNode.isInstanceOf[InternalNode])
6665
val root = model.rootNode.asInstanceOf[InternalNode]
6766
assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode])
@@ -147,18 +146,94 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
147146
//////////////////////////////// Choosing splits //////////////////////////////////
148147

149148
test("computeBestSplits") {
149+
// TODO
150150
}
151151

152-
test("chooseSplit") {
152+
test("chooseSplit: choose correct type of split") {
153+
val labels = Seq(0, 0, 0, 1, 1, 1, 1).map(_.toDouble).toArray
154+
val fromOffset = 1
155+
val toOffset = 4
156+
val impurity = Entropy
157+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
158+
159+
val col1 = FeatureVector.fromOriginal(featureIndex = 0, featureArity = 0,
160+
featureVector = Vectors.dense(0.8, 0.1, 0.1, 0.2, 0.3, 0.5, 0.6))
161+
val (split1, _) = AltDT.chooseSplit(col1, labels, fromOffset, toOffset, metadata)
162+
assert(split1.nonEmpty && split1.get.isInstanceOf[ContinuousSplit])
163+
164+
val col2 = FeatureVector.fromOriginal(featureIndex = 0, featureArity = 3,
165+
featureVector = Vectors.dense(0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0))
166+
val (split2, _) = AltDT.chooseSplit(col2, labels, fromOffset, toOffset, metadata)
167+
assert(split2.nonEmpty && split2.get.isInstanceOf[CategoricalSplit])
153168
}
154169

155170
test("chooseOrderedCategoricalSplit: basic case") {
171+
val featureIndex = 0
172+
val values = Seq(0, 0, 1, 2, 2, 2, 2).map(_.toDouble)
173+
val featureArity = values.max.toInt + 1
174+
175+
def testHelper(
176+
labels: Seq[Double],
177+
expectedLeftCategories: Array[Double],
178+
expectedLeftStats: Array[Double],
179+
expectedRightStats: Array[Double]): Unit = {
180+
val expectedRightCategories = Range(0, featureArity)
181+
.filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray
182+
val impurity = Entropy
183+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
184+
val (split, stats) =
185+
AltDT.chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
186+
split match {
187+
case Some(s: CategoricalSplit) =>
188+
assert(s.featureIndex === featureIndex)
189+
assert(s.leftCategories === expectedLeftCategories)
190+
assert(s.rightCategories === expectedRightCategories)
191+
case _ =>
192+
throw new AssertionError(
193+
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
194+
}
195+
val fullImpurityStatsArray =
196+
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
197+
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
198+
assert(stats.gain === fullImpurity)
199+
assert(stats.impurity === fullImpurity)
200+
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
201+
assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
202+
assert(stats.rightImpurityCalculator.stats === expectedRightStats)
203+
assert(stats.valid)
204+
}
205+
206+
val labels1 = Seq(0, 0, 1, 1, 1, 1, 1).map(_.toDouble)
207+
testHelper(labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0))
208+
209+
val labels2 = Seq(0, 0, 0, 1, 1, 1, 1).map(_.toDouble)
210+
testHelper(labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0))
156211
}
157212

158-
test("chooseOrderedCategoricalSplit: return bad split if best split is on end") {
213+
test("chooseOrderedCategoricalSplit: return bad split if we should not split") {
214+
val featureIndex = 0
215+
val values = Seq(0, 0, 1, 2, 2, 2, 2).map(_.toDouble)
216+
val featureArity = values.max.toInt + 1
217+
218+
val labels = Seq(1, 1, 1, 1, 1, 1, 1).map(_.toDouble)
219+
220+
val impurity = Entropy
221+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
222+
val (split, stats) =
223+
AltDT.chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
224+
assert(split.isEmpty)
225+
val fullImpurityStatsArray =
226+
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
227+
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
228+
assert(stats.gain === 0.0)
229+
assert(stats.impurity === fullImpurity)
230+
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
231+
assert(stats.valid)
159232
}
160233

161-
// test("chooseUnorderedCategoricalSplit") { }
234+
// test("chooseUnorderedCategoricalSplit: basic case") { }
235+
236+
// test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }
162237

163238
test("chooseContinuousSplit: basic case") {
164239
val featureIndex = 0
@@ -175,7 +250,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
175250
throw new AssertionError(
176251
s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}")
177252
}
178-
val fullImpurityStatsArray = Array(2.0, 3.0)
253+
val fullImpurityStatsArray =
254+
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
179255
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
180256
assert(stats.gain === fullImpurity)
181257
assert(stats.impurity === fullImpurity)
@@ -185,8 +261,23 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
185261
assert(stats.valid)
186262
}
187263

188-
// TODO: Add this test once we make this change.
189-
// test("chooseContinuousSplit: return bad split if best split is on end") { }
264+
test("chooseContinuousSplit: return bad split if we should not split") {
265+
val featureIndex = 0
266+
val values = Seq(0.1, 0.2, 0.3, 0.4, 0.5)
267+
val labels = Seq(0.0, 0.0, 0.0, 0.0, 0.0)
268+
val impurity = Entropy
269+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
270+
val (split, stats) = AltDT.chooseContinuousSplit(featureIndex, values, labels, metadata)
271+
// split should be None
272+
assert(split.isEmpty)
273+
// stats for parent node should be correct
274+
val fullImpurityStatsArray =
275+
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
276+
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
277+
assert(stats.gain === 0.0)
278+
assert(stats.impurity === fullImpurity)
279+
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
280+
}
190281

191282
//////////////////////////////// Bit subvectors //////////////////////////////////
192283

@@ -258,6 +349,55 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
258349
//////////////////////////////// Active nodes //////////////////////////////////
259350

260351
test("computeActiveNodePeriphery") {
352+
// old periphery: 2 nodes
353+
val left = LearningNode.emptyNode(id = 1)
354+
val right = LearningNode.emptyNode(id = 2)
355+
val oldPeriphery: Array[LearningNode] = Array(left, right)
356+
// bestSplitsAndGains: Do not split left, but split right node.
357+
val lCalc = new EntropyCalculator(Array(8.0, 1.0))
358+
val lStats = new ImpurityStats(0.0, lCalc.calculate(),
359+
lCalc, lCalc, new EntropyCalculator(Array(0.0, 0.0)))
360+
361+
val rSplit = new ContinuousSplit(featureIndex = 1, threshold = 0.6)
362+
val rCalc = new EntropyCalculator(Array(5.0, 7.0))
363+
val rRightChildCalc = new EntropyCalculator(Array(1.0, 5.0))
364+
val rLeftChildCalc = new EntropyCalculator(Array(
365+
rCalc.stats(0) - rRightChildCalc.stats(0),
366+
rCalc.stats(1) - rRightChildCalc.stats(1)))
367+
val rGain = {
368+
val rightWeight = rRightChildCalc.stats.sum / rCalc.stats.sum
369+
val leftWeight = rLeftChildCalc.stats.sum / rCalc.stats.sum
370+
rCalc.calculate() -
371+
rightWeight * rRightChildCalc.calculate() - leftWeight * rLeftChildCalc.calculate()
372+
}
373+
val rStats =
374+
new ImpurityStats(rGain, rCalc.calculate(), rCalc, rLeftChildCalc, rRightChildCalc)
375+
376+
val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
377+
Array((None, lStats), (Some(rSplit), rStats))
378+
379+
// Test A: Split right node
380+
val newPeriphery1: Array[LearningNode] =
381+
AltDT.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 0.0)
382+
// Expect 2 active nodes
383+
assert(newPeriphery1.length === 2)
384+
// Confirm right node was updated
385+
assert(right.split.get === rSplit)
386+
assert(!right.isLeaf)
387+
assert(right.stats.exactlyEquals(rStats))
388+
assert(right.leftChild.nonEmpty && right.leftChild.get === newPeriphery1(0))
389+
assert(right.rightChild.nonEmpty && right.rightChild.get === newPeriphery1(1))
390+
// Confirm new active nodes have stats but no children
391+
assert(newPeriphery1(0).leftChild.isEmpty && newPeriphery1(0).rightChild.isEmpty &&
392+
newPeriphery1(0).split.isEmpty &&
393+
newPeriphery1(0).stats.impurityCalculator.exactlyEquals(rLeftChildCalc))
394+
assert(newPeriphery1(1).leftChild.isEmpty && newPeriphery1(1).rightChild.isEmpty &&
395+
newPeriphery1(1).split.isEmpty &&
396+
newPeriphery1(1).stats.impurityCalculator.exactlyEquals(rRightChildCalc))
397+
398+
// Test B: Increase minInfoGain, so split nothing
399+
val newPeriphery2: Array[LearningNode] =
400+
AltDT.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 1000.0)
401+
assert(newPeriphery2.isEmpty)
261402
}
262-
263403
}

0 commit comments

Comments
 (0)