diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91..bc94433f7a27 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -219,4 +219,37 @@ class BitSet(numBits: Int) extends Serializable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 + + /** + * Bit-wise OR between two BitSets where the ith bit of other is ORed against the i+offset bit of this instance. For + * performance, the OR is computed word-by-word rather than bit-by-bit. + * + * This function mutates the current BitSet instance (i.e. not `other`). + * + * @param offset the amount to left-shift (with zero padding) `other` before performing the OR, must be >= 0. + */ + private[spark] def orWithOffset(other: BitSet, offset: Int): Unit = { + val numWords = bit2words(math.min(this.capacity, other.capacity - offset)) + val wordOffset = offset >> 6 // divide by 64 + + // Bit vectors have memory layout [63..0|127..64|...] where | denotes word boundaries, so left/right within a word + // and left/right across words are flipped + val rightOffset = offset & 0x3f // mod 64 + val leftOffset = (64 - rightOffset) & 0x3f // mod 64 + + var wordIndex = 0 + while (wordIndex < numWords) { + // Fill in lowest-order bits from other's previous word's highest-order bits if available + if (rightOffset > 0 && wordIndex > 0) { + val maskedShiftedPrevWord = (other.words(wordIndex - 1) & (-1L << leftOffset)) >> leftOffset + words(wordIndex + wordOffset) = words(wordIndex + wordOffset) | maskedShiftedPrevWord + } + + // Mask, shift, and OR with current word + val maskedShiftedOtherWord = (other.words(wordIndex) & (-1L >> rightOffset)) << rightOffset + words(wordIndex + wordOffset) = words(wordIndex + wordOffset) | maskedShiftedOtherWord + + wordIndex += 1 + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index 69dbfa9cd714..88e4a5cfcf52 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -41,6 +41,29 @@ class BitSetSuite extends SparkFunSuite { assert(bitset.cardinality() === setBits.size) } + test("orWithOffset") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + setBits.foreach(i => bitset.set(i)) + + for { + offset <- Seq(0, 1, 63, 64, 65) + } { + val copyBitset = new BitSet(100) + copyBitset.orWithOffset(bitset, offset) + for (i <- 0 until offset) { + assert(!copyBitset.get(i)) + } + for (i <- offset until 100) { + if (setBits.contains(i - offset)) { + assert(copyBitset.get(i)) + } else { + assert(!copyBitset.get(i)) + } + } + } + } + test("100% full bit set") { val bitset = new BitSet(10000) for (i <- 0 until 10000) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index f28671f7869f..52616319ebd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -60,6 +60,7 @@ object DecisionTreeExample { testInput: String = "", dataFormat: String = "libsvm", algo: String = "Classification", + algorithm: String = "byRow", maxDepth: Int = 5, maxBins: Int = 32, minInstancesPerNode: Int = 1, @@ -77,6 +78,9 @@ object DecisionTreeExample { opt[String]("algo") .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") .action((x, c) => c.copy(algo = x)) + opt[String]("algorithm") + .text(s"algorithm (byRow, byCol), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algorithm = x)) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) @@ -236,16 +240,18 @@ object DecisionTreeExample { } // (2) Identify categorical features using VectorIndexer. // Features with more than maxCategories values will be treated as continuous. + /* val featuresIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer + */ // (3) Learn Decision Tree val dt = algo match { case "classification" => new DecisionTreeClassifier() - .setFeaturesCol("indexedFeatures") + .setFeaturesCol("features") // indexedFeatures .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -253,9 +259,10 @@ object DecisionTreeExample { .setMinInfoGain(params.minInfoGain) .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) + .setAlgorithm(params.algorithm) case "regression" => new DecisionTreeRegressor() - .setFeaturesCol("indexedFeatures") + .setFeaturesCol("features") // indexedFeatures .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -263,6 +270,7 @@ object DecisionTreeExample { .setMinInfoGain(params.minInfoGain) .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) + .setAlgorithm(params.algorithm) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } stages += dt @@ -278,14 +286,14 @@ object DecisionTreeExample { algo match { case "classification" => val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel] - if (treeModel.numNodes < 20) { + if (treeModel.numNodes < 200) { println(treeModel.toDebugString) // Print full model. } else { println(treeModel) // Print model summary. } case "regression" => val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel] - if (treeModel.numNodes < 20) { + if (treeModel.numNodes < 200) { println(treeModel.toDebugString) // Print full model. } else { println(treeModel) // Print model summary. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6f70b96b17ec..41bd5a906dd4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} -import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.tree.impl.{AltDT, RandomForest} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -62,6 +62,25 @@ final class DecisionTreeClassifier(override val uid: String) override def setImpurity(value: String): this.type = super.setImpurity(value) + /** + * Algorithm used for learning. + * Supported: "byRow" or "byCol" (case sensitive). + * (default = "byRow") + * @group param + */ + val algorithm: Param[String] = new Param[String](this, "algorithm", "Algorithm used " + + "for learning. Supported options:" + + s" ${DecisionTreeClassifier.supportedAlgorithms.mkString(", ")}", + (value: String) => DecisionTreeClassifier.supportedAlgorithms.contains(value)) + + setDefault(algorithm -> "byRow") + + /** @group setParam */ + def setAlgorithm(value: String): this.type = set(algorithm, value) + + /** @group getParam */ + def getAlgorithm: String = $(algorithm) + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -74,9 +93,15 @@ final class DecisionTreeClassifier(override val uid: String) } val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) - trees.head.asInstanceOf[DecisionTreeClassificationModel] + val model = getAlgorithm match { + case "byRow" => + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, + featureSubsetStrategy = "all", seed = 0L, parentUID = Some(uid)) + trees.head + case "byCol" => + AltDT.train(oldDataset, strategy, parentUID = Some(uid)) + } + model.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -94,6 +119,8 @@ final class DecisionTreeClassifier(override val uid: String) object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + final val supportedAlgorithms: Array[String] = Array("byRow", "byCol") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index a2bcd67401d0..8f07755878a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} -import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.tree.impl.{AltDT, RandomForest} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -62,14 +62,39 @@ final class DecisionTreeRegressor(override val uid: String) override def setImpurity(value: String): this.type = super.setImpurity(value) + /** + * Algorithm used for learning. + * Supported: "byRow" or "byCol" (case sensitive). + * (default = "byRow") + * @group param + */ + val algorithm: Param[String] = new Param[String](this, "algorithm", "Algorithm used " + + "for learning. Supported options:" + + s" ${DecisionTreeRegressor.supportedAlgorithms.mkString(", ")}", + (value: String) => DecisionTreeRegressor.supportedAlgorithms.contains(value)) + + setDefault(algorithm -> "byRow") + + /** @group setParam */ + def setAlgorithm(value: String): this.type = set(algorithm, value) + + /** @group getParam */ + def getAlgorithm: String = $(algorithm) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) - val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) - trees.head.asInstanceOf[DecisionTreeRegressionModel] + val model = getAlgorithm match { + case "byRow" => + val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, + featureSubsetStrategy = "all", seed = 0L, parentUID = Some(uid)) + trees.head + case "byCol" => + AltDT.train(oldDataset, strategy, parentUID = Some(uid)) + } + model.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ @@ -85,6 +110,8 @@ final class DecisionTreeRegressor(override val uid: String) object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + final val supportedAlgorithms: Array[String] = Array("byRow", "byCol") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index cd2493129390..4bc9cad1920a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -288,7 +288,7 @@ private[tree] object LearningNode { id: Int, isLeaf: Boolean, stats: ImpurityStats): LearningNode = { - new LearningNode(id, None, None, None, false, stats) + new LearningNode(id, None, None, None, isLeaf, stats) } /** Create an empty node with the given node index. Values must be set later on. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 78199cc2df58..34aa080a74a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -47,6 +47,12 @@ sealed trait Split extends Serializable { */ private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean + /** + * Return true (split to left) or false (split to right). + * @param feature Feature value (original value, not binned) + */ + private[tree] def shouldGoLeft(feature: Double): Boolean + /** Convert to old Split format */ private[tree] def toOld: OldSplit } @@ -112,6 +118,14 @@ final class CategoricalSplit private[ml] ( } } + override private[tree] def shouldGoLeft(feature: Double): Boolean = { + if (isLeft) { + categories.contains(feature) + } else { + !categories.contains(feature) + } + } + override def equals(o: Any): Boolean = { o match { case other: CategoricalSplit => featureIndex == other.featureIndex && @@ -172,6 +186,10 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr } } + override private[tree] def shouldGoLeft(feature: Double): Boolean = { + feature <= threshold + } + override def equals(o: Any): Boolean = { o match { case other: ContinuousSplit => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala new file mode 100644 index 000000000000..ff4eea2ef336 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -0,0 +1,787 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.TreeUtil._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, FeatureType, Strategy} +import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.{BitSet, OpenHashSet} + + +/** + * DecisionTree which partitions data by feature. + * + * Algorithm: + * - Repartition data, grouping by feature. + * - Prep data (sort continuous features). + * - On each partition, initialize instance--node map with each instance at root node. + * - Iterate, training 1 new level of the tree at a time: + * - On each partition, for each feature on the partition, select the best split for each node. + * - Aggregate best split for each node. + * - Aggregate bit vector (1 bit/instance) indicating whether each instance splits + * left or right. + * - Broadcast bit vector. On each partition, update instance--node map. + * + * TODO: Update to use a sparse column store. + */ +private[ml] object AltDT extends Logging { + + private[impl] class AltDTMetadata( + val numClasses: Int, + val maxBins: Int, + val minInfoGain: Double, + val impurity: Impurity) extends Serializable { + + private val maxCategoriesForUnorderedFeature = + ((math.log(maxBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + /** + * Indicates whether a categorical feature should be treated as unordered. + * + * TODO(SPARK-9957): If a categorical feature has only 1 category, we treat it as continuous. + * Later, handle this properly by filtering out those features. + */ + def isUnorderedFeature(numCategories: Int): Boolean = { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The last inequality is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + isMulticlass && numCategories > 1 && + numCategories <= maxCategoriesForUnorderedFeature + } + + def createImpurityAggregator(): ImpurityAggregatorSingle = { + impurity match { + case Entropy => new EntropyAggregatorSingle(numClasses) + case Gini => new GiniAggregatorSingle(numClasses) + case Variance => new VarianceAggregatorSingle + } + } + } + + private[impl] object AltDTMetadata { + def fromStrategy(strategy: Strategy) = new AltDTMetadata(strategy.numClasses, strategy.maxBins, + strategy.minInfoGain, strategy.impurity) + } + + /** + * Method to train a decision tree model over an RDD. + */ + def train( + input: RDD[LabeledPoint], + strategy: Strategy, + parentUID: Option[String] = None): DecisionTreeModel = { + // TODO: Check validity of params + val rootNode = trainImpl(input, strategy) + RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID) + } + + private[impl] def trainImpl(input: RDD[LabeledPoint], strategy: Strategy): Node = { + val metadata = AltDTMetadata.fromStrategy(strategy) + + // The case with 1 node (depth = 0) is handled separately. + // This allows all iterations in the depth > 0 case to use the same code. + // TODO: Check that learning works when maxDepth > 0 but learning stops at 1 node (because of + // other parameters). + if (strategy.maxDepth == 0) { + val impurityAggregator: ImpurityAggregatorSingle = + input.aggregate(metadata.createImpurityAggregator())( + (agg, lp) => agg.update(lp.label, 1.0), + (agg1, agg2) => agg1.add(agg2)) + val impurityCalculator = impurityAggregator.getCalculator + return new LeafNode(impurityCalculator.getPredict.predict, impurityCalculator.calculate(), + impurityCalculator) + } + + // Prepare column store. + // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. + // TODO: Is this mapping from arrays to iterators to arrays (when constructing learningData)? + // Or is the mapping implicit (i.e., not costly)? + val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features)) + val numRows: Int = colStoreInit.first()._2.size + val labels = new Array[Double](numRows) + input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => + labels(rowIndex.toInt) = label + } + val labelsBc = input.sparkContext.broadcast(labels) + // NOTE: Labels are not sorted with features since that would require 1 copy per feature, + // rather than 1 copy per worker. This means a lot of random accesses. + // We could improve this by applying first-level sorting (by node) to labels. + + // TODO: RIGHT HERE NOW: JUST ADDED ISUNORDERED + + // Sort each column by feature values. + val colStore: RDD[FeatureVector] = colStoreInit.map { case (featureIndex: Int, col: Vector) => + val featureArity: Int = strategy.categoricalFeaturesInfo.getOrElse(featureIndex, 0) + FeatureVector.fromOriginal(featureIndex, featureArity, col) + } + // Group columns together into one array of columns per partition. + // TODO: Test avoiding this grouping, and see if it matters. + val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator => + val groupedCols = new ArrayBuffer[FeatureVector] + iterator.foreach(groupedCols += _) + if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator() + } + groupedColStore.repartition(1).persist(StorageLevel.MEMORY_AND_DISK) // TODO: remove repartition + + // Initialize partitions with 1 node (each instance at the root node). + var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols => + val initActive = new BitSet(1) + initActive.set(0) + new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive) + } + + // Initialize model. + // Note: We do not use node indices. + val rootNode = LearningNode.emptyNode(1) // TODO: remove node id + // Active nodes (still being split), updated each iteration + var activeNodePeriphery: Array[LearningNode] = Array(rootNode) + var numNodeOffsets: Int = 2 + + val partitionInfosDebug = new scala.collection.mutable.ArrayBuffer[RDD[PartitionInfo]]() + partitionInfosDebug.append(partitionInfosA) + + // Iteratively learn, one level of the tree at a time. + var currentLevel = 0 + var doneLearning = false + while (currentLevel < strategy.maxDepth && !doneLearning) { + + val partitionInfos = partitionInfosDebug.last + + // Compute best split for each active node. + val bestSplitsAndGains: Array[(Split, ImpurityStats)] = + computeBestSplits(partitionInfos, labelsBc, metadata) + /* + // NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under + // bestSplitsAndGains since + assert(activeNodePeriphery.length == bestSplitsAndGains.length, + s"activeNodePeriphery.length=${activeNodePeriphery.length} does not equal" + + s" bestSplitsAndGains.length=${bestSplitsAndGains.length}") + */ + + // Update current model and node periphery. + // Note: This flatMap has side effects (on the model). + activeNodePeriphery = + computeActiveNodePeriphery(activeNodePeriphery, bestSplitsAndGains, strategy.getMinInfoGain) + // We keep all old nodeOffsets and add one for each node split. + // Each node split adds 2 nodes to activeNodePeriphery. + // TODO: Should this be calculated after filtering for impurity?? + numNodeOffsets = numNodeOffsets + activeNodePeriphery.length / 2 + + // Filter active node periphery by impurity. + val estimatedRemainingActive = activeNodePeriphery.count(_.stats.impurity > 0.0) + + // TODO: Check to make sure we split something, and stop otherwise. + doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0 + + if (!doneLearning) { + // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. + val aggBitVectors: Array[BitSubvector] = + collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1)) + + // Broadcast aggregated bit vectors. On each partition, update instance--node map. + val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) + // partitionInfos = partitionInfos.map { partitionInfo => + val partitionInfosB = partitionInfos.map { partitionInfo => + partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) + } + partitionInfosB.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... + partitionInfosDebug.append(partitionInfosB) + + // TODO: unpersist aggBitVectorsBc after action. + } + + currentLevel += 1 + } + + // Done with learning + groupedColStore.unpersist() + labelsBc.unpersist() + rootNode.toNode + } + + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + + /** + * Find the best splits for all active nodes. + * - On each partition, for each feature on the partition, select the best split for each node. + * Each worker returns: For each active node, best split + info gain + * - The splits across workers are aggregated to the driver. + * @param partitionInfos + * @param labelsBc + * @param metadata + * @return + */ + private[impl] def computeBestSplits( + partitionInfos: RDD[PartitionInfo], + labelsBc: Broadcast[Array[Double]], + metadata: AltDTMetadata): Array[(Split, ImpurityStats)] = { + // On each partition, for each feature on the partition, select the best split for each node. + // This will use: + // - groupedColStore (the features) + // - partitionInfos (the node -> instance mapping) + // - labelsBc (the labels column) + // Each worker returns: + // for each active node, best split + info gain + val partBestSplitsAndGains: RDD[Array[(Split, ImpurityStats)]] = partitionInfos.map { + case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => + val localLabels = labelsBc.value + // Iterate over the active nodes in the current level. + activeNodes.iterator.map { nodeIndexInLevel: Int => + val fromOffset = nodeOffsets(nodeIndexInLevel) + val toOffset = nodeOffsets(nodeIndexInLevel + 1) + val splitsAndStats = + columns.map { col => + chooseSplit(col, localLabels, fromOffset, toOffset, metadata) + } + // We use Iterator and flatMap to handle empty partitions. + splitsAndStats.maxBy(_._2.gain) + }.toArray + } + + // TODO: treeReduce + // Aggregate best split for each active node. + partBestSplitsAndGains.reduce { case (splitsGains1, splitsGains2) => + splitsGains1.zip(splitsGains2).map { case ((split1, gain1), (split2, gain2)) => + if (gain1.gain >= gain2.gain) { + (split1, gain1) + } else { + (split2, gain2) + } + } + } + } + + /** + * On driver: Grow tree based on chosen splits, and compute new set of active nodes. + * @param oldPeriphery Old periphery of active nodes. + * @param bestSplitsAndGains Best (split, gain) pairs, which can be zipped with the old + * periphery. + * @param minInfoGain Threshold for min info gain required to split a node. + * @return New active node periphery + */ + private[impl] def computeActiveNodePeriphery( + oldPeriphery: Array[LearningNode], + bestSplitsAndGains: Array[(Split, ImpurityStats)], + minInfoGain: Double): Array[LearningNode] = { + bestSplitsAndGains.zipWithIndex.flatMap { + case ((split, stats), nodeIdx) => + val node = oldPeriphery(nodeIdx) + if (stats.gain > minInfoGain) { + // TODO: remove node id + node.leftChild = Some(LearningNode(node.id * 2, isLeaf = false, + ImpurityStats(stats.leftImpurity, stats.leftImpurityCalculator))) + node.rightChild = Some(LearningNode(node.id * 2 + 1, isLeaf = false, + ImpurityStats(stats.rightImpurity, stats.rightImpurityCalculator))) + node.split = Some(split) + node.isLeaf = false + node.stats = stats + Iterator(node.leftChild.get, node.rightChild.get) + } else { + node.isLeaf = true + Iterator() + } + } + } + + /** + * Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. + * - Send chosen splits to workers. + * - Each worker creates part of the bit vector corresponding to the splits it created. + * - Aggregate the partial bit vectors to create one vector (of length numRows). + * Correction: Aggregate only the pieces of that vector corresponding to instances at + * active nodes. + * @param partitionInfos RDD with feature data, plus current status metadata + * @param bestSplits Split for each active node + * @return Array of bit vectors, ordered by offset ranges + */ + private[impl] def collectBitVectors( + partitionInfos: RDD[PartitionInfo], + bestSplits: Array[Split]): Array[BitSubvector] = { + val bestSplitsBc: Broadcast[Array[Split]] = partitionInfos.sparkContext.broadcast(bestSplits) + val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map { + case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], + activeNodes: BitSet) => + val localBestSplits: Array[Split] = bestSplitsBc.value + // localFeatureIndex[feature index] = index into PartitionInfo.columns + val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap + activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + case (nodeIndexInLevel: Int, split: Split) => + if (localFeatureIndex.contains(split.featureIndex)) { + // This partition has the column (feature) used for this split. + val fromOffset = nodeOffsets(nodeIndexInLevel) + val toOffset = nodeOffsets(nodeIndexInLevel + 1) + val colIndex: Int = localFeatureIndex(split.featureIndex) + Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split)) + } else { + Iterator() + } + }.toArray + } + val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge) + bestSplitsBc.unpersist() + aggBitVectors + } + + /** + * Choose the best split for a feature at a node. + * + * TODO: Return null or None when the split is invalid, such as putting all instances on one + * child node. + * + * @param col + * @param labels + * @param fromOffset + * @param toOffset + * @return + */ + private[impl] def chooseSplit( + col: FeatureVector, + labels: Array[Double], + fromOffset: Int, + toOffset: Int, + metadata: AltDTMetadata): (Split, ImpurityStats) = { + val valuesForNode = col.values.view.slice(fromOffset, toOffset) + val labelsForNode = col.indices.view.slice(fromOffset, toOffset).map(labels.apply) + if (col.isCategorical) { + if (metadata.isUnorderedFeature(col.featureArity)) { + chooseUnorderedCategoricalSplit(col.featureIndex, valuesForNode, labelsForNode, metadata, + col.featureArity) + } else { + chooseOrderedCategoricalSplit(col.featureIndex, valuesForNode, labelsForNode, metadata, + col.featureArity) + } + } else { + chooseContinuousSplit(col.featureIndex, valuesForNode, labelsForNode, metadata) + } + } + + /** + * Find the best split for an ordered categorical feature at a single node. + * + * Algorithm: + * - For each category, compute a "centroid." + * - For multiclass classification, the centroid is the label impurity. + * - For binary classification and regression, the centroid is the average label. + * - Sort the centroids, and consider splits anywhere in this order. + * Thus, with K categories, we consider K - 1 possible splits. + * + * @param featureIndex Index of feature being split. + * @param values Feature values at this node. Sorted in increasing order. + * @param labels Labels corresponding to values, in the same order. + * @return (best split, corresponding impurity statistics) + */ + private[impl] def chooseOrderedCategoricalSplit( + featureIndex: Int, + values: Seq[Double], + labels: Seq[Double], + metadata: AltDTMetadata, + featureArity: Int): (Split, ImpurityStats) = { + // TODO: Support high-arity features by using a single array to hold the stats. + + // aggStats(category) = label statistics for category + val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( + _ => metadata.createImpurityAggregator()) + values.zip(labels).foreach { case (cat, label) => + aggStats(cat.toInt).update(label) + } + + // Compute centroids. centroidsForCategories is a list: (category, centroid) + val centroidsForCategories: Seq[(Int, Double)] = if (metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, featureArity).map { case featureValue => + val categoryStats = aggStats(featureValue) + val centroid = if (categoryStats.getCount != 0) { + categoryStats.getCalculator.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else if (metadata.isClassification) { // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, featureArity).map { case featureValue => + val categoryStats = aggStats(featureValue) + val centroid = if (categoryStats.getCount != 0) { + assert(categoryStats.stats.length == 2) + (categoryStats.stats(1) - categoryStats.stats(0)) / categoryStats.getCount + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else { // regression + // For categorical variables in regression, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, featureArity).map { case featureValue => + val categoryStats = aggStats(featureValue) + val centroid = if (categoryStats.getCount != 0) { + categoryStats.getCalculator.predict + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } + + logDebug("Centroids for categorical variable: " + centroidsForCategories.mkString(",")) + + val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) + + // Cumulative sums of bin statistics for left, right parts of split. + val leftImpurityAgg = metadata.createImpurityAggregator() + val rightImpurityAgg = metadata.createImpurityAggregator() + aggStats.foreach(rightImpurityAgg.add) + + var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid + val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() + var bestGain: Double = -1.0 + val fullImpurity = rightImpurityAgg.getCalculator.calculate() + var leftCount: Double = 0.0 + var rightCount: Double = rightImpurityAgg.getCount + val fullCount: Double = rightCount + + val numSplits = categoriesSortedByCentroid.length - 1 + var sortedCatIndex = 0 + while (sortedCatIndex < numSplits) { + val cat = categoriesSortedByCentroid(sortedCatIndex) + // Update left, right stats + val catStats = aggStats(cat) + leftImpurityAgg.add(catStats) + rightImpurityAgg.subtract(catStats) + leftCount += catStats.getCount + rightCount -= catStats.getCount + // Compute impurity + val leftWeight = leftCount / fullCount + val rightWeight = rightCount / fullCount + val leftImpurity = leftImpurityAgg.getCalculator.calculate() + val rightImpurity = rightImpurityAgg.getCalculator.calculate() + val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + if (gain > bestGain && gain > metadata.minInfoGain) { + bestSplitIndex = sortedCatIndex + leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats) + bestGain = gain + } + sortedCatIndex += 1 + } + + assert(bestSplitIndex != -1, "Unknown error in AltDT split selection for ordered categorical" + + s" variable with numSplits = $numSplits.") + + val categoriesForSplit = + categoriesSortedByCentroid.slice(0, bestSplitIndex + 1).map(_.toDouble) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, featureArity) + val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) + val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) + val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, + bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) + (bestFeatureSplit, bestImpurityStats) + } + + private[impl] def chooseUnorderedCategoricalSplit( + featureIndex: Int, + values: Seq[Double], + labels: Seq[Double], + metadata: AltDTMetadata, + featureArity: Int): (Split, ImpurityStats) = ??? + + /** + * Choose splitting rule: feature value <= threshold + */ + private[impl] def chooseContinuousSplit( + featureIndex: Int, + values: Seq[Double], + labels: Seq[Double], + metadata: AltDTMetadata): (Split, ImpurityStats) = { + + val leftImpurityAgg = metadata.createImpurityAggregator() + val rightImpurityAgg = metadata.createImpurityAggregator() + labels.foreach(rightImpurityAgg.update(_, 1.0)) + + var bestThreshold: Double = Double.NegativeInfinity + val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() + var bestGain: Double = 0.0 + val fullImpurity = rightImpurityAgg.getCalculator.calculate() + var leftCount: Double = 0.0 + var rightCount: Double = rightImpurityAgg.getCount + val fullCount: Double = rightCount + var currentThreshold = values.headOption.getOrElse(bestThreshold) + values.zip(labels).foreach { case (value, label) => + if (value != currentThreshold) { + // Check gain + val leftWeight = leftCount / fullCount + val rightWeight = rightCount / fullCount + val leftImpurity = leftImpurityAgg.getCalculator.calculate() + val rightImpurity = rightImpurityAgg.getCalculator.calculate() + val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + if (gain > bestGain && gain > metadata.minInfoGain) { + bestThreshold = currentThreshold + leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats) + bestGain = gain + } + currentThreshold = value + } + // Move this instance from right to left side of split. + leftImpurityAgg.update(label, 1.0) + rightImpurityAgg.update(label, -1.0) + leftCount += 1.0 + rightCount -= 1.0 + } + + val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) + val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) + val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, + bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) + (new ContinuousSplit(featureIndex, bestThreshold), bestImpurityStats) + } + + /** + * Feature vector types are based on (feature type, representation). + * The feature type can be continuous or categorical. + * + * Features are sorted by value, so we must store indices + values. + * These values are currently stored in a dense representation only. + * TODO: Support sparse storage (to optimize deeper levels of the tree), and maybe compressed + * storage (to optimize upper levels of the tree). + * @param featureArity For categorical features, this gives the number of categories. + * For continuous features, this should be set to 0. + */ + private[impl] class FeatureVector( + val featureIndex: Int, + val featureArity: Int, + val values: Array[Double], + val indices: Array[Int]) + extends Serializable { + + def isCategorical: Boolean = featureArity > 0 + + /** For debugging */ + override def toString: String = { + " FeatureVector(" + + s" featureIndex: $featureIndex,\n" + + s" featureType: ${if (featureArity == 0) "Continuous" else "Categorical"},\n" + + s" featureArity: $featureArity,\n" + + s" values: ${values.mkString(", ")},\n" + + s" indices: ${indices.mkString(", ")},\n" + + " )" + } + + def deepCopy(): FeatureVector = + new FeatureVector(featureIndex, featureArity, values.clone(), indices.clone()) + + override def equals(other: Any): Boolean = { + other match { + case o: FeatureVector => + featureIndex == o.featureIndex && featureArity == o.featureArity && + values.sameElements(o.values) && indices.sameElements(o.indices) + case _ => false + } + } + } + + private[impl] object FeatureVector { + /** Store column sorted by feature values. */ + def fromOriginal( + featureIndex: Int, + featureArity: Int, + featureVector: Vector): FeatureVector = { + val (values, indices) = featureVector.toArray.zipWithIndex.sorted.unzip + new FeatureVector(featureIndex, featureArity, values.toArray, indices.toArray) + } + } + + /** + * For a given feature, for a given node, apply a split and return a bit vector indicating the + * outcome of the split for each instance at that node. + * + * @param col Column for feature + * @param fromOffset Start offset in col for the node + * @param toOffset End offset in col for the node + * @param split Split to apply to instances at this node. + * @return Bits indicating splits for instances at this node. + * These bits are sorted by the row indices, in order to guarantee an ordering + * understood by all workers. + * Thus, the bit indices used are based on 2-level sorting: first by node, and + * second by sorted row indices within the node's rows. + * bit[index in sorted array of row indices] = false for left, true for right + */ + private[impl] def bitSubvectorFromSplit( + col: FeatureVector, + fromOffset: Int, + toOffset: Int, + split: Split): BitSubvector = { + val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset).toArray + val nodeRowValues = col.values.view.slice(fromOffset, toOffset).toArray + val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2) + val bitv = new BitSubvector(fromOffset, toOffset) + nodeRowValuesSortedByIndices.zipWithIndex.foreach { case (value, i) => + if (!split.shouldGoLeft(value)) { + bitv.set(fromOffset + i) + } + } + bitv + } + + /** + * Intermediate data stored on each partition during learning. + * + * Node indexing for nodeOffsets, activeNodes: + * Nodes are indexed left-to-right along the periphery of the tree, with 0-based indices. + * The periphery is the set of leaf nodes (active and inactive). + * + * @param columns Subset of columns (features) stored in this partition. + * Each column is sorted first by nodes (left-to-right along the tree periphery); + * all columns share this first level of sorting. + * Within each node's group, each column is sorted based on feature value; + * this second level of sorting differs across columns. + * @param nodeOffsets Offsets into the columns indicating the first level of sorting (by node). + * The rows corresponding to node i are in the range + * [nodeOffsets(i), nodeOffsets(i+1)). + * @param activeNodes Nodes which are active (still being split). + * Inactive nodes are known to be leafs in the final tree. + * TODO: Should this (and even nodeOffsets) not be stored in PartitionInfo, + * but instead on the driver? + */ + private[impl] case class PartitionInfo( + columns: Array[FeatureVector], + nodeOffsets: Array[Int], + activeNodes: BitSet) + extends Serializable { + + /** For debugging */ + override def toString: String = { + "PartitionInfo(" + + " columns: {\n" + + columns.mkString(",\n") + + " },\n" + + s" nodeOffsets: ${nodeOffsets.mkString(", ")},\n" + + s" activeNodes: ${activeNodes.iterator.mkString(", ")},\n" + + ")\n" + } + + /** + * Update columns and nodeOffsets for the next level of the tree. + * + * Update columns: + * For each column, + * For each (previously) active node, + * Sort corresponding range of instances based on bit vector. + * Update nodeOffsets, activeNodes: + * Split offsets for nodes which split (which can be identified using the bit vector). + * + * @param bitVectors Bit vectors encoding splits for the next level of the tree. + * These must follow a 2-level ordering, where the first level is by node + * and the second level is by row index. + * bitVector(i) = false iff instance i goes to the left child. + * For instances at inactive (leaf) nodes, the value can be arbitrary. + * @return Updated partition info + */ + def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = { + val newColumns = columns.map { oldCol => + val col = oldCol.deepCopy() + var curBitVecIdx = 0 + activeNodes.iterator.foreach { nodeIdx => + val from = nodeOffsets(nodeIdx) + val to = nodeOffsets(nodeIdx + 1) + // Note: Each node is guaranteed to be covered within 1 bit vector. + if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 + val curBitVector = bitVectors(curBitVecIdx) + // Sort range [from, to) based on indices. This is required to match the bit vector + // across all workers. See [[bitSubvectorFromSplit]] for details. + val rangeIndices = col.indices.view.slice(from, to).toArray + val rangeValues = col.values.view.slice(from, to).toArray + val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1) + // Sort range [from, to) based on bit vector. + sortedRange.zipWithIndex.map { case ((idx, value), i) => + val bit = curBitVector.get(from + i) + // TODO: In-place merge, rather than general sort. + // TODO: We don't actually need to sort the categorical features using our approach. + (bit, value, idx) + }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) => + col.values(from + i) = value + col.indices(from + i) = idx + } + } + col + } + + // Create a 2-level representation of the new nodeOffsets (to be flattened). + val newNodeOffsets = nodeOffsets.map(Array(_)) + var curBitVecIdx = 0 + activeNodes.iterator.foreach { nodeIdx => + val from = nodeOffsets(nodeIdx) + val to = nodeOffsets(nodeIdx + 1) + if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 + val curBitVector = bitVectors(curBitVecIdx) + assert(curBitVector.from <= from && to <= curBitVector.to) + // Count number of values splitting to left vs. right + val numRight = Range(from, to).count(curBitVector.get) + val numLeft = to - from - numRight + if (numLeft != 0 && numRight != 0) { + // node is split + val oldOffset = newNodeOffsets(nodeIdx).head + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft) + } + } + + assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, + s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," + + s" newNumNodeOffsets: $newNumNodeOffsets") + + // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. + val newActiveNodes = new BitSet(newNumNodeOffsets - 1) + var newNodeOffsetsIdx = 0 + newNodeOffsets.foreach { offsets => + if (offsets.length == 2) { + newActiveNodes.set(newNodeOffsetsIdx) + newActiveNodes.set(newNodeOffsetsIdx + 1) + newNodeOffsetsIdx += 2 + } else { + newNodeOffsetsIdx += 1 + } + } + + PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala new file mode 100644 index 000000000000..b94f4903dba4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.collection.BitSet + + +/** + * A range of bits within a larger distributed bit vector. + * @param from starting index (inclusive) of larger distributed bit vector represented by this instance + * @param to ending index (exclusive) of larger distributed bit vector represented by this instance + */ +private[impl] class BitSubvector(val from: Int, val to: Int) extends Serializable { + + val numBits: Int = to - from + + private val bits: BitSet = new BitSet(numBits) + + /** Set a bit in this instance using an external index */ + def set(idx: Int): Unit = bits.set(toInternalIdx(idx)) + + def get(idx: Int): Boolean = bits.get(toInternalIdx(idx)) + + /** Get an iterator over the external indices of the set bits. */ + def iterator: Iterator[Int] = new Iterator[Int] { + val iter = bits.iterator + override def hasNext: Boolean = iter.hasNext + override def next(): Int = toExternalIdx(iter.next()) + } + + /** + * Bit-wise OR with another BitSubvector. Bits are matched according to external index (ORed against 0 if absent in + * the other BitSubvector). This method mutates the current instance in-place. + */ + def |=(other: BitSubvector): Unit = { + require(from <= other.from && to >= other.to) + val delta = other.from - from + bits.orWithOffset(other.bits, delta) + } + + private def toInternalIdx(idx: Int): Int = { + require(idx >= from && idx < to) + idx - from + } + private def toExternalIdx(idx: Int): Int = { + idx + from + } +} + +private[impl] object BitSubvector { + + def merge(parts1: Array[BitSubvector], parts2: Array[BitSubvector]): Array[BitSubvector] = { + // Merge sorted parts1, parts2 + val sortedSubvectors = (parts1 ++ parts2).sortBy(_.from) + if (sortedSubvectors.nonEmpty) { + // Merge adjacent PartialBitVectors (for adjacent node ranges) + val newSubvectorRanges: Array[(Int, Int)] = { + val newSubvRanges = ArrayBuffer.empty[(Int, Int)] + var i = 1 + var currentFrom = sortedSubvectors.head.from + while (i < sortedSubvectors.length) { + if (sortedSubvectors(i - 1).to != sortedSubvectors(i).from) { + newSubvRanges.append((currentFrom, sortedSubvectors(i - 1).to)) + currentFrom = sortedSubvectors(i).from + } + i += 1 + } + newSubvRanges.append((currentFrom, sortedSubvectors.last.to)) + newSubvRanges.toArray + } + val newSubvectors = newSubvectorRanges.map { case (from, to) => new BitSubvector(from, to) } + var curNewSubvIdx = 0 + sortedSubvectors.foreach { subv => + if (subv.to > newSubvectors(curNewSubvIdx).to) curNewSubvIdx += 1 + val newSubv = newSubvectors(curNewSubvIdx) + newSubv |= subv + } + assert(curNewSubvIdx + 1 == newSubvectors.length) // sanity check + newSubvectors + } else { + Array.empty[BitSubvector] + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/InfoGainStats.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/InfoGainStats.scala new file mode 100644 index 000000000000..7520ac225fde --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/InfoGainStats.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, Predict} + +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict + */ +private[tree] class InfoGainStats( + val prediction: Double, + val gain: Double, + val impurity: Double, + val leftImpurity: Double, + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { + + override def toString: String = { + s"prediction = $prediction, gain = $gain, impurity = $impurity, " + + s"left impurity = $leftImpurity, right impurity = $rightImpurity" + } + + override def equals(o: Any): Boolean = o match { + case other: InfoGainStats => + prediction == other.prediction && + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict + case _ => false + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + prediction: java.lang.Double, + gain: java.lang.Double, + impurity: java.lang.Double, + leftImpurity: java.lang.Double, + rightImpurity: java.lang.Double, + leftPredict, + rightPredict) + } + + def toOld: OldInformationGainStats = new OldInformationGainStats(gain, impurity, leftImpurity, + rightImpurity, leftPredict, rightPredict) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4ac51a475474..525168fc476c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -179,22 +179,26 @@ private[ml] object RandomForest extends Logging { } } + topNodes.map(lNode => finalizeTree(lNode.toNode, strategy.algo, strategy.numClasses, parentUID)) + } + + private[tree] def finalizeTree( + rootNode: Node, + algo: OldAlgo.Algo, + numClasses: Int, + parentUID: Option[String]): DecisionTreeModel = { parentUID match { case Some(uid) => - if (strategy.algo == OldAlgo.Classification) { - topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) - } + if (algo == OldAlgo.Classification) { + new DecisionTreeClassificationModel(uid, rootNode, numClasses) } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + new DecisionTreeRegressionModel(uid, rootNode) } case None => - if (strategy.algo == OldAlgo.Classification) { - topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) - } + if (algo == OldAlgo.Classification) { + new DecisionTreeClassificationModel(rootNode, numClasses) } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + new DecisionTreeRegressionModel(rootNode) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala new file mode 100644 index 000000000000..fa6eb63685d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.rdd.RDD + + +private[tree] object TreeUtil { + + /** + * Convert a dataset of [[Vector]] from row storage to column storage. + * This can take any [[Vector]] type but stores data as [[DenseVector]]. + * + * WARNING: This shuffles the ENTIRE dataset across the network, so it is a VERY EXPENSIVE + * operation. This can also fail if 1 column is too large to fit on 1 partition. + * + * This maintains sparsity in the data. + * + * This maintains matrix structure. I.e., each partition of the output RDD holds adjacent + * columns. The number of partitions will be min(input RDD's number of partitions, numColumns). + * + * @param rowStore The input vectors are data rows/instances. + * @return RDD of (columnIndex, columnValues) pairs, + * where each pair corresponds to one entire column. + * If either dimension of the given data is 0, this returns an empty RDD. + * If vector lengths do not match, this throws an exception. + * + * TODO: Add implementation for sparse data. + * For sparse data, distribute more evenly based on number of non-zeros. + * (First collect stats to decide how to partition.) + * TODO: Move elsewhere in MLlib. + */ + def rowToColumnStoreDense(rowStore: RDD[Vector]): RDD[(Int, Vector)] = { + + val numRows = { + val longNumRows: Long = rowStore.count() + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } + if (numRows == 0) { + return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) + } + val numCols = rowStore.take(1)(0).size + val numSourcePartitions = rowStore.partitions.length + val numTargetPartitions = Math.min(numCols, numSourcePartitions) + if (numTargetPartitions == 0) { + return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) + } + val maxColumnsPerPartition = Math.floor(numCols / (numTargetPartitions + 0.0)).toInt + + def getNumColsInGroup(groupIndex: Int) = { + if (groupIndex + 1 < numTargetPartitions) { + maxColumnsPerPartition + } else { + numCols - (numTargetPartitions - 1) * maxColumnsPerPartition // last partition + } + } + + /* On each partition, re-organize into groups of columns: + (groupIndex, (sourcePartitionIndex, partCols)), + where partCols(colIdx) = partial column. + The groupIndex will be used to groupByKey. + The sourcePartitionIndex is used to ensure instance indices match up after the shuffle. + The partial columns will be stacked into full columns after the shuffle. + Note: By design, partCols will always have at least 1 column. + */ + val partialColumns: RDD[(Int, (Int, Array[Array[Double]]))] = + rowStore.mapPartitionsWithIndex { case (sourcePartitionIndex, iterator) => + // columnSets(groupIndex)(colIdx) + // = column values for each instance in sourcePartitionIndex, + // where colIdx is a 0-based index for columns for groupIndex + val columnSets = new Array[Array[ArrayBuffer[Double]]](numTargetPartitions) + Range(0, numTargetPartitions).foreach { groupIndex => + columnSets(groupIndex) = + Array.fill[ArrayBuffer[Double]](getNumColsInGroup(groupIndex))(ArrayBuffer[Double]()) + } + iterator.foreach { row => + Range(0, numTargetPartitions).foreach { groupIndex => + val fromCol = groupIndex * maxColumnsPerPartition + val numColsInTargetPartition = getNumColsInGroup(groupIndex) + // TODO: match-case here on row as Dense or Sparse Vector (for speed) + var colIdx = 0 + while (colIdx < numColsInTargetPartition) { + columnSets(groupIndex)(colIdx) += row(fromCol + colIdx) + colIdx += 1 + } + } + } + Range(0, numTargetPartitions).map { groupIndex => + (groupIndex, + (sourcePartitionIndex, columnSets(groupIndex).map(_.toArray))) + }.toIterator + } + + // Shuffle data + val groupedPartialColumns: RDD[(Int, Iterable[(Int, Array[Array[Double]])])] = + partialColumns.groupByKey() + + // Each target partition now holds its set of columns. + // Group the partial columns into full columns. + val fullColumns = groupedPartialColumns.flatMap { case (groupIndex, iterator) => + // We do not know the number of rows per group, so we need to collect the groups + // before filling the full columns. + val collectedPartCols = new Array[Array[Array[Double]]](numSourcePartitions) + iterator.foreach { case (sourcePartitionIndex, partCols) => + collectedPartCols(sourcePartitionIndex) = partCols + } + val rowOffsets: Array[Int] = collectedPartCols.map(_(0).length).scanLeft(0)(_ + _) + val numRows = rowOffsets.last + // Initialize full columns + val fromCol = groupIndex * maxColumnsPerPartition + val numColumnsInPartition = getNumColsInGroup(groupIndex) + val partitionColumns: Array[Array[Double]] = + Array.fill[Array[Double]](numColumnsInPartition)(new Array[Double](numRows)) + var colIdx = 0 // index within group + while (colIdx < numColumnsInPartition) { + var sourcePartitionIndex = 0 + while (sourcePartitionIndex < numSourcePartitions) { + val partColLength = + rowOffsets(sourcePartitionIndex + 1) - rowOffsets(sourcePartitionIndex) + Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx), 0, + partitionColumns(colIdx), rowOffsets(sourcePartitionIndex), partColLength) + sourcePartitionIndex += 1 + } + colIdx += 1 + } + val columnIndices = Range(0, numColumnsInPartition).map(_ + fromCol) + val columns = partitionColumns.map(Vectors.dense) + columnIndices.zip(columns) + } + + fullColumns + } + + /** + * This checks for an empty RDD (0 rows or 0 columns). + * This will throw an exception if any columns have non-matching numbers of features. + * @param rowStore Dataset of vectors which all have the same length (number of columns). + * @return Array over columns of the number of non-zero elements in each column. + * Returns empty array if the RDD is empty. + */ + private def countNonZerosPerColumn(rowStore: RDD[Vector]): Array[Long] = { + val firstRow = rowStore.take(1) + if (firstRow.length == 0) { + return Array.empty[Long] + } + val numCols = firstRow(0).size + val colSizes: Array[Long] = rowStore.mapPartitions { iterator => + val partColSizes = Array.fill[Long](numCols)(0) + iterator.foreach { + case dv: DenseVector => + var col = 0 + while (col < dv.size) { + if (dv(col) != 0.0) partColSizes(col) += 1 + col += 1 + } + case sv: SparseVector => + var k = 0 + while (k < sv.indices.length) { + if (sv.values(k) != 0.0) partColSizes(sv.indices(k)) += 1 + k += 1 + } + } + Iterator(partColSizes) + }.fold(Array.fill[Long](numCols)(0)){ + case (v1, v2) => v1.zip(v2).map(v12 => v12._1 + v12._2) + } + colSizes + } + + /** + * The returned RDD sets the number of partitions as follows: + * - The targeted number is: + * numTargetPartitions = min(rowStore num partitions, num columns) * overPartitionFactor. + * - The actual number will be in the range [numTargetPartitions, 2 * numTargetPartitions]. + * Partitioning is done such that each partition holds consecutive columns. + * + * TODO: Update this to adaptively make columns dense or sparse based on a sparsity threshold. + * + * TODO: Cache rowStore temporarily. + * + * @param rowStore RDD of dataset rows + * @param overPartitionFactor Multiplier for the targeted number of partitions. This parameter + * helps to ensure that P partitions handled by P compute cores + * do not get split into slightly more than P partitions; + * if that occurred, then work would not be shared evenly. + * @return RDD of (column index, column) pairs + */ + def rowToColumnStoreSparse( + rowStore: RDD[Vector], + overPartitionFactor: Int = 3): RDD[(Int, Vector)] = { + + val numRows = { + val longNumRows: Long = rowStore.count() + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } + if (numRows == 0) { + return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) + } + + // Compute the number of non-zeros in each column. + val colSizes: Array[Long] = countNonZerosPerColumn(rowStore) + val numCols = colSizes.length + val numSourcePartitions = rowStore.partitions.length + if (numCols == 0 || numSourcePartitions == 0) { + return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) + } + val totalNonZeros = colSizes.sum + + // Split columns into groups. + // Groups are chosen greedily and sequentially, putting as many columns as possible in each + // group (limited by the number of non-zeros). Try to limit the number of non-zeros per + // group to at most targetNonZerosPerPartition. + val numTargetPartitions = math.min(numSourcePartitions, numCols) * overPartitionFactor + val targetNonZerosPerPartition = (totalNonZeros / numTargetPartitions.toDouble).floor.toLong + val groupStartColumns: Array[Int] = { + val startCols = new ArrayBuffer[Int]() + startCols += 0 + var currentStartCol = 0 + var currentNonZeros: Long = 0 + var col = 0 + while (col < numCols) { + if (currentNonZeros >= targetNonZerosPerPartition && col != startCols.last) { + startCols += col + currentStartCol = col + currentNonZeros = 0 + } else { + currentNonZeros += colSizes(col) + } + col += 1 + } + startCols += numCols + startCols.toArray + } + val numGroups = groupStartColumns.length - 1 // actual number of destination partitions + + /* On each partition, re-organize into groups of columns: + (groupIndex, (sourcePartitionIndex, partCols)), + where partCols(colIdx) = partial column. + The groupIndex will be used to groupByKey. + The sourcePartitionIndex is used to ensure instance indices match up after the shuffle. + The partial columns will be stacked into full columns after the shuffle. + Note: By design, partCols will always have at least 1 column. + */ + val partialColumns: RDD[(Int, (Int, Array[SparseVector]))] = + rowStore.zipWithIndex().mapPartitionsWithIndex { case (sourcePartitionIndex, iterator) => + type SparseVectorBuffer = (Int, ArrayBuffer[Int], ArrayBuffer[Double]) + // columnSets(groupIndex)(colIdx) + // = column values for each instance in sourcePartitionIndex, + // where colIdx is a 0-based index for columns for groupIndex, + // and where column values are in sparse format: (size, indices, values) + val columnSetSizes = new Array[Array[Int]](numGroups) + val columnSetIndices = new Array[Array[ArrayBuffer[Int]]](numGroups) + val columnSetValues = new Array[Array[ArrayBuffer[Double]]](numGroups) + Range(0, numGroups).foreach { groupIndex => + val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) + columnSetSizes(groupIndex) = Array.fill[Int](numColsInGroup)(0) + columnSetIndices(groupIndex) = + Array.fill[ArrayBuffer[Int]](numColsInGroup)(new ArrayBuffer[Int]) + columnSetValues(groupIndex) = + Array.fill[ArrayBuffer[Double]](numColsInGroup)(new ArrayBuffer[Double]) + } + iterator.foreach { + case (dv: DenseVector, rowIndex: Long) => + Range(0, numGroups).foreach { groupIndex => + val fromCol = groupStartColumns(groupIndex) + val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) + var colIdx = 0 + while (colIdx < numColsInGroup) { + columnSetSizes(groupIndex)(colIdx) += 1 + columnSetIndices(groupIndex)(colIdx) += rowIndex.toInt + columnSetValues(groupIndex)(colIdx) += dv(fromCol + colIdx) + colIdx += 1 + } + } + case (sv: SparseVector, rowIndex: Long) => + /* + A sparse vector is chopped into groups (destination partitions). + We iterate through the non-zeros (indexed by k), going to the next group sv.indices(k) + passes the current group's boundary. + */ + var groupIndex = 0 + var k = 0 // index into SparseVector non-zeros + val nnz = sv.indices.length + while (groupIndex < numGroups && k < nnz) { + val fromColumn = groupStartColumns(groupIndex) + val groupEndColumn = groupStartColumns(groupIndex + 1) + while (k < nnz && sv.indices(k) < groupEndColumn) { + val columnIndex = sv.indices(k) // index in full row + val colIdx = columnIndex - fromColumn // index in group of columns + columnSetSizes(groupIndex)(colIdx) += 1 + columnSetIndices(groupIndex)(colIdx) += rowIndex.toInt + columnSetValues(groupIndex)(colIdx) += sv.values(k) + k += 1 + } + groupIndex += 1 + } + } + Range(0, numGroups).map { groupIndex => + val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) + val groupPartialColumns: Array[SparseVector] = Range(0, numColsInGroup).map { colIdx => + new SparseVector(columnSetSizes(groupIndex)(colIdx), + columnSetIndices(groupIndex)(colIdx).toArray, + columnSetValues(groupIndex)(colIdx).toArray) + }.toArray + (groupIndex, (sourcePartitionIndex, groupPartialColumns)) + }.toIterator + } + + // Shuffle data + val groupedPartialColumns: RDD[(Int, Iterable[(Int, Array[SparseVector])])] = + partialColumns.groupByKey() + + // Each target partition now holds its set of columns. + // Group the partial columns into full columns. + val fullColumns = groupedPartialColumns.flatMap { case (groupIndex, iterator) => + val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) + + // We do not know the number of rows or non-zeros per group, so we need to collect the groups + // before filling the full columns. + // collectedPartCols(sourcePartitionIndex)(colIdx) = partial column + val collectedPartCols = new Array[Array[SparseVector]](numSourcePartitions) + // nzCounts(colIdx)(sourcePartitionIndex) = number of non-zeros + val nzCounts = Array.fill[Array[Int]](numColsInGroup)(Array.fill[Int](numSourcePartitions)(0)) + iterator.foreach { case (sourcePartitionIndex, partCols) => + collectedPartCols(sourcePartitionIndex) = partCols + partCols.zipWithIndex.foreach { case (partCol, colIdx) => + nzCounts(colIdx)(sourcePartitionIndex) += partCol.indices.length + } + } + // nzOffsets(colIdx)(sourcePartitionIndex) = cumulative number of non-zeros + val nzOffsets: Array[Array[Int]] = nzCounts.map(_.scanLeft(0)(_ + _)) + + // Initialize full columns + val columnNZIndices: Array[Array[Int]] = + nzOffsets.map(colNZOffsets => new Array[Int](colNZOffsets.last)) + val columnNZValues: Array[Array[Double]] = + nzOffsets.map(colNZOffsets => new Array[Double](colNZOffsets.last)) + + // Fill columns + var colIdx = 0 // index within group + while (colIdx < numColsInGroup) { + var sourcePartitionIndex = 0 + while (sourcePartitionIndex < numSourcePartitions) { + val nzStartOffset = nzOffsets(colIdx)(sourcePartitionIndex) + val partColLength = nzOffsets(colIdx)(sourcePartitionIndex + 1) - nzStartOffset + Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx).indices, 0, + columnNZIndices(colIdx), nzStartOffset, partColLength) + Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx).values, 0, + columnNZValues(colIdx), nzStartOffset, partColLength) + sourcePartitionIndex += 1 + } + colIdx += 1 + } + val columns = columnNZIndices.zip(columnNZValues).map { case (indices, values) => + Vectors.sparse(numRows, indices, values) + } + val fromColumn = groupStartColumns(groupIndex) + val columnIndices = Range(0, numColsInGroup).map(_ + fromColumn) + columnIndices.zip(columns) + } + + fullColumns + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impurities.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impurities.scala new file mode 100644 index 000000000000..88dc9c2d4c25 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impurities.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.tree.impurity.{EntropyCalculator, GiniCalculator, ImpurityCalculator, + VarianceCalculator} + +/** + * Version of impurity aggregator which owns its data and is only for 1 node. + */ +private[tree] abstract class ImpurityAggregatorSingle(val stats: Array[Double]) + extends Serializable { + + def statsSize: Int = stats.length + + /** + * Add two aggregators: this + other + * @return This aggregator (modified). + */ + def add(other: ImpurityAggregatorSingle): this.type = { + var i = 0 + while (i < statsSize) { + stats(i) += other.stats(i) + i += 1 + } + this + } + + /** + * Subtract another aggregators from this one: this - other + * @return This aggregator (modified). + */ + def subtract(other: ImpurityAggregatorSingle): this.type = { + var i = 0 + while (i < statsSize) { + stats(i) -= other.stats(i) + i += 1 + } + this + } + + /** + * Update stats with the given label and instance weight. + * @return This aggregator (modified). + */ + def update(label: Double, instanceWeight: Double): this.type + + /** + * Update stats with the given label. + * @return This aggregator (modified). + */ + def update(label: Double): this.type = update(label, 1.0) + + /** Get an [[ImpurityCalculator]] for the current stats. */ + def getCalculator: ImpurityCalculator + + def deepCopy(): ImpurityAggregatorSingle + + /** Total (weighted) count of instances in this aggregator */ + def getCount: Double + + /** Resets this aggregator as though nothing has been added to it. */ + def clear(): this.type = { + var i = 0 + while (i < statsSize) { + stats(i) = 0.0 + i += 1 + } + this + } +} + +/** + * Version of Entropy aggregator which owns its data and is only for one node. + */ +private[tree] class EntropyAggregatorSingle private (stats: Array[Double]) + extends ImpurityAggregatorSingle(stats) with Serializable { + + def this(numClasses: Int) = this(new Array[Double](numClasses)) + + def update(label: Double, instanceWeight: Double): this.type = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"EntropyAggregatorSingle given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + stats(label.toInt) += instanceWeight + this + } + + def getCalculator: EntropyCalculator = new EntropyCalculator(stats) + + override def deepCopy(): ImpurityAggregatorSingle = new EntropyAggregatorSingle(stats.clone()) + + override def getCount: Double = stats.sum +} + +/** + * Version of Gini aggregator which owns its data and is only for one node. + */ +private[tree] class GiniAggregatorSingle private (stats: Array[Double]) + extends ImpurityAggregatorSingle(stats) with Serializable { + + def this(numClasses: Int) = this(new Array[Double](numClasses)) + + def update(label: Double, instanceWeight: Double): this.type = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"GiniAggregatorSingle given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + stats(label.toInt) += instanceWeight + this + } + + def getCalculator: GiniCalculator = new GiniCalculator(stats) + + override def deepCopy(): ImpurityAggregatorSingle = new GiniAggregatorSingle(stats.clone()) + + override def getCount: Double = stats.sum +} + +/** + * Version of Variance aggregator which owns its data and is only for one node. + */ +private[tree] class VarianceAggregatorSingle + extends ImpurityAggregatorSingle(new Array[Double](3)) with Serializable { + + def update(label: Double, instanceWeight: Double): this.type = { + stats(0) += instanceWeight + stats(1) += instanceWeight * label + stats(2) += instanceWeight * label * label + this + } + + def getCalculator: VarianceCalculator = new VarianceCalculator(stats) + + override def deepCopy(): ImpurityAggregatorSingle = { + val tmp = new VarianceAggregatorSingle() + stats.copyToArray(tmp.stats) + tmp + } + + override def getCount: Double = stats(0) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b74e3f1f4652..d8769643329b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -136,7 +136,7 @@ class Strategy ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { algo match { case Classification => require(numClasses >= 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 21ee49c45788..375bea0d56ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -205,7 +205,7 @@ private[spark] object DecisionTreeMetadata extends Logging { buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") } - /** + /** * Given the arity of a categorical feature (arity = number of categories), * return the number of bins for the feature if it is to be treated as an unordered feature. * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4637dcceea7f..b9d0f73c85c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.impurity import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.mllib.tree.model.Predict /** * :: Experimental :: @@ -158,6 +159,12 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten */ def prob(label: Double): Double = -1 + /** Get [[Predict]] struct. */ + def getPredict = { + val pred = this.predict + new Predict(predict = pred, prob = this.prob(pred)) + } + /** * Return the index of the largest array element. * Fails if the array is empty. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 091a0462c204..edb4ad88c48e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -103,13 +103,13 @@ private[spark] class ImpurityStats( s"right impurity = $rightImpurity" } - def leftImpurity: Double = if (leftImpurityCalculator != null) { + lazy val leftImpurity: Double = if (leftImpurityCalculator != null) { leftImpurityCalculator.calculate() } else { -1.0 } - def rightImpurity: Double = if (rightImpurityCalculator != null) { + lazy val rightImpurity: Double = if (rightImpurityCalculator != null) { rightImpurityCalculator.calculate() } else { -1.0 @@ -118,6 +118,16 @@ private[spark] class ImpurityStats( private[spark] object ImpurityStats { + /** + * Create stats object missing the child node info. + */ + def apply( + impurity: Double, + impurityCalculator: ImpurityCalculator, + valid: Boolean = true): ImpurityStats = { + new ImpurityStats(Double.NaN, impurity, impurityCalculator, null, null, valid) + } + /** * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to * denote that current split doesn't satisfies minimum info gain or diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala new file mode 100644 index 000000000000..6f05640a08a1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.tree.{LeafNode, InternalNode, ContinuousSplit} +import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo} +import org.apache.spark.ml.tree.impl.TreeUtil._ +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.collection.BitSet + +/** + * Test suite for [[AltDT]]. + */ +class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { + + //////////////////////////////// Integration tests ////////////////////////////////// + + test("run deep example") { + val data = Range(0, 3).map(x => LabeledPoint(math.pow(x, 3), Vectors.dense(x))) + val df = sqlContext.createDataFrame(data) + val dt = new DecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(10) + .setAlgorithm("byCol") + val model = dt.fit(df) + println(model.toDebugString) // TODO: remove println + assert(model.rootNode.isInstanceOf[InternalNode]) + val root = model.rootNode.asInstanceOf[InternalNode] + assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[LeafNode]) + val left = root.leftChild.asInstanceOf[InternalNode] + assert(left.leftChild.isInstanceOf[LeafNode], left.rightChild.isInstanceOf[LeafNode]) + } + + test("run example") { + val data = Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x))) + val df = sqlContext.createDataFrame(data) + val dt = new DecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(10) + .setAlgorithm("byCol") + val model = dt.fit(df) + println(model.toDebugString) // TODO: remove println + assert(model.rootNode.isInstanceOf[InternalNode]) + val root = model.rootNode.asInstanceOf[InternalNode] + assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode]) + val left = root.leftChild.asInstanceOf[InternalNode] + val right = root.rightChild.asInstanceOf[InternalNode] + val grandkids = Array(left.leftChild, left.rightChild, right.leftChild, right.rightChild) + assert(grandkids.forall(_.isInstanceOf[InternalNode])) + } + + //////////////////////////////// Helper classes ////////////////////////////////// + + test("FeatureVector") { + val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0)) + + val vCopy = v.deepCopy() + vCopy.values(0) = 1000 + assert(v.values(0) !== vCopy.values(0)) + + val original = Vectors.dense(0.7, 0.1, 0.3) + val v2 = FeatureVector.fromOriginal(1, 0, original) + assert(v === v2) + } + + test("PartitionInfo") { + val numRows = 4 + val col1 = + FeatureVector.fromOriginal(0, 0, Vectors.dense(0.8, 0.2, 0.1, 0.6)) + val col2 = + FeatureVector.fromOriginal(1, 3, Vectors.dense(0, 1, 0, 2)) + assert(col1.values.length === numRows) + assert(col2.values.length === numRows) + val nodeOffsets = Array(0, numRows) + val activeNodes = new BitSet(1) + activeNodes.set(0) + + val info = PartitionInfo(Array(col1, col2), nodeOffsets, activeNodes) + + // Create bitVector for splitting the 4 rows: L, R, L, R + // New groups are {0, 2}, {1, 3} + val bitVector = new BitSubvector(0, numRows) + bitVector.set(1) + bitVector.set(3) + + val newInfo = info.update(Array(bitVector), newNumNodeOffsets = 3) + + assert(newInfo.columns.length === 2) + val expectedCol1a = + new FeatureVector(0, 0, Array(0.1, 0.8, 0.2, 0.6), Array(2, 0, 1, 3)) + val expectedCol1b = + new FeatureVector(1, 3, Array(0, 0, 1, 2), Array(0, 2, 1, 3)) + assert(newInfo.columns(0) === expectedCol1a) + assert(newInfo.columns(1) === expectedCol1b) + assert(newInfo.nodeOffsets === Array(0, 2, 4)) + assert(newInfo.activeNodes.iterator.toSet === Set(0, 1)) + + // Create 2 bitVectors for splitting into: 0, 2, 1, 3 + val bv2a = new BitSubvector(0, 2) + bv2a.set(1) + val bv2b = new BitSubvector(2, 4) + bv2b.set(3) + + val newInfo2 = newInfo.update(Array(bv2a, bv2b), newNumNodeOffsets = 5) + + assert(newInfo2.columns.length === 2) + val expectedCol2a = + new FeatureVector(0, 0, Array(0.8, 0.1, 0.2, 0.6), Array(0, 2, 1, 3)) + val expectedCol2b = + new FeatureVector(1, 3, Array(0, 0, 1, 2), Array(0, 2, 1, 3)) + assert(newInfo2.columns(0) === expectedCol2a) + assert(newInfo2.columns(1) === expectedCol2b) + assert(newInfo2.nodeOffsets === Array(0, 1, 2, 3, 4)) + assert(newInfo2.activeNodes.iterator.toSet === Set(0, 1, 2, 3)) + } + + //////////////////////////////// Misc ////////////////////////////////// + + test("numUnorderedBins") { + // Note: We have duplicate bins (the inverse) for unordered features. This should be fixed! + assert(AltDT.numUnorderedBins(2) === 2) // 2 categories => 2 bins + assert(AltDT.numUnorderedBins(3) === 6) // 3 categories => 6 bins + } + + //////////////////////////////// Choosing splits ////////////////////////////////// + + test("computeBestSplits") { + } + + test("chooseSplit") { + } + + test("chooseOrderedCategoricalSplit") { + } + + // test("chooseUnorderedCategoricalSplit") { } + + test("chooseContinuousSplit: basic case") { + val featureIndex = 0 + val values = Seq(0.1, 0.2, 0.3, 0.4, 0.5) + val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0) + val impurity = Entropy + val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity) + val (split, stats) = AltDT.chooseContinuousSplit(featureIndex, values, labels, metadata) + split match { + case s: ContinuousSplit => + assert(s.featureIndex === featureIndex) + assert(s.threshold === 0.2) + case _ => + throw new AssertionError( + s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") + } + val fullImpurityStatsArray = Array(2.0, 3.0) + val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) + assert(stats.gain === fullImpurity) + assert(stats.impurity === fullImpurity) + assert(stats.impurityCalculator.stats === fullImpurityStatsArray) + assert(stats.leftImpurityCalculator.stats === Array(2.0, 0.0)) + assert(stats.rightImpurityCalculator.stats === Array(0.0, 3.0)) + assert(stats.valid) + } + + test("chooseContinuousSplit: some equal values") { + } + + // TODO: Add this test once we make this change. + // test("chooseContinuousSplit: return bad split if best split is on end") { } + + //////////////////////////////// Bit subvectors ////////////////////////////////// + + test("bitSubvectorFromSplit: 1 node") { + val col = + FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7)) + val fromOffset = 0 + val toOffset = col.values.length + val split = new ContinuousSplit(0, threshold = 0.5) + val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) + assert(bitv.from === fromOffset) + assert(bitv.to === toOffset) + assert(bitv.iterator.toSet === Set(3, 4)) + } + + test("bitSubvectorFromSplit: 2 nodes") { + // Initially, 1 split: (0, 2, 4) | (1, 3) + val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7), + Array(4, 2, 0, 1, 3)) + def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, expectedRight: Set[Int]): Unit = { + val split = new ContinuousSplit(0, threshold) + val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) + assert(bitv.from === fromOffset) + assert(bitv.to === toOffset) + assert(bitv.iterator.toSet === expectedRight) + } + // Left child node + checkSplit(0, 3, 0.15, Set(0, 1)) + checkSplit(0, 3, 0.2, Set(0)) + checkSplit(0, 3, 0.5, Set()) + // Right child node + checkSplit(3, 5, 0.1, Set(3, 4)) + checkSplit(3, 5, 0.65, Set(4)) + checkSplit(3, 5, 0.8, Set()) + } + + test("collectBitVectors with 1 vector") { + val col = + FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7)) + val numRows = col.values.length + val activeNodes = new BitSet(1) + activeNodes.set(0) + val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) + val partitionInfos = sc.parallelize(Seq(info)) + val bestSplit = new ContinuousSplit(0, threshold = 0.5) + val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit)) + assert(bitVectors.length === 1) + val bitv = bitVectors.head + assert(bitv.numBits === numRows) + assert(bitv.iterator.toArray === Array(3, 4)) + } + + test("collectBitVectors with 1 vector, with tied threshold") { + val col = new FeatureVector(0, 0, + Array(-4.0,-4.0,-2.0,-2.0,-1.0,-1.0,1.0,1.0), Array(3,7,2,6,1,5,0,4)) + val numRows = col.values.length + val activeNodes = new BitSet(1) + activeNodes.set(0) + val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) + val partitionInfos = sc.parallelize(Seq(info)) + val bestSplit = new ContinuousSplit(0, threshold = -2.0) + val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(bestSplit)) + assert(bitVectors.length === 1) + val bitv = bitVectors.head + assert(bitv.numBits === numRows) + assert(bitv.iterator.toArray === Array(0, 1, 4, 5)) + } + + //////////////////////////////// Active nodes ////////////////////////////////// + + test("computeActiveNodePeriphery") { + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala new file mode 100644 index 000000000000..0a8f838c0a31 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[BitSubvector]]. + */ +class BitSubvectorSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("basic set and get") { + val from = 1 + val to = 4 + val bs = new BitSubvector(from, to) + val setVals = Array(from, to - 1) + + assert(bs.numBits === to - from) + Range(from, to).foreach(x => assert(!bs.get(x))) + setVals.foreach { x => + bs.set(x) + assert(bs.get(x)) + } + assert(bs.iterator.toSet === setVals.toSet) + } + + test("|=") { + val from = 1 + val to = 4 + val bs = new BitSubvector(from, to) + val setVals = Array(from, to - 1) + setVals.foreach(i => bs.set(i)) + + val copyBs = new BitSubvector(from, to) + copyBs |= bs + assert(copyBs.iterator.toSet === setVals.toSet) + } + + test("merge") { + val b1 = new BitSubvector(0, 5) + b1.set(1) + val b2 = new BitSubvector(5, 7) + b2.set(5) + val b3 = new BitSubvector(9, 12) + b3.set(11) + val parts1 = Array(b1) + val parts2 = Array(b2, b3) + val newParts = BitSubvector.merge(parts1, parts2) + + val r1 = new BitSubvector(0, 7) + r1.set(1) + r1.set(5) + val r2 = new BitSubvector(9, 12) + r2.set(11) + val expectedParts = Array(r1, r2) + newParts.zip(expectedParts).foreach { case (x, y) => + assert(x.from === y.from) + assert(x.to === x.to) + assert(x.iterator.toSet === y.iterator.toSet) + } + } + + test("merge with empty BitSubvectors") { + val parts = BitSubvector.merge(Array.empty[BitSubvector], Array.empty[BitSubvector]) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala new file mode 100644 index 000000000000..483fb8568f8b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.impl.TreeUtil._ +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[TreeUtil]]. + */ +class TreeUtilSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def checkDense(rows: Seq[Vector]): Unit = { + val numRowPartitions = 2 + val rowStore = sc.parallelize(rows, numRowPartitions) + val colStore = rowToColumnStoreDense(rowStore) + val numColPartitions = colStore.partitions.length + val cols: Map[Int, Vector] = colStore.collect().toMap + val numRows = rows.size + if (numRows == 0) { + assert(cols.isEmpty) + return + } + val numCols = rows.head.size + if (numCols == 0) { + assert(cols.isEmpty) + return + } + rows.zipWithIndex.foreach { case (row, i) => + var j = 0 + while (j < numCols) { + assert(row(j) == cols(j)(i)) + j += 1 + } + } + val expectedNumColPartitions = math.min(rowStore.partitions.length, numCols) + assert(numColPartitions === expectedNumColPartitions) + } + + private def checkSparse(rows: Seq[Vector]): Unit = { + val numRowPartitions = 2 + val overPartitionFactor = 2 + val rowStore = sc.parallelize(rows, numRowPartitions) + val colStore = rowToColumnStoreSparse(rowStore, overPartitionFactor) + val numColPartitions = colStore.partitions.length + val cols: Map[Int, Vector] = colStore.collect().toMap + val numRows = rows.size + // Check cases with 0 rows or cols + if (numRows == 0) { + assert(cols.isEmpty) + return + } + val numCols = rows.head.size + if (numCols == 0) { + assert(cols.isEmpty) + return + } + // Check values (and count non-zeros too) + var expectedNumNonZeros = 0 + rows.zipWithIndex.foreach { case (row, i) => + var j = 0 + while (j < numCols) { + assert(row(j) == cols(j)(i)) + if (row(j) != 0) expectedNumNonZeros += 1 + j += 1 + } + } + // Check sparsity + val numNonZeros = cols.values.map { + case sv: SparseVector => sv.indices.length + case _ => throw new RuntimeException( + "checkSparse() found column which was not converted to SparseVector.") + }.sum + assert(numNonZeros === expectedNumNonZeros) + // Check partitions to make sure they each contain consecutive columns. + val colsByPartition: Array[(Int, Array[(Int, Vector)])] = colStore.mapPartitionsWithIndex { + case (partitionIndex, iterator) => + val partCols = new mutable.ArrayBuffer[(Int, Vector)] + iterator.foreach(col => partCols += col) + Iterator((partitionIndex, iterator.toArray)) + }.collect() + colsByPartition.foreach { case (partitionIndex, partCols) => + var j = 0 + while (j + 1 < partCols.length) { + val curColIndex = partCols(j)._1 + val nextColIndex = partCols(j + 1)._1 + assert(curColIndex + 1 == nextColIndex) + j += 1 + } + } + } + + test("rowToColumnStore: small dense") { + val rows = Seq( + Vectors.dense(1.0, 2.0, 3.0, 4.0), + Vectors.dense(1.1, 2.1, 3.1, 4.1), + Vectors.dense(1.2, 2.2, 3.2, 4.2) + ) + checkDense(rows) + checkSparse(rows) + } + + test("rowToColumnStore: small sparse") { + val rows = Seq( + Vectors.sparse(4, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(4, Array(1, 2), Array(1.1, 2.1)), + Vectors.sparse(4, Array(2, 3), Array(1.2, 2.2)) + ) + checkDense(rows) + checkSparse(rows) + } + + test("rowToColumnStore: large dense") { + // Note: All values must be non-zero since rowToColumnStoreSparse() automatically ignores + // zero-valued elements. + val numRows = 100 + val numCols = 90 + val rows = Range(0, numRows).map { i => + Vectors.dense(Range(0, numCols).map(_ + numCols * i + 1.0).toArray) + } + checkDense(rows) + checkSparse(rows) + } + + test("rowToColumnStore: mixed dense and sparse") { + val rows = Seq( + Vectors.dense(1.0, 2.0, 3.0, 4.0), + Vectors.sparse(4, Array(1, 2), Array(1.1, 2.1)), + Vectors.dense(1.2, 2.2, 3.2, 4.2), + Vectors.sparse(4, Array(0, 2), Array(1.3, 2.3)) + ) + checkDense(rows) + checkSparse(rows) + } + + test("rowToColumnStore: 0 rows") { + val rows = Seq.empty[Vector] + checkDense(rows) + checkSparse(rows) + } + + test("rowToColumnStore: 0 cols") { + val rows = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(Array.empty[Double]), + Vectors.dense(Array.empty[Double]) + ) + checkDense(rows) + checkSparse(rows) + } +}