Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 80 additions & 48 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.TreeUtil._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.{BitSet, OpenHashSet}
import org.apache.spark.util.collection.BitSet


/**
Expand Down Expand Up @@ -151,7 +151,7 @@ private[ml] object AltDT extends Logging {
iterator.foreach(groupedCols += _)
if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator()
}
groupedColStore.repartition(1).persist(StorageLevel.MEMORY_AND_DISK) // TODO: remove repartition
groupedColStore.persist(StorageLevel.MEMORY_AND_DISK)

// Initialize partitions with 1 node (each instance at the root node).
var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols =>
Expand All @@ -178,7 +178,7 @@ private[ml] object AltDT extends Logging {
val partitionInfos = partitionInfosDebug.last

// Compute best split for each active node.
val bestSplitsAndGains: Array[(Split, ImpurityStats)] =
val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
computeBestSplits(partitionInfos, labelsBc, metadata)
/*
// NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under
Expand Down Expand Up @@ -246,20 +246,22 @@ private[ml] object AltDT extends Logging {
* @param partitionInfos
* @param labelsBc
* @param metadata
* @return
* @return Array over active nodes of (best split, impurity stats for split),
* where the split is None if no useful split exists
*/
private[impl] def computeBestSplits(
partitionInfos: RDD[PartitionInfo],
labelsBc: Broadcast[Array[Double]],
metadata: AltDTMetadata): Array[(Split, ImpurityStats)] = {
metadata: AltDTMetadata): Array[(Option[Split], ImpurityStats)] = {
// On each partition, for each feature on the partition, select the best split for each node.
// This will use:
// - groupedColStore (the features)
// - partitionInfos (the node -> instance mapping)
// - labelsBc (the labels column)
// Each worker returns:
// for each active node, best split + info gain
val partBestSplitsAndGains: RDD[Array[(Split, ImpurityStats)]] = partitionInfos.map {
// for each active node, best split + info gain,
// where the best split is None if no useful split exists
val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map {
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) =>
val localLabels = labelsBc.value
// Iterate over the active nodes in the current level.
Expand All @@ -270,7 +272,6 @@ private[ml] object AltDT extends Logging {
columns.map { col =>
chooseSplit(col, localLabels, fromOffset, toOffset, metadata)
}
// We use Iterator and flatMap to handle empty partitions.
splitsAndStats.maxBy(_._2.gain)
}.toArray
}
Expand Down Expand Up @@ -298,18 +299,18 @@ private[ml] object AltDT extends Logging {
*/
private[impl] def computeActiveNodePeriphery(
oldPeriphery: Array[LearningNode],
bestSplitsAndGains: Array[(Split, ImpurityStats)],
bestSplitsAndGains: Array[(Option[Split], ImpurityStats)],
minInfoGain: Double): Array[LearningNode] = {
bestSplitsAndGains.zipWithIndex.flatMap {
case ((split, stats), nodeIdx) =>
val node = oldPeriphery(nodeIdx)
if (stats.gain > minInfoGain) {
if (split.nonEmpty && stats.gain > minInfoGain) {
// TODO: remove node id
node.leftChild = Some(LearningNode(node.id * 2, isLeaf = false,
ImpurityStats(stats.leftImpurity, stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(node.id * 2 + 1, isLeaf = false,
ImpurityStats(stats.rightImpurity, stats.rightImpurityCalculator)))
node.split = Some(split)
node.split = split
node.isLeaf = false
node.stats = stats
Iterator(node.leftChild.get, node.rightChild.get)
Expand All @@ -328,21 +329,22 @@ private[ml] object AltDT extends Logging {
* Correction: Aggregate only the pieces of that vector corresponding to instances at
* active nodes.
* @param partitionInfos RDD with feature data, plus current status metadata
* @param bestSplits Split for each active node
* @param bestSplits Split for each active node, or None if that node will not be split
* @return Array of bit vectors, ordered by offset ranges
*/
private[impl] def collectBitVectors(
partitionInfos: RDD[PartitionInfo],
bestSplits: Array[Split]): Array[BitSubvector] = {
val bestSplitsBc: Broadcast[Array[Split]] = partitionInfos.sparkContext.broadcast(bestSplits)
bestSplits: Array[Option[Split]]): Array[BitSubvector] = {
val bestSplitsBc: Broadcast[Array[Option[Split]]] =
partitionInfos.sparkContext.broadcast(bestSplits)
val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map {
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int],
activeNodes: BitSet) =>
val localBestSplits: Array[Split] = bestSplitsBc.value
val localBestSplits: Array[Option[Split]] = bestSplitsBc.value
// localFeatureIndex[feature index] = index into PartitionInfo.columns
val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap
activeNodes.iterator.zip(localBestSplits.iterator).flatMap {
case (nodeIndexInLevel: Int, split: Split) =>
case (nodeIndexInLevel: Int, Some(split: Split)) =>
if (localFeatureIndex.contains(split.featureIndex)) {
// This partition has the column (feature) used for this split.
val fromOffset = nodeOffsets(nodeIndexInLevel)
Expand All @@ -352,6 +354,10 @@ private[ml] object AltDT extends Logging {
} else {
Iterator()
}
case (nodeIndexInLevel: Int, None) =>
// Do not create a BitSubvector when there is no split.
// This requires PartitionInfo.update to handle missing BitSubvectors.
Iterator()
}.toArray
}
val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge)
Expand All @@ -369,14 +375,15 @@ private[ml] object AltDT extends Logging {
* @param labels
* @param fromOffset
* @param toOffset
* @return
* @return (best split, statistics for split) If the best split actually puts all instances
* in one leaf node, then it will be set to None.
*/
private[impl] def chooseSplit(
col: FeatureVector,
labels: Array[Double],
fromOffset: Int,
toOffset: Int,
metadata: AltDTMetadata): (Split, ImpurityStats) = {
metadata: AltDTMetadata): (Option[Split], ImpurityStats) = {
val valuesForNode = col.values.view.slice(fromOffset, toOffset)
val labelsForNode = col.indices.view.slice(fromOffset, toOffset).map(labels.apply)
if (col.isCategorical) {
Expand Down Expand Up @@ -405,14 +412,16 @@ private[ml] object AltDT extends Logging {
* @param featureIndex Index of feature being split.
* @param values Feature values at this node. Sorted in increasing order.
* @param labels Labels corresponding to values, in the same order.
* @return (best split, corresponding impurity statistics)
* @return (best split, statistics for split) If the best split actually puts all instances
* in one leaf node, then it will be set to None. The impurity stats maybe still be
* useful, so they are returned.
*/
private[impl] def chooseOrderedCategoricalSplit(
featureIndex: Int,
values: Seq[Double],
labels: Seq[Double],
metadata: AltDTMetadata,
featureArity: Int): (Split, ImpurityStats) = {
featureArity: Int): (Option[Split], ImpurityStats) = {
// TODO: Support high-arity features by using a single array to hold the stats.

// aggStats(category) = label statistics for category
Expand Down Expand Up @@ -514,24 +523,32 @@ private[ml] object AltDT extends Logging {
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
(bestFeatureSplit, bestImpurityStats)

if (bestSplitIndex == 0 || bestSplitIndex == categoriesSortedByCentroid.length - 1) {
(None, bestImpurityStats)
} else {
(Some(bestFeatureSplit), bestImpurityStats)
}
}

private[impl] def chooseUnorderedCategoricalSplit(
featureIndex: Int,
values: Seq[Double],
labels: Seq[Double],
metadata: AltDTMetadata,
featureArity: Int): (Split, ImpurityStats) = ???
featureArity: Int): (Option[Split], ImpurityStats) = ???

/**
* Choose splitting rule: feature value <= threshold
* @return (best split, statistics for split) If the best split actually puts all instances
* in one leaf node, then it will be set to None. The impurity stats maybe still be
* useful, so they are returned.
*/
private[impl] def chooseContinuousSplit(
featureIndex: Int,
values: Seq[Double],
labels: Seq[Double],
metadata: AltDTMetadata): (Split, ImpurityStats) = {
metadata: AltDTMetadata): (Option[Split], ImpurityStats) = {

val leftImpurityAgg = metadata.createImpurityAggregator()
val rightImpurityAgg = metadata.createImpurityAggregator()
Expand Down Expand Up @@ -571,7 +588,12 @@ private[ml] object AltDT extends Logging {
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)
(new ContinuousSplit(featureIndex, bestThreshold), bestImpurityStats)
val split = if (bestThreshold != Double.NegativeInfinity && bestThreshold != values.last) {
Some(new ContinuousSplit(featureIndex, bestThreshold))
} else {
None
}
(split, bestImpurityStats)
}

/**
Expand Down Expand Up @@ -713,6 +735,8 @@ private[ml] object AltDT extends Logging {
* and the second level is by row index.
* bitVector(i) = false iff instance i goes to the left child.
* For instances at inactive (leaf) nodes, the value can be arbitrary.
* When an active node is not split (e.g., because no good split was found),
* then the corresponding BitSubvector can be missing.
* @return Updated partition info
*/
def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = {
Expand All @@ -722,44 +746,52 @@ private[ml] object AltDT extends Logging {
activeNodes.iterator.foreach { nodeIdx =>
val from = nodeOffsets(nodeIdx)
val to = nodeOffsets(nodeIdx + 1)
// Note: Each node is guaranteed to be covered within 1 bit vector.
// TODO: Allow missing vectors when no split is chosen.
if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
val curBitVector = bitVectors(curBitVecIdx)
// Sort range [from, to) based on indices. This is required to match the bit vector
// across all workers. See [[bitSubvectorFromSplit]] for details.
val rangeIndices = col.indices.view.slice(from, to).toArray
val rangeValues = col.values.view.slice(from, to).toArray
val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
// Sort range [from, to) based on bit vector.
sortedRange.zipWithIndex.map { case ((idx, value), i) =>
val bit = curBitVector.get(from + i)
// TODO: In-place merge, rather than general sort.
// TODO: We don't actually need to sort the categorical features using our approach.
(bit, value, idx)
}.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
col.values(from + i) = value
col.indices(from + i) = idx
// If the current BitVector does not cover this node, then this node was not split,
// so we do not need to update its part of the column. Otherwise, we update it.
if (curBitVector.from <= from && to <= curBitVector.to) {
// Sort range [from, to) based on indices. This is required to match the bit vector
// across all workers. See [[bitSubvectorFromSplit]] for details.
val rangeIndices = col.indices.view.slice(from, to).toArray
val rangeValues = col.values.view.slice(from, to).toArray
val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
// Sort range [from, to) based on bit vector.
sortedRange.zipWithIndex.map { case ((idx, value), i) =>
val bit = curBitVector.get(from + i)
// TODO: In-place merge, rather than general sort.
// TODO: We don't actually need to sort the categorical features using our approach.
(bit, value, idx)
}.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
col.values(from + i) = value
col.indices(from + i) = idx
}
}
}
col
}

// Create a 2-level representation of the new nodeOffsets (to be flattened).
// These 2 levels correspond to original nodes and their children (if split).
val newNodeOffsets = nodeOffsets.map(Array(_))
var curBitVecIdx = 0
activeNodes.iterator.foreach { nodeIdx =>
val from = nodeOffsets(nodeIdx)
val to = nodeOffsets(nodeIdx + 1)
if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1
val curBitVector = bitVectors(curBitVecIdx)
assert(curBitVector.from <= from && to <= curBitVector.to)
// Count number of values splitting to left vs. right
val numRight = Range(from, to).count(curBitVector.get)
val numLeft = to - from - numRight
if (numLeft != 0 && numRight != 0) {
// node is split
val oldOffset = newNodeOffsets(nodeIdx).head
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft)
// If the current BitVector does not cover this node, then this node was not split,
// so we do not need to create a new node offset. Otherwise, we create an offset.
if (curBitVector.from <= from && to <= curBitVector.to) {
// Count number of values splitting to left vs. right
val numRight = Range(from, to).count(curBitVector.get)
val numLeft = to - from - numRight
if (numLeft != 0 && numRight != 0) {
// node is split
val oldOffset = newNodeOffsets(nodeIdx).head
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
test("chooseSplit") {
}

test("chooseOrderedCategoricalSplit") {
test("chooseOrderedCategoricalSplit: basic case") {
}

test("chooseOrderedCategoricalSplit: return bad split if best split is on end") {
}

// test("chooseUnorderedCategoricalSplit") { }
Expand All @@ -165,7 +168,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) = AltDT.chooseContinuousSplit(featureIndex, values, labels, metadata)
split match {
case s: ContinuousSplit =>
case Some(s: ContinuousSplit) =>
assert(s.featureIndex === featureIndex)
assert(s.threshold === 0.2)
case _ =>
Expand All @@ -182,9 +185,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(stats.valid)
}

test("chooseContinuousSplit: some equal values") {
}

// TODO: Add this test once we make this change.
// test("chooseContinuousSplit: return bad split if best split is on end") { }

Expand Down Expand Up @@ -232,7 +232,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
val partitionInfos = sc.parallelize(Seq(info))
val bestSplit = new ContinuousSplit(0, threshold = 0.5)
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit))
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
assert(bitVectors.length === 1)
val bitv = bitVectors.head
assert(bitv.numBits === numRows)
Expand All @@ -248,7 +248,7 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
val partitionInfos = sc.parallelize(Seq(info))
val bestSplit = new ContinuousSplit(0, threshold = -2.0)
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit))
val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
assert(bitVectors.length === 1)
val bitv = bitVectors.head
assert(bitv.numBits === numRows)
Expand Down