Skip to content

Commit a8d8583

Browse files
committed
Merge pull request #11 from jkbradley/dt-features-1
Choosing splits returns None for split if no valid one found
2 parents 1855176 + 2e7868e commit a8d8583

File tree

2 files changed

+87
-55
lines changed

2 files changed

+87
-55
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ import org.apache.spark.ml.tree._
2525
import org.apache.spark.ml.tree.impl.TreeUtil._
2626
import org.apache.spark.mllib.linalg.Vector
2727
import org.apache.spark.mllib.regression.LabeledPoint
28-
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, FeatureType, Strategy}
28+
import org.apache.spark.mllib.tree.configuration.Strategy
2929
import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity}
3030
import org.apache.spark.mllib.tree.model.ImpurityStats
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.storage.StorageLevel
33-
import org.apache.spark.util.collection.{BitSet, OpenHashSet}
33+
import org.apache.spark.util.collection.BitSet
3434

3535

3636
/**
@@ -151,7 +151,7 @@ private[ml] object AltDT extends Logging {
151151
iterator.foreach(groupedCols += _)
152152
if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator()
153153
}
154-
groupedColStore.repartition(1).persist(StorageLevel.MEMORY_AND_DISK) // TODO: remove repartition
154+
groupedColStore.persist(StorageLevel.MEMORY_AND_DISK)
155155

156156
// Initialize partitions with 1 node (each instance at the root node).
157157
var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols =>
@@ -178,7 +178,7 @@ private[ml] object AltDT extends Logging {
178178
val partitionInfos = partitionInfosDebug.last
179179

180180
// Compute best split for each active node.
181-
val bestSplitsAndGains: Array[(Split, ImpurityStats)] =
181+
val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
182182
computeBestSplits(partitionInfos, labelsBc, metadata)
183183
/*
184184
// NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under
@@ -246,20 +246,22 @@ private[ml] object AltDT extends Logging {
246246
* @param partitionInfos
247247
* @param labelsBc
248248
* @param metadata
249-
* @return
249+
* @return Array over active nodes of (best split, impurity stats for split),
250+
* where the split is None if no useful split exists
250251
*/
251252
private[impl] def computeBestSplits(
252253
partitionInfos: RDD[PartitionInfo],
253254
labelsBc: Broadcast[Array[Double]],
254-
metadata: AltDTMetadata): Array[(Split, ImpurityStats)] = {
255+
metadata: AltDTMetadata): Array[(Option[Split], ImpurityStats)] = {
255256
// On each partition, for each feature on the partition, select the best split for each node.
256257
// This will use:
257258
// - groupedColStore (the features)
258259
// - partitionInfos (the node -> instance mapping)
259260
// - labelsBc (the labels column)
260261
// Each worker returns:
261-
// for each active node, best split + info gain
262-
val partBestSplitsAndGains: RDD[Array[(Split, ImpurityStats)]] = partitionInfos.map {
262+
// for each active node, best split + info gain,
263+
// where the best split is None if no useful split exists
264+
val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map {
263265
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) =>
264266
val localLabels = labelsBc.value
265267
// Iterate over the active nodes in the current level.
@@ -270,7 +272,6 @@ private[ml] object AltDT extends Logging {
270272
columns.map { col =>
271273
chooseSplit(col, localLabels, fromOffset, toOffset, metadata)
272274
}
273-
// We use Iterator and flatMap to handle empty partitions.
274275
splitsAndStats.maxBy(_._2.gain)
275276
}.toArray
276277
}
@@ -298,18 +299,18 @@ private[ml] object AltDT extends Logging {
298299
*/
299300
private[impl] def computeActiveNodePeriphery(
300301
oldPeriphery: Array[LearningNode],
301-
bestSplitsAndGains: Array[(Split, ImpurityStats)],
302+
bestSplitsAndGains: Array[(Option[Split], ImpurityStats)],
302303
minInfoGain: Double): Array[LearningNode] = {
303304
bestSplitsAndGains.zipWithIndex.flatMap {
304305
case ((split, stats), nodeIdx) =>
305306
val node = oldPeriphery(nodeIdx)
306-
if (stats.gain > minInfoGain) {
307+
if (split.nonEmpty && stats.gain > minInfoGain) {
307308
// TODO: remove node id
308309
node.leftChild = Some(LearningNode(node.id * 2, isLeaf = false,
309310
ImpurityStats(stats.leftImpurity, stats.leftImpurityCalculator)))
310311
node.rightChild = Some(LearningNode(node.id * 2 + 1, isLeaf = false,
311312
ImpurityStats(stats.rightImpurity, stats.rightImpurityCalculator)))
312-
node.split = Some(split)
313+
node.split = split
313314
node.isLeaf = false
314315
node.stats = stats
315316
Iterator(node.leftChild.get, node.rightChild.get)
@@ -328,21 +329,22 @@ private[ml] object AltDT extends Logging {
328329
* Correction: Aggregate only the pieces of that vector corresponding to instances at
329330
* active nodes.
330331
* @param partitionInfos RDD with feature data, plus current status metadata
331-
* @param bestSplits Split for each active node
332+
* @param bestSplits Split for each active node, or None if that node will not be split
332333
* @return Array of bit vectors, ordered by offset ranges
333334
*/
334335
private[impl] def collectBitVectors(
335336
partitionInfos: RDD[PartitionInfo],
336-
bestSplits: Array[Split]): Array[BitSubvector] = {
337-
val bestSplitsBc: Broadcast[Array[Split]] = partitionInfos.sparkContext.broadcast(bestSplits)
337+
bestSplits: Array[Option[Split]]): Array[BitSubvector] = {
338+
val bestSplitsBc: Broadcast[Array[Option[Split]]] =
339+
partitionInfos.sparkContext.broadcast(bestSplits)
338340
val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map {
339341
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int],
340342
activeNodes: BitSet) =>
341-
val localBestSplits: Array[Split] = bestSplitsBc.value
343+
val localBestSplits: Array[Option[Split]] = bestSplitsBc.value
342344
// localFeatureIndex[feature index] = index into PartitionInfo.columns
343345
val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap
344346
activeNodes.iterator.zip(localBestSplits.iterator).flatMap {
345-
case (nodeIndexInLevel: Int, split: Split) =>
347+
case (nodeIndexInLevel: Int, Some(split: Split)) =>
346348
if (localFeatureIndex.contains(split.featureIndex)) {
347349
// This partition has the column (feature) used for this split.
348350
val fromOffset = nodeOffsets(nodeIndexInLevel)
@@ -352,6 +354,10 @@ private[ml] object AltDT extends Logging {
352354
} else {
353355
Iterator()
354356
}
357+
case (nodeIndexInLevel: Int, None) =>
358+
// Do not create a BitSubvector when there is no split.
359+
// This requires PartitionInfo.update to handle missing BitSubvectors.
360+
Iterator()
355361
}.toArray
356362
}
357363
val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge)
@@ -369,14 +375,15 @@ private[ml] object AltDT extends Logging {
369375
* @param labels
370376
* @param fromOffset
371377
* @param toOffset
372-
* @return
378+
* @return (best split, statistics for split) If the best split actually puts all instances
379+
* in one leaf node, then it will be set to None.
373380
*/
374381
private[impl] def chooseSplit(
375382
col: FeatureVector,
376383
labels: Array[Double],
377384
fromOffset: Int,
378385
toOffset: Int,
379-
metadata: AltDTMetadata): (Split, ImpurityStats) = {
386+
metadata: AltDTMetadata): (Option[Split], ImpurityStats) = {
380387
val valuesForNode = col.values.view.slice(fromOffset, toOffset)
381388
val labelsForNode = col.indices.view.slice(fromOffset, toOffset).map(labels.apply)
382389
if (col.isCategorical) {
@@ -405,14 +412,16 @@ private[ml] object AltDT extends Logging {
405412
* @param featureIndex Index of feature being split.
406413
* @param values Feature values at this node. Sorted in increasing order.
407414
* @param labels Labels corresponding to values, in the same order.
408-
* @return (best split, corresponding impurity statistics)
415+
* @return (best split, statistics for split) If the best split actually puts all instances
416+
* in one leaf node, then it will be set to None. The impurity stats maybe still be
417+
* useful, so they are returned.
409418
*/
410419
private[impl] def chooseOrderedCategoricalSplit(
411420
featureIndex: Int,
412421
values: Seq[Double],
413422
labels: Seq[Double],
414423
metadata: AltDTMetadata,
415-
featureArity: Int): (Split, ImpurityStats) = {
424+
featureArity: Int): (Option[Split], ImpurityStats) = {
416425
// TODO: Support high-arity features by using a single array to hold the stats.
417426

418427
// aggStats(category) = label statistics for category
@@ -514,24 +523,32 @@ private[ml] object AltDT extends Logging {
514523
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
515524
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
516525
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
517-
(bestFeatureSplit, bestImpurityStats)
526+
527+
if (bestSplitIndex == 0 || bestSplitIndex == categoriesSortedByCentroid.length - 1) {
528+
(None, bestImpurityStats)
529+
} else {
530+
(Some(bestFeatureSplit), bestImpurityStats)
531+
}
518532
}
519533

520534
private[impl] def chooseUnorderedCategoricalSplit(
521535
featureIndex: Int,
522536
values: Seq[Double],
523537
labels: Seq[Double],
524538
metadata: AltDTMetadata,
525-
featureArity: Int): (Split, ImpurityStats) = ???
539+
featureArity: Int): (Option[Split], ImpurityStats) = ???
526540

527541
/**
528542
* Choose splitting rule: feature value <= threshold
543+
* @return (best split, statistics for split) If the best split actually puts all instances
544+
* in one leaf node, then it will be set to None. The impurity stats maybe still be
545+
* useful, so they are returned.
529546
*/
530547
private[impl] def chooseContinuousSplit(
531548
featureIndex: Int,
532549
values: Seq[Double],
533550
labels: Seq[Double],
534-
metadata: AltDTMetadata): (Split, ImpurityStats) = {
551+
metadata: AltDTMetadata): (Option[Split], ImpurityStats) = {
535552

536553
val leftImpurityAgg = metadata.createImpurityAggregator()
537554
val rightImpurityAgg = metadata.createImpurityAggregator()
@@ -571,7 +588,12 @@ private[ml] object AltDT extends Logging {
571588
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
572589
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
573590
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
574-
(new ContinuousSplit(featureIndex, bestThreshold), bestImpurityStats)
591+
val split = if (bestThreshold != Double.NegativeInfinity && bestThreshold != values.last) {
592+
Some(new ContinuousSplit(featureIndex, bestThreshold))
593+
} else {
594+
None
595+
}
596+
(split, bestImpurityStats)
575597
}
576598

577599
/**
@@ -713,6 +735,8 @@ private[ml] object AltDT extends Logging {
713735
* and the second level is by row index.
714736
* bitVector(i) = false iff instance i goes to the left child.
715737
* For instances at inactive (leaf) nodes, the value can be arbitrary.
738+
* When an active node is not split (e.g., because no good split was found),
739+
* then the corresponding BitSubvector can be missing.
716740
* @return Updated partition info
717741
*/
718742
def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = {
@@ -722,44 +746,52 @@ private[ml] object AltDT extends Logging {
722746
activeNodes.iterator.foreach { nodeIdx =>
723747
val from = nodeOffsets(nodeIdx)
724748
val to = nodeOffsets(nodeIdx + 1)
725-
// Note: Each node is guaranteed to be covered within 1 bit vector.
749+
// TODO: Allow missing vectors when no split is chosen.
726750
if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
727751
val curBitVector = bitVectors(curBitVecIdx)
728-
// Sort range [from, to) based on indices. This is required to match the bit vector
729-
// across all workers. See [[bitSubvectorFromSplit]] for details.
730-
val rangeIndices = col.indices.view.slice(from, to).toArray
731-
val rangeValues = col.values.view.slice(from, to).toArray
732-
val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
733-
// Sort range [from, to) based on bit vector.
734-
sortedRange.zipWithIndex.map { case ((idx, value), i) =>
735-
val bit = curBitVector.get(from + i)
736-
// TODO: In-place merge, rather than general sort.
737-
// TODO: We don't actually need to sort the categorical features using our approach.
738-
(bit, value, idx)
739-
}.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
740-
col.values(from + i) = value
741-
col.indices(from + i) = idx
752+
// If the current BitVector does not cover this node, then this node was not split,
753+
// so we do not need to update its part of the column. Otherwise, we update it.
754+
if (curBitVector.from <= from && to <= curBitVector.to) {
755+
// Sort range [from, to) based on indices. This is required to match the bit vector
756+
// across all workers. See [[bitSubvectorFromSplit]] for details.
757+
val rangeIndices = col.indices.view.slice(from, to).toArray
758+
val rangeValues = col.values.view.slice(from, to).toArray
759+
val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
760+
// Sort range [from, to) based on bit vector.
761+
sortedRange.zipWithIndex.map { case ((idx, value), i) =>
762+
val bit = curBitVector.get(from + i)
763+
// TODO: In-place merge, rather than general sort.
764+
// TODO: We don't actually need to sort the categorical features using our approach.
765+
(bit, value, idx)
766+
}.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
767+
col.values(from + i) = value
768+
col.indices(from + i) = idx
769+
}
742770
}
743771
}
744772
col
745773
}
746774

747775
// Create a 2-level representation of the new nodeOffsets (to be flattened).
776+
// These 2 levels correspond to original nodes and their children (if split).
748777
val newNodeOffsets = nodeOffsets.map(Array(_))
749778
var curBitVecIdx = 0
750779
activeNodes.iterator.foreach { nodeIdx =>
751780
val from = nodeOffsets(nodeIdx)
752781
val to = nodeOffsets(nodeIdx + 1)
753782
if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
754783
val curBitVector = bitVectors(curBitVecIdx)
755-
assert(curBitVector.from <= from && to <= curBitVector.to)
756-
// Count number of values splitting to left vs. right
757-
val numRight = Range(from, to).count(curBitVector.get)
758-
val numLeft = to - from - numRight
759-
if (numLeft != 0 && numRight != 0) {
760-
// node is split
761-
val oldOffset = newNodeOffsets(nodeIdx).head
762-
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft)
784+
// If the current BitVector does not cover this node, then this node was not split,
785+
// so we do not need to create a new node offset. Otherwise, we create an offset.
786+
if (curBitVector.from <= from && to <= curBitVector.to) {
787+
// Count number of values splitting to left vs. right
788+
val numRight = Range(from, to).count(curBitVector.get)
789+
val numLeft = to - from - numRight
790+
if (numLeft != 0 && numRight != 0) {
791+
// node is split
792+
val oldOffset = newNodeOffsets(nodeIdx).head
793+
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft)
794+
}
763795
}
764796
}
765797

mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
152152
test("chooseSplit") {
153153
}
154154

155-
test("chooseOrderedCategoricalSplit") {
155+
test("chooseOrderedCategoricalSplit: basic case") {
156+
}
157+
158+
test("chooseOrderedCategoricalSplit: return bad split if best split is on end") {
156159
}
157160

158161
// test("chooseUnorderedCategoricalSplit") { }
@@ -165,7 +168,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
165168
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
166169
val (split, stats) = AltDT.chooseContinuousSplit(featureIndex, values, labels, metadata)
167170
split match {
168-
case s: ContinuousSplit =>
171+
case Some(s: ContinuousSplit) =>
169172
assert(s.featureIndex === featureIndex)
170173
assert(s.threshold === 0.2)
171174
case _ =>
@@ -182,9 +185,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
182185
assert(stats.valid)
183186
}
184187

185-
test("chooseContinuousSplit: some equal values") {
186-
}
187-
188188
// TODO: Add this test once we make this change.
189189
// test("chooseContinuousSplit: return bad split if best split is on end") { }
190190

@@ -232,7 +232,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
232232
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
233233
val partitionInfos = sc.parallelize(Seq(info))
234234
val bestSplit = new ContinuousSplit(0, threshold = 0.5)
235-
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit))
235+
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
236236
assert(bitVectors.length === 1)
237237
val bitv = bitVectors.head
238238
assert(bitv.numBits === numRows)
@@ -248,7 +248,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
248248
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
249249
val partitionInfos = sc.parallelize(Seq(info))
250250
val bestSplit = new ContinuousSplit(0, threshold = -2.0)
251-
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit))
251+
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
252252
assert(bitVectors.length === 1)
253253
val bitv = bitVectors.head
254254
assert(bitv.numBits === numRows)

0 commit comments

Comments
 (0)