@@ -25,12 +25,12 @@ import org.apache.spark.ml.tree._
2525import org .apache .spark .ml .tree .impl .TreeUtil ._
2626import org .apache .spark .mllib .linalg .Vector
2727import org .apache .spark .mllib .regression .LabeledPoint
28- import org .apache .spark .mllib .tree .configuration .{ Algo => OldAlgo , FeatureType , Strategy }
28+ import org .apache .spark .mllib .tree .configuration .Strategy
2929import org .apache .spark .mllib .tree .impurity .{Variance , Gini , Entropy , Impurity }
3030import org .apache .spark .mllib .tree .model .ImpurityStats
3131import org .apache .spark .rdd .RDD
3232import org .apache .spark .storage .StorageLevel
33- import org .apache .spark .util .collection .{ BitSet , OpenHashSet }
33+ import org .apache .spark .util .collection .BitSet
3434
3535
3636/**
@@ -151,7 +151,7 @@ private[ml] object AltDT extends Logging {
151151 iterator.foreach(groupedCols += _)
152152 if (groupedCols.nonEmpty) Iterator (groupedCols.toArray) else Iterator ()
153153 }
154- groupedColStore.repartition( 1 ). persist(StorageLevel .MEMORY_AND_DISK ) // TODO: remove repartition
154+ groupedColStore.persist(StorageLevel .MEMORY_AND_DISK )
155155
156156 // Initialize partitions with 1 node (each instance at the root node).
157157 var partitionInfosA : RDD [PartitionInfo ] = groupedColStore.map { groupedCols =>
@@ -178,7 +178,7 @@ private[ml] object AltDT extends Logging {
178178 val partitionInfos = partitionInfosDebug.last
179179
180180 // Compute best split for each active node.
181- val bestSplitsAndGains : Array [(Split , ImpurityStats )] =
181+ val bestSplitsAndGains : Array [(Option [ Split ] , ImpurityStats )] =
182182 computeBestSplits(partitionInfos, labelsBc, metadata)
183183 /*
184184 // NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under
@@ -246,20 +246,22 @@ private[ml] object AltDT extends Logging {
246246 * @param partitionInfos
247247 * @param labelsBc
248248 * @param metadata
249- * @return
249+ * @return Array over active nodes of (best split, impurity stats for split),
250+ * where the split is None if no useful split exists
250251 */
251252 private [impl] def computeBestSplits (
252253 partitionInfos : RDD [PartitionInfo ],
253254 labelsBc : Broadcast [Array [Double ]],
254- metadata : AltDTMetadata ): Array [(Split , ImpurityStats )] = {
255+ metadata : AltDTMetadata ): Array [(Option [ Split ] , ImpurityStats )] = {
255256 // On each partition, for each feature on the partition, select the best split for each node.
256257 // This will use:
257258 // - groupedColStore (the features)
258259 // - partitionInfos (the node -> instance mapping)
259260 // - labelsBc (the labels column)
260261 // Each worker returns:
261- // for each active node, best split + info gain
262- val partBestSplitsAndGains : RDD [Array [(Split , ImpurityStats )]] = partitionInfos.map {
262+ // for each active node, best split + info gain,
263+ // where the best split is None if no useful split exists
264+ val partBestSplitsAndGains : RDD [Array [(Option [Split ], ImpurityStats )]] = partitionInfos.map {
263265 case PartitionInfo (columns : Array [FeatureVector ], nodeOffsets : Array [Int ], activeNodes : BitSet ) =>
264266 val localLabels = labelsBc.value
265267 // Iterate over the active nodes in the current level.
@@ -270,7 +272,6 @@ private[ml] object AltDT extends Logging {
270272 columns.map { col =>
271273 chooseSplit(col, localLabels, fromOffset, toOffset, metadata)
272274 }
273- // We use Iterator and flatMap to handle empty partitions.
274275 splitsAndStats.maxBy(_._2.gain)
275276 }.toArray
276277 }
@@ -298,18 +299,18 @@ private[ml] object AltDT extends Logging {
298299 */
299300 private [impl] def computeActiveNodePeriphery (
300301 oldPeriphery : Array [LearningNode ],
301- bestSplitsAndGains : Array [(Split , ImpurityStats )],
302+ bestSplitsAndGains : Array [(Option [ Split ] , ImpurityStats )],
302303 minInfoGain : Double ): Array [LearningNode ] = {
303304 bestSplitsAndGains.zipWithIndex.flatMap {
304305 case ((split, stats), nodeIdx) =>
305306 val node = oldPeriphery(nodeIdx)
306- if (stats.gain > minInfoGain) {
307+ if (split.nonEmpty && stats.gain > minInfoGain) {
307308 // TODO: remove node id
308309 node.leftChild = Some (LearningNode (node.id * 2 , isLeaf = false ,
309310 ImpurityStats (stats.leftImpurity, stats.leftImpurityCalculator)))
310311 node.rightChild = Some (LearningNode (node.id * 2 + 1 , isLeaf = false ,
311312 ImpurityStats (stats.rightImpurity, stats.rightImpurityCalculator)))
312- node.split = Some ( split)
313+ node.split = split
313314 node.isLeaf = false
314315 node.stats = stats
315316 Iterator (node.leftChild.get, node.rightChild.get)
@@ -328,21 +329,22 @@ private[ml] object AltDT extends Logging {
328329 * Correction: Aggregate only the pieces of that vector corresponding to instances at
329330 * active nodes.
330331 * @param partitionInfos RDD with feature data, plus current status metadata
331- * @param bestSplits Split for each active node
332+ * @param bestSplits Split for each active node, or None if that node will not be split
332333 * @return Array of bit vectors, ordered by offset ranges
333334 */
334335 private [impl] def collectBitVectors (
335336 partitionInfos : RDD [PartitionInfo ],
336- bestSplits : Array [Split ]): Array [BitSubvector ] = {
337- val bestSplitsBc : Broadcast [Array [Split ]] = partitionInfos.sparkContext.broadcast(bestSplits)
337+ bestSplits : Array [Option [Split ]]): Array [BitSubvector ] = {
338+ val bestSplitsBc : Broadcast [Array [Option [Split ]]] =
339+ partitionInfos.sparkContext.broadcast(bestSplits)
338340 val workerBitSubvectors : RDD [Array [BitSubvector ]] = partitionInfos.map {
339341 case PartitionInfo (columns : Array [FeatureVector ], nodeOffsets : Array [Int ],
340342 activeNodes : BitSet ) =>
341- val localBestSplits : Array [Split ] = bestSplitsBc.value
343+ val localBestSplits : Array [Option [ Split ] ] = bestSplitsBc.value
342344 // localFeatureIndex[feature index] = index into PartitionInfo.columns
343345 val localFeatureIndex : Map [Int , Int ] = columns.map(_.featureIndex).zipWithIndex.toMap
344346 activeNodes.iterator.zip(localBestSplits.iterator).flatMap {
345- case (nodeIndexInLevel : Int , split : Split ) =>
347+ case (nodeIndexInLevel : Int , Some ( split : Split ) ) =>
346348 if (localFeatureIndex.contains(split.featureIndex)) {
347349 // This partition has the column (feature) used for this split.
348350 val fromOffset = nodeOffsets(nodeIndexInLevel)
@@ -352,6 +354,10 @@ private[ml] object AltDT extends Logging {
352354 } else {
353355 Iterator ()
354356 }
357+ case (nodeIndexInLevel : Int , None ) =>
358+ // Do not create a BitSubvector when there is no split.
359+ // This requires PartitionInfo.update to handle missing BitSubvectors.
360+ Iterator ()
355361 }.toArray
356362 }
357363 val aggBitVectors : Array [BitSubvector ] = workerBitSubvectors.reduce(BitSubvector .merge)
@@ -369,14 +375,15 @@ private[ml] object AltDT extends Logging {
369375 * @param labels
370376 * @param fromOffset
371377 * @param toOffset
372- * @return
378+ * @return (best split, statistics for split) If the best split actually puts all instances
379+ * in one leaf node, then it will be set to None.
373380 */
374381 private [impl] def chooseSplit (
375382 col : FeatureVector ,
376383 labels : Array [Double ],
377384 fromOffset : Int ,
378385 toOffset : Int ,
379- metadata : AltDTMetadata ): (Split , ImpurityStats ) = {
386+ metadata : AltDTMetadata ): (Option [ Split ] , ImpurityStats ) = {
380387 val valuesForNode = col.values.view.slice(fromOffset, toOffset)
381388 val labelsForNode = col.indices.view.slice(fromOffset, toOffset).map(labels.apply)
382389 if (col.isCategorical) {
@@ -405,14 +412,16 @@ private[ml] object AltDT extends Logging {
405412 * @param featureIndex Index of feature being split.
406413 * @param values Feature values at this node. Sorted in increasing order.
407414 * @param labels Labels corresponding to values, in the same order.
408- * @return (best split, corresponding impurity statistics)
415+ * @return (best split, statistics for split) If the best split actually puts all instances
416+ * in one leaf node, then it will be set to None. The impurity stats maybe still be
417+ * useful, so they are returned.
409418 */
410419 private [impl] def chooseOrderedCategoricalSplit (
411420 featureIndex : Int ,
412421 values : Seq [Double ],
413422 labels : Seq [Double ],
414423 metadata : AltDTMetadata ,
415- featureArity : Int ): (Split , ImpurityStats ) = {
424+ featureArity : Int ): (Option [ Split ] , ImpurityStats ) = {
416425 // TODO: Support high-arity features by using a single array to hold the stats.
417426
418427 // aggStats(category) = label statistics for category
@@ -514,24 +523,32 @@ private[ml] object AltDT extends Logging {
514523 val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
515524 val bestImpurityStats = new ImpurityStats (bestGain, fullImpurity, fullImpurityAgg.getCalculator,
516525 bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
517- (bestFeatureSplit, bestImpurityStats)
526+
527+ if (bestSplitIndex == 0 || bestSplitIndex == categoriesSortedByCentroid.length - 1 ) {
528+ (None , bestImpurityStats)
529+ } else {
530+ (Some (bestFeatureSplit), bestImpurityStats)
531+ }
518532 }
519533
520534 private [impl] def chooseUnorderedCategoricalSplit (
521535 featureIndex : Int ,
522536 values : Seq [Double ],
523537 labels : Seq [Double ],
524538 metadata : AltDTMetadata ,
525- featureArity : Int ): (Split , ImpurityStats ) = ???
539+ featureArity : Int ): (Option [ Split ] , ImpurityStats ) = ???
526540
527541 /**
528542 * Choose splitting rule: feature value <= threshold
543+ * @return (best split, statistics for split) If the best split actually puts all instances
544+ * in one leaf node, then it will be set to None. The impurity stats maybe still be
545+ * useful, so they are returned.
529546 */
530547 private [impl] def chooseContinuousSplit (
531548 featureIndex : Int ,
532549 values : Seq [Double ],
533550 labels : Seq [Double ],
534- metadata : AltDTMetadata ): (Split , ImpurityStats ) = {
551+ metadata : AltDTMetadata ): (Option [ Split ] , ImpurityStats ) = {
535552
536553 val leftImpurityAgg = metadata.createImpurityAggregator()
537554 val rightImpurityAgg = metadata.createImpurityAggregator()
@@ -571,7 +588,12 @@ private[ml] object AltDT extends Logging {
571588 val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
572589 val bestImpurityStats = new ImpurityStats (bestGain, fullImpurity, fullImpurityAgg.getCalculator,
573590 bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
574- (new ContinuousSplit (featureIndex, bestThreshold), bestImpurityStats)
591+ val split = if (bestThreshold != Double .NegativeInfinity && bestThreshold != values.last) {
592+ Some (new ContinuousSplit (featureIndex, bestThreshold))
593+ } else {
594+ None
595+ }
596+ (split, bestImpurityStats)
575597 }
576598
577599 /**
@@ -713,6 +735,8 @@ private[ml] object AltDT extends Logging {
713735 * and the second level is by row index.
714736 * bitVector(i) = false iff instance i goes to the left child.
715737 * For instances at inactive (leaf) nodes, the value can be arbitrary.
738+ * When an active node is not split (e.g., because no good split was found),
739+ * then the corresponding BitSubvector can be missing.
716740 * @return Updated partition info
717741 */
718742 def update (bitVectors : Array [BitSubvector ], newNumNodeOffsets : Int ): PartitionInfo = {
@@ -722,44 +746,52 @@ private[ml] object AltDT extends Logging {
722746 activeNodes.iterator.foreach { nodeIdx =>
723747 val from = nodeOffsets(nodeIdx)
724748 val to = nodeOffsets(nodeIdx + 1 )
725- // Note: Each node is guaranteed to be covered within 1 bit vector .
749+ // TODO: Allow missing vectors when no split is chosen .
726750 if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
727751 val curBitVector = bitVectors(curBitVecIdx)
728- // Sort range [from, to) based on indices. This is required to match the bit vector
729- // across all workers. See [[bitSubvectorFromSplit]] for details.
730- val rangeIndices = col.indices.view.slice(from, to).toArray
731- val rangeValues = col.values.view.slice(from, to).toArray
732- val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
733- // Sort range [from, to) based on bit vector.
734- sortedRange.zipWithIndex.map { case ((idx, value), i) =>
735- val bit = curBitVector.get(from + i)
736- // TODO: In-place merge, rather than general sort.
737- // TODO: We don't actually need to sort the categorical features using our approach.
738- (bit, value, idx)
739- }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
740- col.values(from + i) = value
741- col.indices(from + i) = idx
752+ // If the current BitVector does not cover this node, then this node was not split,
753+ // so we do not need to update its part of the column. Otherwise, we update it.
754+ if (curBitVector.from <= from && to <= curBitVector.to) {
755+ // Sort range [from, to) based on indices. This is required to match the bit vector
756+ // across all workers. See [[bitSubvectorFromSplit]] for details.
757+ val rangeIndices = col.indices.view.slice(from, to).toArray
758+ val rangeValues = col.values.view.slice(from, to).toArray
759+ val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
760+ // Sort range [from, to) based on bit vector.
761+ sortedRange.zipWithIndex.map { case ((idx, value), i) =>
762+ val bit = curBitVector.get(from + i)
763+ // TODO: In-place merge, rather than general sort.
764+ // TODO: We don't actually need to sort the categorical features using our approach.
765+ (bit, value, idx)
766+ }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
767+ col.values(from + i) = value
768+ col.indices(from + i) = idx
769+ }
742770 }
743771 }
744772 col
745773 }
746774
747775 // Create a 2-level representation of the new nodeOffsets (to be flattened).
776+ // These 2 levels correspond to original nodes and their children (if split).
748777 val newNodeOffsets = nodeOffsets.map(Array (_))
749778 var curBitVecIdx = 0
750779 activeNodes.iterator.foreach { nodeIdx =>
751780 val from = nodeOffsets(nodeIdx)
752781 val to = nodeOffsets(nodeIdx + 1 )
753782 if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
754783 val curBitVector = bitVectors(curBitVecIdx)
755- assert(curBitVector.from <= from && to <= curBitVector.to)
756- // Count number of values splitting to left vs. right
757- val numRight = Range (from, to).count(curBitVector.get)
758- val numLeft = to - from - numRight
759- if (numLeft != 0 && numRight != 0 ) {
760- // node is split
761- val oldOffset = newNodeOffsets(nodeIdx).head
762- newNodeOffsets(nodeIdx) = Array (oldOffset, oldOffset + numLeft)
784+ // If the current BitVector does not cover this node, then this node was not split,
785+ // so we do not need to create a new node offset. Otherwise, we create an offset.
786+ if (curBitVector.from <= from && to <= curBitVector.to) {
787+ // Count number of values splitting to left vs. right
788+ val numRight = Range (from, to).count(curBitVector.get)
789+ val numLeft = to - from - numRight
790+ if (numLeft != 0 && numRight != 0 ) {
791+ // node is split
792+ val oldOffset = newNodeOffsets(nodeIdx).head
793+ newNodeOffsets(nodeIdx) = Array (oldOffset, oldOffset + numLeft)
794+ }
763795 }
764796 }
765797
0 commit comments