@@ -28,35 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2828class TreeSplitUtilsSuite
2929 extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3030
31- /**
32- * Iterate over feature values and labels for a specific (node, feature), updating stats
33- * aggregator for the current node.
34- */
35- private [impl] def updateAggregator (
36- statsAggregator : DTStatsAggregator ,
37- featureIndex : Int ,
38- values : Array [Int ],
39- indices : Array [Int ],
40- instanceWeights : Array [Double ],
41- labels : Array [Double ],
42- from : Int ,
43- to : Int ,
44- featureIndexIdx : Int ,
45- featureSplits : Array [Split ]): Unit = {
46- val metadata = statsAggregator.metadata
47- from.until(to).foreach { idx =>
48- val rowIndex = indices(idx)
49- if (metadata.isUnordered(featureIndex)) {
50- AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
51- featureIndex = featureIndex, featureIndexIdx, featureSplits,
52- instanceWeight = instanceWeights(rowIndex))
53- } else {
54- AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
55- featureIndexIdx, instanceWeight = instanceWeights(rowIndex))
56- }
57- }
58- }
59-
6031 /**
6132 * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
6233 * with the data from the specified training points.
@@ -71,11 +42,25 @@ class TreeSplitUtilsSuite
7142
7243 val featureIndex = 0
7344 val statsAggregator = new DTStatsAggregator (metadata, featureSubset = None )
74- val instanceWeights = Array .fill[Double ](values.length)(1.0 )
7545 val indices = values.indices.toArray
46+ val instanceWeights = Array .fill[Double ](values.length)(1.0 )
47+ // Update parent impurity stats
7648 AggUpdateUtils .updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
77- updateAggregator(statsAggregator, featureIndex = 0 , values, indices, instanceWeights, labels,
78- from, to, featureIndex, featureSplits)
49+ // Update current aggregator's impurity stats
50+ from.until(to).foreach { idx =>
51+ val rowIndex = indices(idx)
52+ if (metadata.isUnordered(featureIndex)) {
53+ AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
54+ featureIndex = featureIndex, featureIndexIdx, featureSplits,
55+ instanceWeight = 1.0 )
56+ } else {
57+ AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
58+ featureIndexIdx, instanceWeight = 1.0 )
59+ }
60+ }
61+
62+ updateAggregator(statsAggregator, featureIndex = 0 , featureIndexIdx = 0 , values, indices,
63+ labels, from, to, featureSplits)
7964 statsAggregator
8065 }
8166
0 commit comments