Skip to content

Commit 31ef80b

Browse files
committed
WIP simplifying TreeSplitUtilsSuite
1 parent b4a5f3b commit 31ef80b

File tree

1 file changed

+17
-32
lines changed

1 file changed

+17
-32
lines changed

mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
class TreeSplitUtilsSuite
2929
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3030

31-
/**
32-
* Iterate over feature values and labels for a specific (node, feature), updating stats
33-
* aggregator for the current node.
34-
*/
35-
private[impl] def updateAggregator(
36-
statsAggregator: DTStatsAggregator,
37-
featureIndex: Int,
38-
values: Array[Int],
39-
indices: Array[Int],
40-
instanceWeights: Array[Double],
41-
labels: Array[Double],
42-
from: Int,
43-
to: Int,
44-
featureIndexIdx: Int,
45-
featureSplits: Array[Split]): Unit = {
46-
val metadata = statsAggregator.metadata
47-
from.until(to).foreach { idx =>
48-
val rowIndex = indices(idx)
49-
if (metadata.isUnordered(featureIndex)) {
50-
AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
51-
featureIndex = featureIndex, featureIndexIdx, featureSplits,
52-
instanceWeight = instanceWeights(rowIndex))
53-
} else {
54-
AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
55-
featureIndexIdx, instanceWeight = instanceWeights(rowIndex))
56-
}
57-
}
58-
}
59-
6031
/**
6132
* Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
6233
* with the data from the specified training points.
@@ -71,11 +42,25 @@ class TreeSplitUtilsSuite
7142

7243
val featureIndex = 0
7344
val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None)
74-
val instanceWeights = Array.fill[Double](values.length)(1.0)
7545
val indices = values.indices.toArray
46+
val instanceWeights = Array.fill[Double](values.length)(1.0)
47+
// Update parent impurity stats
7648
AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
77-
updateAggregator(statsAggregator, featureIndex = 0, values, indices, instanceWeights, labels,
78-
from, to, featureIndex, featureSplits)
49+
// Update current aggregator's impurity stats
50+
from.until(to).foreach { idx =>
51+
val rowIndex = indices(idx)
52+
if (metadata.isUnordered(featureIndex)) {
53+
AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
54+
featureIndex = featureIndex, featureIndexIdx, featureSplits,
55+
instanceWeight = 1.0)
56+
} else {
57+
AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
58+
featureIndexIdx, instanceWeight = 1.0)
59+
}
60+
}
61+
62+
updateAggregator(statsAggregator, featureIndex = 0, featureIndexIdx = 0, values, indices,
63+
labels, from, to, featureSplits)
7964
statsAggregator
8065
}
8166

0 commit comments

Comments
 (0)