@@ -19,12 +19,13 @@ package org.apache.spark.ml.tree.impl
1919
2020import org .apache .spark .SparkFunSuite
2121import org .apache .spark .ml .regression .DecisionTreeRegressor
22- import org .apache .spark .ml .tree .{ LeafNode , InternalNode , ContinuousSplit }
22+ import org .apache .spark .ml .tree ._
2323import org .apache .spark .ml .tree .impl .AltDT .{AltDTMetadata , FeatureVector , PartitionInfo }
2424import org .apache .spark .ml .tree .impl .TreeUtil ._
2525import org .apache .spark .mllib .linalg .{SparseVector , Vector , Vectors }
2626import org .apache .spark .mllib .regression .LabeledPoint
27- import org .apache .spark .mllib .tree .impurity .{Variance , Gini , Entropy , Impurity }
27+ import org .apache .spark .mllib .tree .impurity ._
28+ import org .apache .spark .mllib .tree .model .ImpurityStats
2829import org .apache .spark .mllib .util .MLlibTestSparkContext
2930import org .apache .spark .util .collection .BitSet
3031
@@ -44,7 +45,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
4445 .setMaxDepth(10 )
4546 .setAlgorithm(" byCol" )
4647 val model = dt.fit(df)
47- println(model.toDebugString) // TODO: remove println
4848 assert(model.rootNode.isInstanceOf [InternalNode ])
4949 val root = model.rootNode.asInstanceOf [InternalNode ]
5050 assert(root.leftChild.isInstanceOf [InternalNode ] && root.rightChild.isInstanceOf [LeafNode ])
@@ -61,7 +61,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
6161 .setMaxDepth(10 )
6262 .setAlgorithm(" byCol" )
6363 val model = dt.fit(df)
64- println(model.toDebugString) // TODO: remove println
6564 assert(model.rootNode.isInstanceOf [InternalNode ])
6665 val root = model.rootNode.asInstanceOf [InternalNode ]
6766 assert(root.leftChild.isInstanceOf [InternalNode ] && root.rightChild.isInstanceOf [InternalNode ])
@@ -147,18 +146,94 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
147146 // ////////////////////////////// Choosing splits //////////////////////////////////
148147
149148 test(" computeBestSplits" ) {
149+ // TODO
150150 }
151151
152- test(" chooseSplit" ) {
152+ test(" chooseSplit: choose correct type of split" ) {
153+ val labels = Seq (0 , 0 , 0 , 1 , 1 , 1 , 1 ).map(_.toDouble).toArray
154+ val fromOffset = 1
155+ val toOffset = 4
156+ val impurity = Entropy
157+ val metadata = new AltDTMetadata (numClasses = 2 , maxBins = 4 , minInfoGain = 0.0 , impurity)
158+
159+ val col1 = FeatureVector .fromOriginal(featureIndex = 0 , featureArity = 0 ,
160+ featureVector = Vectors .dense(0.8 , 0.1 , 0.1 , 0.2 , 0.3 , 0.5 , 0.6 ))
161+ val (split1, _) = AltDT .chooseSplit(col1, labels, fromOffset, toOffset, metadata)
162+ assert(split1.nonEmpty && split1.get.isInstanceOf [ContinuousSplit ])
163+
164+ val col2 = FeatureVector .fromOriginal(featureIndex = 0 , featureArity = 3 ,
165+ featureVector = Vectors .dense(0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 2.0 , 2.0 ))
166+ val (split2, _) = AltDT .chooseSplit(col2, labels, fromOffset, toOffset, metadata)
167+ assert(split2.nonEmpty && split2.get.isInstanceOf [CategoricalSplit ])
153168 }
154169
155170 test(" chooseOrderedCategoricalSplit: basic case" ) {
171+ val featureIndex = 0
172+ val values = Seq (0 , 0 , 1 , 2 , 2 , 2 , 2 ).map(_.toDouble)
173+ val featureArity = values.max.toInt + 1
174+
175+ def testHelper (
176+ labels : Seq [Double ],
177+ expectedLeftCategories : Array [Double ],
178+ expectedLeftStats : Array [Double ],
179+ expectedRightStats : Array [Double ]): Unit = {
180+ val expectedRightCategories = Range (0 , featureArity)
181+ .filter(c => ! expectedLeftCategories.contains(c)).map(_.toDouble).toArray
182+ val impurity = Entropy
183+ val metadata = new AltDTMetadata (numClasses = 2 , maxBins = 4 , minInfoGain = 0.0 , impurity)
184+ val (split, stats) =
185+ AltDT .chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
186+ split match {
187+ case Some (s : CategoricalSplit ) =>
188+ assert(s.featureIndex === featureIndex)
189+ assert(s.leftCategories === expectedLeftCategories)
190+ assert(s.rightCategories === expectedRightCategories)
191+ case _ =>
192+ throw new AssertionError (
193+ s " Expected CategoricalSplit but got ${split.getClass.getSimpleName}" )
194+ }
195+ val fullImpurityStatsArray =
196+ Array (labels.count(_ == 0.0 ).toDouble, labels.count(_ == 1.0 ).toDouble)
197+ val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
198+ assert(stats.gain === fullImpurity)
199+ assert(stats.impurity === fullImpurity)
200+ assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
201+ assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
202+ assert(stats.rightImpurityCalculator.stats === expectedRightStats)
203+ assert(stats.valid)
204+ }
205+
206+ val labels1 = Seq (0 , 0 , 1 , 1 , 1 , 1 , 1 ).map(_.toDouble)
207+ testHelper(labels1, Array (0.0 ), Array (2.0 , 0.0 ), Array (0.0 , 5.0 ))
208+
209+ val labels2 = Seq (0 , 0 , 0 , 1 , 1 , 1 , 1 ).map(_.toDouble)
210+ testHelper(labels2, Array (0.0 , 1.0 ), Array (3.0 , 0.0 ), Array (0.0 , 4.0 ))
156211 }
157212
158- test(" chooseOrderedCategoricalSplit: return bad split if best split is on end" ) {
213+ test(" chooseOrderedCategoricalSplit: return bad split if we should not split" ) {
214+ val featureIndex = 0
215+ val values = Seq (0 , 0 , 1 , 2 , 2 , 2 , 2 ).map(_.toDouble)
216+ val featureArity = values.max.toInt + 1
217+
218+ val labels = Seq (1 , 1 , 1 , 1 , 1 , 1 , 1 ).map(_.toDouble)
219+
220+ val impurity = Entropy
221+ val metadata = new AltDTMetadata (numClasses = 2 , maxBins = 4 , minInfoGain = 0.0 , impurity)
222+ val (split, stats) =
223+ AltDT .chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
224+ assert(split.isEmpty)
225+ val fullImpurityStatsArray =
226+ Array (labels.count(_ == 0.0 ).toDouble, labels.count(_ == 1.0 ).toDouble)
227+ val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
228+ assert(stats.gain === 0.0 )
229+ assert(stats.impurity === fullImpurity)
230+ assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
231+ assert(stats.valid)
159232 }
160233
161- // test("chooseUnorderedCategoricalSplit") { }
234+ // test("chooseUnorderedCategoricalSplit: basic case") { }
235+
236+ // test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }
162237
163238 test(" chooseContinuousSplit: basic case" ) {
164239 val featureIndex = 0
@@ -175,7 +250,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
175250 throw new AssertionError (
176251 s " Expected ContinuousSplit but got ${split.getClass.getSimpleName}" )
177252 }
178- val fullImpurityStatsArray = Array (2.0 , 3.0 )
253+ val fullImpurityStatsArray =
254+ Array (labels.count(_ == 0.0 ).toDouble, labels.count(_ == 1.0 ).toDouble)
179255 val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
180256 assert(stats.gain === fullImpurity)
181257 assert(stats.impurity === fullImpurity)
@@ -185,8 +261,23 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
185261 assert(stats.valid)
186262 }
187263
188- // TODO: Add this test once we make this change.
189- // test("chooseContinuousSplit: return bad split if best split is on end") { }
264+ test(" chooseContinuousSplit: return bad split if we should not split" ) {
265+ val featureIndex = 0
266+ val values = Seq (0.1 , 0.2 , 0.3 , 0.4 , 0.5 )
267+ val labels = Seq (0.0 , 0.0 , 0.0 , 0.0 , 0.0 )
268+ val impurity = Entropy
269+ val metadata = new AltDTMetadata (numClasses = 2 , maxBins = 4 , minInfoGain = 0.0 , impurity)
270+ val (split, stats) = AltDT .chooseContinuousSplit(featureIndex, values, labels, metadata)
271+ // split should be None
272+ assert(split.isEmpty)
273+ // stats for parent node should be correct
274+ val fullImpurityStatsArray =
275+ Array (labels.count(_ == 0.0 ).toDouble, labels.count(_ == 1.0 ).toDouble)
276+ val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
277+ assert(stats.gain === 0.0 )
278+ assert(stats.impurity === fullImpurity)
279+ assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
280+ }
190281
191282 // ////////////////////////////// Bit subvectors //////////////////////////////////
192283
@@ -258,6 +349,55 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
258349 // ////////////////////////////// Active nodes //////////////////////////////////
259350
260351 test(" computeActiveNodePeriphery" ) {
352+ // old periphery: 2 nodes
353+ val left = LearningNode .emptyNode(id = 1 )
354+ val right = LearningNode .emptyNode(id = 2 )
355+ val oldPeriphery : Array [LearningNode ] = Array (left, right)
356+ // bestSplitsAndGains: Do not split left, but split right node.
357+ val lCalc = new EntropyCalculator (Array (8.0 , 1.0 ))
358+ val lStats = new ImpurityStats (0.0 , lCalc.calculate(),
359+ lCalc, lCalc, new EntropyCalculator (Array (0.0 , 0.0 )))
360+
361+ val rSplit = new ContinuousSplit (featureIndex = 1 , threshold = 0.6 )
362+ val rCalc = new EntropyCalculator (Array (5.0 , 7.0 ))
363+ val rRightChildCalc = new EntropyCalculator (Array (1.0 , 5.0 ))
364+ val rLeftChildCalc = new EntropyCalculator (Array (
365+ rCalc.stats(0 ) - rRightChildCalc.stats(0 ),
366+ rCalc.stats(1 ) - rRightChildCalc.stats(1 )))
367+ val rGain = {
368+ val rightWeight = rRightChildCalc.stats.sum / rCalc.stats.sum
369+ val leftWeight = rLeftChildCalc.stats.sum / rCalc.stats.sum
370+ rCalc.calculate() -
371+ rightWeight * rRightChildCalc.calculate() - leftWeight * rLeftChildCalc.calculate()
372+ }
373+ val rStats =
374+ new ImpurityStats (rGain, rCalc.calculate(), rCalc, rLeftChildCalc, rRightChildCalc)
375+
376+ val bestSplitsAndGains : Array [(Option [Split ], ImpurityStats )] =
377+ Array ((None , lStats), (Some (rSplit), rStats))
378+
379+ // Test A: Split right node
380+ val newPeriphery1 : Array [LearningNode ] =
381+ AltDT .computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 0.0 )
382+ // Expect 2 active nodes
383+ assert(newPeriphery1.length === 2 )
384+ // Confirm right node was updated
385+ assert(right.split.get === rSplit)
386+ assert(! right.isLeaf)
387+ assert(right.stats.exactlyEquals(rStats))
388+ assert(right.leftChild.nonEmpty && right.leftChild.get === newPeriphery1(0 ))
389+ assert(right.rightChild.nonEmpty && right.rightChild.get === newPeriphery1(1 ))
390+ // Confirm new active nodes have stats but no children
391+ assert(newPeriphery1(0 ).leftChild.isEmpty && newPeriphery1(0 ).rightChild.isEmpty &&
392+ newPeriphery1(0 ).split.isEmpty &&
393+ newPeriphery1(0 ).stats.impurityCalculator.exactlyEquals(rLeftChildCalc))
394+ assert(newPeriphery1(1 ).leftChild.isEmpty && newPeriphery1(1 ).rightChild.isEmpty &&
395+ newPeriphery1(1 ).split.isEmpty &&
396+ newPeriphery1(1 ).stats.impurityCalculator.exactlyEquals(rRightChildCalc))
397+
398+ // Test B: Increase minInfoGain, so split nothing
399+ val newPeriphery2 : Array [LearningNode ] =
400+ AltDT .computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 1000.0 )
401+ assert(newPeriphery2.isEmpty)
261402 }
262-
263403}
0 commit comments