diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda032..5d92518dfd0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -40,20 +40,31 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType /** - * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for classification. - * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. + * It supports both continuous and categorical features. * - * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * The implemention offers both Stochastic Gradient Boosting, as in J.H. Friedman 1999, + * "Stochastic Gradient Boosting" and TreeBoost, as in Friedman 1999 + * "Greedy Function Approximation: A Gradient Boosting Machine" * - * Notes on Gradient Boosting vs. TreeBoost: - * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * Notes on Stochastic Gradient Boosting (SGB) vs. TreeBoost: + * - TreeBoost algorithms are a subset of SGB algorithms. * - Both algorithms learn tree ensembles by minimizing loss functions. - * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes - * based on the loss function, whereas the original gradient boosting method does not. - * - We expect to implement TreeBoost in the future: - * [https://issues.apache.org/jira/browse/SPARK-4240] + * - TreeBoost has two additional properties that general SGB trees don't: + * - The loss function gradients are directly used as an approximate impurity measure. + * - The value reported at a leaf is given by optimizing the loss function is optimized on + * that leaf node's partition of the data, rather than just being the mean. + * - In the case of squared error loss, variance impurity and mean leaf estimates happen + * to make the SGB and TreeBoost algorithms identical. + * + * [[GBTClassifier]] will use the `"loss-based"` impurity by default, conforming to + * TreeBoost behavior. For SGB, set impurity to `"variance"`. + * + * Currently, however, even TreeBoost behavior uses variance impurity for split selection for + * ease and speed. This is the approach `R`'s + * [[https://cran.r-project.org/web/packages/gbm/index.html gbm package]] takes. */ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @@ -91,14 +102,16 @@ class GBTClassifier @Since("1.4.0") ( override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** - * The impurity setting is ignored for GBT models. - * Individual trees are built using impurity "Variance." + * Impurity-setting is currently only offered as a way to recover pre-2.0.2 Spark GBT + * behavior (which is Stochastic Gradient Boosting): set impurity to `"variance"` for this. + * @param value new impurity value + * @return this */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { - logWarning("GBTClassifier.setImpurity should NOT be used") - this - } + @deprecated( + "Control over impurity will be removed, as it is an implementation detail of GBTs", + "2.0.2") + override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: @@ -136,7 +149,8 @@ class GBTClassifier @Since("1.4.0") ( LabeledPoint(label, features) } val numFeatures = oldDataset.first().features.size - val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val boostingStrategy = super.getOldBoostingStrategy( + categoricalFeatures, OldAlgo.Classification, getOldImpurity) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) @@ -156,11 +170,14 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { - - /** Accessor for supported loss settings: logistic */ + /** Accessor for supported loss settings: logistic, bernoulli */ @Since("1.4.0") final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + /** Accessor for supported entropy settings: loss-based or variance */ + @Since("2.1") + final val supportedImpurities: Array[String] = GBTClassifierParams.supportedImpurities + @Since("2.0.0") override def load(path: String): GBTClassifier = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index fa69d60836e6..dd8957412ded 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -23,7 +23,6 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap @@ -42,21 +41,31 @@ import org.apache.spark.sql.functions._ * learning algorithm for regression. * It supports both continuous and categorical features. * - * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * The implemention offers both Stochastic Gradient Boosting, as in J.H. Friedman 1999, + * "Stochastic Gradient Boosting" and TreeBoost, as in Friedman 1999 + * "Greedy Function Approximation: A Gradient Boosting Machine" * - * Notes on Gradient Boosting vs. TreeBoost: - * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * Notes on Stochastic Gradient Boosting (SGB) vs. TreeBoost: + * - TreeBoost algorithms are a subset of SGB algorithms. * - Both algorithms learn tree ensembles by minimizing loss functions. - * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes - * based on the loss function, whereas the original gradient boosting method does not. - * - When the loss is SquaredError, these methods give the same result, but they could differ - * for other loss functions. - * - We expect to implement TreeBoost in the future: - * [https://issues.apache.org/jira/browse/SPARK-4240] + * - TreeBoost has two additional properties that general SGB trees don't: + * - The loss function gradients are directly used as an approximate impurity measure. + * - The value reported at a leaf is given by optimizing the loss function is optimized on + * that leaf node's partition of the data, rather than just being the mean. + * - In the case of squared error loss, variance impurity and mean leaf estimates happen + * to make the SGB and TreeBoost algorithms identical. + * + * [[GBTRegressor]] will use the usual `"variance"` impurity by default, conforming to + * SGB behavior. For TreeBoost, set impurity to `"loss-based"`. Note TreeBoost is currently + * incompatible with absolute error. + * + * Currently, however, even TreeBoost behavior uses variance impurity for split selection for + * ease and speed. This is the approach `R`'s + * [[https://cran.r-project.org/web/packages/gbm/index.html gbm package]] takes. */ @Since("1.4.0") class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Predictor[Vector, GBTRegressor, GBTRegressionModel] + extends Regressor[Vector, GBTRegressor, GBTRegressionModel] with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") @@ -88,14 +97,18 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** - * The impurity setting is ignored for GBT models. - * Individual trees are built using impurity "Variance." + * Note that the loss-based impurity is currently NOT compatible with absolute loss. + * + * Impurity-setting is currently only offered as a way to recover pre-2.0.2 Spark GBT + * behavior (which is Stochastic Gradient Boosting): set impurity to `"variance"` for this. + * @param value new impurity value + * @return this */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { - logWarning("GBTRegressor.setImpurity should NOT be used") - this - } + @deprecated( + "Control over impurity will be removed, as it is an implementation detail of GBTs", + "2.0.2") + override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: @Since("1.4.0") @@ -113,7 +126,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) // Parameters from GBTRegressorParams: - /** @group setParam */ + /** + * Note that the loss-based impurity is currently NOT compatible with absolute loss. + * @group setParam + */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) @@ -122,7 +138,8 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size - val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val boostingStrategy = super.getOldBoostingStrategy( + categoricalFeatures, OldAlgo.Regression, getOldImpurity) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) @@ -141,11 +158,17 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { - - /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + /** + * Accessor for supported loss settings: squared (L2), absolute (L1), + * gaussian (alias for squared), laplace (alias for absolute) + * */ @Since("1.4.0") final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + /** Accessor for support entropy settings: loss-based or variance */ + @Since("2.1") + final val supportedImpurities: Array[String] = GBTRegressorParams.supportedImpurities + @Since("2.0.0") override def load(path: String): GBTRegressor = super.load(path) } @@ -163,7 +186,7 @@ class GBTRegressionModel private[ml]( private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], override val numFeatures: Int) - extends PredictionModel[Vector, GBTRegressionModel] + extends RegressionModel[Vector, GBTRegressionModel] with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e4..bea5794dda50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import org.apache.spark.ml.tree.impurity._ import org.apache.spark.mllib.tree.impurity._ @@ -33,11 +34,13 @@ private[spark] class DTStatsAggregator( /** * [[ImpurityAggregator]] instance specifying the impurity type. - */ - val impurityAggregator: ImpurityAggregator = metadata.impurity match { + */ + private val impurityAggregator: ImpurityAggregator = metadata.impurity match { + // TODO(SPARK-16728): this switch should be replaced by a virtual call case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) case Variance => new VarianceAggregator() + case ApproxBernoulliImpurity => new ApproxBernoulliAggregator() case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d..07f7198e2f77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -258,11 +258,11 @@ private[spark] object GradientBoostedTrees extends Logging { val baseLearnerWeights = new Array[Double](numIterations) val loss = boostingStrategy.loss val learningRate = boostingStrategy.learningRate - // Prepare strategy for individual trees, which use regression with variance impurity. + // Prepare strategy for individual trees, which all use regression. + // TODO(SPARK-16728): changing the strategy here is confusing and should be avoided val treeStrategy = boostingStrategy.treeStrategy.copy val validationTol = boostingStrategy.validationTol treeStrategy.algo = OldAlgo.Regression - treeStrategy.impurity = OldVariance treeStrategy.assertValid() // Cache input @@ -327,6 +327,8 @@ private[spark] object GradientBoostedTrees extends Logging { // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. + // Note: For loss-based impurities, which have optimized loss-based leaf predictions, + // using a constant learning rate is correct. baseLearnerWeights(m) = learningRate predError = updatePredictionError( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impurity/ApproxBernoulliImpurity.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impurity/ApproxBernoulliImpurity.scala new file mode 100644 index 000000000000..db671257160b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impurity/ApproxBernoulliImpurity.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.mllib.tree.impurity._ + +/** + * [[ApproxBernoulliImpurity]] currently uses variance as a (proxy) impurity measure + * during tree construction. The main purpose of the class is to have an alternative + * leaf prediction calculation. + * + * Only data with examples each of weight 1.0 is supported. + * + * Class for calculating variance during regression. + */ +@Since("2.1") +private[spark] object ApproxBernoulliImpurity extends Impurity { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value, or 0 if totalCount = 0 + */ + @Since("2.1") + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = + throw new UnsupportedOperationException("ApproxBernoulliImpurity.calculate") + + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 + */ + @Since("2.1") + @DeveloperApi + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + Variance.calculate(count, sum, sumSquares) + } +} + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + */ +private[spark] class ApproxBernoulliAggregator + extends ImpurityAggregator(statsSize = 4) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + allStats(offset) += instanceWeight + allStats(offset + 1) += instanceWeight * label + allStats(offset + 2) += instanceWeight * label * label + allStats(offset + 3) += instanceWeight * Math.abs(label) + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): ApproxBernoulliCalculator = { + new ApproxBernoulliCalculator(allStats.view(offset, offset + statsSize).toArray) + } +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[spark] class ApproxBernoulliCalculator(stats: Array[Double]) + extends ImpurityCalculator(stats) { + + require(stats.length == 4, + s"ApproxBernoulliCalculator requires sufficient statistics array stats to be of length 4," + + s" but was given array of length ${stats.length}.") + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: ApproxBernoulliCalculator = new ApproxBernoulliCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = ApproxBernoulliImpurity.calculate(stats(0), stats(1), stats(2)) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats(0).toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + // Per Friedman 1999, we use a single Newton-Raphson step from gamma = 0 to find the + // optimal leaf prediction, the solution gamma to the minimization problem: + // L = sum((p_i, y_i) in leaf) 2 log(1 + exp(-2 y_i (p_i + gamma))) + // Where p_i is the previous GBT model's prediction for point i. The above with the factor of + // 2 is equivalent to the LogLoss defined in Spark. + // + // We solve this problem by iterative root-finding for the gradients themselves (this turns + // out to be equivalent to Newton's optimization method anyway). + // + // The single NR step from 0 yields the solution H^{-1} s, where H is the Hessian + // and s is the gradient for L wrt gamma above at gamma = 0. + // + // The derivative of the i-th term wrt gamma is + // - 4 y_i / (1 + E), where E = exp(2 y_i (p_i + gamma)) + // At gamma = 0 it's equivalent to the gradient of LogLoss for i, the sum of which is + // stored in stats(1). + // + // The Hessian of the i-th term wrt to gamma is + // 8 y_i^2 E / (1 + E)^2 = 8 y_i^2 / (1 + E) - 8 y_i^2 / (1 + E)^2 + // At gamma = 0, the latter term is the half of the square of the gradient of the LogLoss for i. + // Since y_i is one of {-1, +1}, the first term is the absolute value of the gradient of the + // LogLoss for i times 2. These statistics are stored in stats(2) and stats(3), respectively. + stats(1) / (2 * stats(3) - stats(2) / 2) + } + + override def toString: String = { + s"ApproxBernoulliAggregator(cnt = ${stats(0)}, sum = ${stats(1)}," + + s"sum2 = ${stats(2)}, sumAbs = ${stats(3)})" + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index d3cbc363799a..74b9d1631113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -336,24 +336,31 @@ private[ml] object DecisionTreeModelReadWrite { import sparkSession.implicits._ implicit val format = DefaultFormats - // Get impurity to construct ImpurityCalculator for each node - val impurityType: String = { - val impurityJson: JValue = metadata.getParamValue("impurity") - Param.jsonDecode[String](compact(render(impurityJson))) - } - val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), metadata) } /** * Given all data for all nodes in a tree, rebuild the tree. * @param data Unsorted node data - * @param impurityType Impurity type for this tree + * @param metadata metadata for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], metadata: DefaultParamsReader.Metadata): Node = { + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val loss: String = if (impurityType == "loss-based") { + val lossJson: JValue = metadata.getParamValue("lossType") + Param.jsonDecode[String](compact(render(lossJson))) + } else { + "" + } + // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -365,7 +372,7 @@ private[ml] object DecisionTreeModelReadWrite { // traversal, this guarantees that child nodes will be built before parent nodes. val finalNodes = new Array[Node](nodes.length) nodes.reverseIterator.foreach { case n: NodeData => - val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val impurityStats = ImpurityCalculator.getCalculator(impurityType, loss, n.impurityStats) val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) @@ -431,12 +438,6 @@ private[ml] object EnsembleModelReadWrite { implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) - // Get impurity to construct ImpurityCalculator for each node - val impurityType: String = { - val impurityJson: JValue = metadata.getParamValue("impurity") - Param.jsonDecode[String](compact(render(impurityJson))) - } - val treesMetadataPath = new Path(path, "treesMetadata").toString val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { @@ -454,7 +455,7 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, metadata) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 57c7e44e9760..5b5a05351f47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.tree.impurity._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} @@ -40,6 +41,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * Maximum depth of the tree (>= 0). * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * (default = 5) + * * @group param */ final val maxDepth: IntParam = @@ -52,6 +54,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * on features at each node. More bins give higher granularity. * Must be >= 2 and >= number of categories in any categorical feature. * (default = 32) + * * @group param */ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + @@ -64,6 +67,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * the split will be discarded as invalid. * Should be >= 1. * (default = 1) + * * @group param */ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + @@ -74,6 +78,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** * Minimum information gain for a split to be considered at a tree node. * (default = 0.0) + * * @group param */ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", @@ -83,6 +88,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be * split per iteration, and its aggregates may exceed this size. * (default = 256 MB) + * * @group expertParam */ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", @@ -95,6 +101,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * Caching can speed up training of deeper trees. Users can set how often should the * cache be checkpointed or disable it by setting checkpointInterval. * (default = false) + * * @group expertParam */ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + @@ -151,6 +158,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * [[org.apache.spark.SparkContext]]. * Must be >= 1. * (default = 10) + * * @group setParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) @@ -183,18 +191,12 @@ private[ml] trait DecisionTreeParams extends PredictorParams */ private[ml] trait TreeClassifierParams extends Params { + // Impurity should be overridden when setting a default. This should be a def, but has + // to be a val to maintain the proper documentation. /** - * Criterion used for information gain calculation (case-insensitive). - * Supported: "entropy" and "gini". - * (default = gini) * @group param */ - final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + - " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - - setDefault(impurity -> "gini") + val impurity: Param[String] = new Param(this, "", "") /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -220,26 +222,36 @@ private[ml] object TreeClassifierParams { final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } +private[ml] trait TreeClassifierParamsWithDefault extends TreeClassifierParams { + /** + * Criterion used for information gain calculation (case-insensitive). + * Also used for terminal leaf value prediction. + * Supported: "gini" (default) and "entropy" + * + * @group param + */ + override val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + + setDefault(impurity -> "gini") +} + private[ml] trait DecisionTreeClassifierParams - extends DecisionTreeParams with TreeClassifierParams + extends DecisionTreeParams with TreeClassifierParamsWithDefault /** * Parameters for Decision Tree-based regression algorithms. */ private[ml] trait TreeRegressorParams extends Params { + // Impurity should be overridden when setting a default. This should be a def, but has + // to be a val to maintain the proper documentation. /** - * Criterion used for information gain calculation (case-insensitive). - * Supported: "variance". - * (default = variance) * @group param */ - final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + - " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) - - setDefault(impurity -> "variance") + val impurity: Param[String] = new Param(this, "", "") /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -264,8 +276,24 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait TreeRegressorParamsWithDefault extends TreeRegressorParams { + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance" (default) + * + * @group param + */ + override val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeRegressorParams.supportedImpurities + .contains(value.toLowerCase)) + + setDefault(impurity -> "variance") +} + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams - with TreeRegressorParams with HasVarianceCol { + with TreeRegressorParamsWithDefault with HasVarianceCol { override protected def validateAndTransformSchema( schema: StructType, @@ -290,6 +318,7 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. * (default = 1.0) + * * @group param */ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", @@ -340,10 +369,10 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { * - sqrt: recommended by Breiman manual for random forests * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest * package. + * * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for * random forests]] - * * @group param */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", @@ -376,6 +405,7 @@ private[ml] trait HasNumTrees extends Params { * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * TODO: Change to always do bootstrapping (simpler). SPARK-7130 * (default = 20) + * * @group param */ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", @@ -403,16 +433,16 @@ private[spark] object RandomForestParams { } private[ml] trait RandomForestClassifierParams - extends RandomForestParams with TreeClassifierParams + extends RandomForestParams with TreeClassifierParamsWithDefault private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeClassifierParams + with HasFeatureSubsetStrategy with TreeClassifierParamsWithDefault private[ml] trait RandomForestRegressorParams - extends RandomForestParams with TreeRegressorParams + extends RandomForestParams with TreeRegressorParamsWithDefault private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeRegressorParams + with HasFeatureSubsetStrategy with TreeRegressorParamsWithDefault /** * Parameters for Gradient-Boosted Tree algorithms. @@ -441,6 +471,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each * estimator. * (default = 0.1) + * * @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -454,8 +485,9 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], - oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, oldImpurity) // NOTE: The old API does not support "seed" so we ignore it. new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) } @@ -465,17 +497,48 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS } private[ml] object GBTClassifierParams { - // The losses below should be lowercase. - /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + // The values below should be lowercase. + /** Accessor for supported loss settings: logistic, bernoulli */ + final val supportedLossTypes: Array[String] = Array("logistic", "bernoulli") + /** Accessor for support entropy settings: loss-based or variance */ + final val supportedImpurities: Array[String] = Array("loss-based", "variance") + final def getLossBasedImpurity(loss: String): OldImpurity = loss match { + case "logistic" | "bernoulli" => ApproxBernoulliImpurity + case _ => throw new RuntimeException( + s"GBTClassifier does not have loss-based impurity for loss ${loss}") + } } private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + /** + * Criterion used for information gain calculation (case-insensitive). + * Also used for terminal leaf value prediction. + * Supported: "loss-based" (default) and "variance" + * + * @group expertParam + */ + override val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedImpurities.contains(value.toLowerCase)) + + /** (private[ml]) convert new impurity to old impurity */ + override private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "loss-based" => GBTClassifierParams.getLossBasedImpurity($(lossType)) + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"GBTClassifier was given unrecognized impurity: $impurity") + } + } + /** * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) + * Supported: "logistic" (default), "bernoulli" (alias for "logistic") + * * @group param */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + @@ -483,7 +546,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) - setDefault(lossType -> "logistic") + setDefault(lossType -> "logistic", impurity -> "loss-based") /** @group getParam */ def getLossType: String = $(lossType).toLowerCase @@ -491,7 +554,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { getLossType match { - case "logistic" => OldLogLoss + case "bernoulli" | "logistic" => OldLogLoss case _ => // Should never happen because of check in setter method. throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") @@ -501,16 +564,55 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam private[ml] object GBTRegressorParams { // The losses below should be lowercase. - /** Accessor for supported loss settings: squared (L2), absolute (L1) */ - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + /** + * Accessor for supported loss settings: squared (L2), absolute (L1), + * gaussian (alias for squared), laplace (alias for absolute) + * */ + final val supportedLossTypes: Array[String] = Array( + "squared", "absolute", "gaussian", "laplace") + /** Accessor for supported impurity settings: loss-based or variance */ + final val supportedImpurities: Array[String] = Array("loss-based", "variance") + + final def getLossBasedImpurity(loss: String): OldImpurity = loss match { + case "gaussian" | "squared" => OldVariance + case "laplace" | "absolute" => throw new RuntimeException( + "GBTRegressor does not yet support loss-based impurity for absolute or laplace loss") + // TODO(SPARK-4240) use ApproxLaplaceImpurity here once it is implemented. + case _ => throw new RuntimeException( + s"GBTRegressor does not have loss-based impurity for loss ${loss}") + } } private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + /** + * Criterion used for information gain calculation (case-insensitive). + * Also used for terminal leaf value prediction. + * Supported: "loss-based" and "variance" (default) + * + * @group expertParam + */ + override val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedImpurities.contains(value.toLowerCase)) + + override private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "loss-based" => GBTRegressorParams.getLossBasedImpurity($(lossType)) + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"GBTRegressor was given unrecognized impurity: $impurity") + } + } + /** * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) + * Supported: "gaussian" (default), "squared" (alias for "gaussian"), "laplace", + * "absolute" (alias for "laplace"). + * * @group param */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + @@ -518,7 +620,7 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) - setDefault(lossType -> "squared") + setDefault(lossType -> "gaussian", impurity -> "variance") /** @group getParam */ def getLossType: String = $(lossType).toLowerCase @@ -526,8 +628,8 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError + case "squared" | "gaussian" => OldSquaredError + case "absolute" | "laplace" => OldAbsoluteError case _ => // Should never happen because of check in setter method. throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d0a..43088fa883b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -808,6 +808,7 @@ private[python] class PythonMLLibAPI extends Serializable { boostingStrategy.treeStrategy.setMaxDepth(maxDepth) boostingStrategy.treeStrategy.setMaxBins(maxBins) boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + boostingStrategy.treeStrategy.setImpurity(Variance) val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) try { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b34e1b1b56c4..355aeaa314fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -21,6 +21,7 @@ import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since +import org.apache.spark.ml.tree.impurity.ApproxBernoulliImpurity import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} @@ -144,7 +145,12 @@ class Strategy @Since("1.3.0") ( s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + s" Valid settings: Gini, Entropy") case Regression => - require(impurity == Variance, + // Regression is used under-the-hood in the GBTClassifier, so all of its impurities + // could be valid as regression impurities. + // TODO(SPARK-16728): the above is only necessary since mllib.tree doesn't have + // spark.ml's usual Param.isValid()-checking mechanism. It should be removed with + // the resolution of SPARK-16728. + require(Set(Variance, ApproxBernoulliImpurity).contains(impurity), s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + s" Valid settings: Variance") case _ => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index a5bdc2c6d2c9..e06a3425fd64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.impurity import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.ml.tree.impurity.ApproxBernoulliCalculator /** * Trait for calculating information gain. @@ -31,6 +32,7 @@ trait Impurity extends Serializable { /** * :: DeveloperApi :: * information calculation for multiclass classification + * * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 @@ -42,6 +44,7 @@ trait Impurity extends Serializable { /** * :: DeveloperApi :: * information calculation for regression + * * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels @@ -56,12 +59,14 @@ trait Impurity extends Serializable { * Interface for updating views of a vector of sufficient statistics, * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. + * * @param statsSize Length of the vector of sufficient statistics for one bin. */ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { /** * Merge the stats from one bin into another. + * * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for (node, feature, bin) which is modified by the merge. * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified. @@ -76,6 +81,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser /** * Update stats for one (node, feature, bin) with the given label. + * * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ @@ -83,6 +89,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser /** * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ @@ -93,6 +100,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * Stores statistics for one (node, feature, bin) for calculating impurity. * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific * (node, feature, bin). + * * @param stats Array of sufficient statistics for a (node, feature, bin). */ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { @@ -181,13 +189,26 @@ private[spark] object ImpurityCalculator { /** * Create an [[ImpurityCalculator]] instance of the given impurity type and with - * the given stats. + * the given stats. If impurity is "loss-based", then the loss should be specified as well. */ - def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + def getCalculator(impurity: String, + loss: String, + stats: Array[Double]): ImpurityCalculator = { impurity match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) + // TODO(vlad17): this dependency into spark.ml is only necessary until SPARK-16728 is resolved + // At that point, we can cleanse ourselves of this case-derived-class anti-pattern, + // in turn preventing duplicating hacks like this from ever occurring. + case "loss-based" => loss match { + case "gaussian" | "squared" => new VarianceCalculator(stats) + case "bernoulli" | "logistic" => new ApproxBernoulliCalculator(stats) + case _ => + throw new IllegalArgumentException( + s"ImpurityCalculator builder found impurity type $impurity but could not recognize" + + s"loss $loss") + } case _ => throw new IllegalArgumentException( s"ImpurityCalculator builder did not recognize impurity type: $impurity") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 3492709677d4..5551714cdbf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,15 +19,18 @@ package org.apache.spark.ml.classification import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.tree.impurity.ApproxBernoulliAggregator +import org.apache.spark.ml.util.{DefaultReadWriteTest, GBTSuiteHelper, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -52,8 +55,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext override def beforeAll() { super.beforeAll() - data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) - .map(_.asML) + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints( + numFeatures = 10, numInstances = 100), 2).map(_.asML) trainData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) .map(_.asML) @@ -70,21 +73,68 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } - test("Binary classification with continuous features: Log Loss") { - val categoricalFeatures = Map.empty[Int, Int] + test("GBT-specific param defaults") { + val gbt = new GBTClassifier() + assert(gbt.getImpurity === "loss-based") + assert(gbt.getLossType === "logistic") + } + + test("GBT-specific param support") { + val gbt = new GBTClassifier() + for (impurity <- GBTClassifier.supportedImpurities) { + gbt.setImpurity(impurity) + } + for (lossType <- GBTClassifier.supportedLossTypes) { + gbt.setLossType(lossType) + } + } + + test("Binary classification: Variance-based impurity + Log Loss") { + // Using a non-loss-based impurity we can just check for equivalence with the old API testCombinations.foreach { case (maxIter, learningRate, subsamplingRate) => val gbt = new GBTClassifier() .setMaxDepth(2) .setSubsamplingRate(subsamplingRate) + .setImpurity("variance") .setLossType("logistic") .setMaxIter(maxIter) .setStepSize(learningRate) .setSeed(123) - compareAPIs(data, None, gbt, categoricalFeatures) + compareAPIs(data, None, gbt, Map.empty[Int, Int]) } } + test("approximate bernoulli impurity") { + def grad(pred: Double, label: Double): Double = { + -4 * label / (1 + math.exp(2 * label * pred)) + } + def hess(pred: Double, label: Double): Double = { + val numerator = 8 * math.exp(2 * label * pred) * math.pow(label, 2) + val denominator = math.pow(1 + Math.exp(2 * label * pred), 2) + numerator / denominator + } + val variance = (responses: Seq[Double]) => + GBTSuiteHelper.computeCalculator(responses, new VarianceAggregator).calculate() + val newtonRaphson = (prediction: Double, labels: Seq[Double], responses: Seq[Double]) => + -labels.map(grad(prediction, _)).sum / labels.map(hess(prediction, _)).sum + + GBTSuiteHelper.verifyCalculator( + new ApproxBernoulliAggregator(), + LogLoss, + expectedImpurity = variance, + expectedPrediction = newtonRaphson) + } + + test("Binary classification: Loss-based impurity + Log Loss") { + val impurityName = "bernoulli" + val loss = LogLoss + val expectedImpurity = new ApproxBernoulliAggregator() + + GBTSuiteHelper.verifyGBTConstruction( + spark, classification = true, impurityName, loss, expectedImpurity) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString @@ -125,6 +175,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext GBTClassifier.supportedLossTypes.foreach { loss => val gbt = new GBTClassifier() .setMaxIter(maxIter) + .setImpurity("variance") .setMaxDepth(2) .setLossType(loss) .setValidationTol(0.0) @@ -176,7 +227,6 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("Feature importance with toy data") { val numClasses = 2 val gbt = new GBTClassifier() - .setImpurity("Gini") .setMaxDepth(3) .setMaxIter(5) .setSubsamplingRate(1.0) @@ -210,11 +260,16 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val gbt = new GBTClassifier() val rdd = TreeTests.getTreeReadWriteData(sc) - val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + // Test for all different impurity types. + for (impurity <- Seq(Some("loss-based"), Some("variance"), None)) { + val allParamSettings = TreeTests.allParamSettings ++ + Map("lossType" -> "logistic") ++ + impurity.map("impurity" -> _).toMap - val continuousData: DataFrame = - TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + } } } @@ -223,15 +278,18 @@ private object GBTClassifierSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + * + * The old API only supports variance-based impurity, so gbt should have that setting. */ def compareAPIs( data: RDD[LabeledPoint], validationData: Option[RDD[LabeledPoint]], gbt: GBTClassifier, categoricalFeatures: Map[Int, Int]): Unit = { + assert(gbt.getImpurity === "variance") val numFeatures = data.first().features.size val oldBoostingStrategy = - gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification, Variance) val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML)) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea9b..656986032bcc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -21,10 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, GBTSuiteHelper, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.loss.SquaredError import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -59,23 +61,72 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext .map(_.asML) } - test("Regression with continuous features") { - val categoricalFeatures = Map.empty[Int, Int] - GBTRegressor.supportedLossTypes.foreach { loss => - testCombinations.foreach { - case (maxIter, learningRate, subsamplingRate) => - val gbt = new GBTRegressor() - .setMaxDepth(2) - .setSubsamplingRate(subsamplingRate) - .setLossType(loss) - .setMaxIter(maxIter) - .setStepSize(learningRate) - .setSeed(123) - compareAPIs(data, None, gbt, categoricalFeatures) - } + test("GBT-specific param defaults") { + val gbt = new GBTRegressor() + assert(gbt.getImpurity === "variance") + assert(gbt.getLossType === "gaussian") + } + + test("GBT-specific param support") { + val gbt = new GBTRegressor() + for (impurity <- GBTRegressor.supportedImpurities) { + gbt.setImpurity(impurity) + } + for (lossType <- GBTRegressor.supportedLossTypes) { + gbt.setLossType(lossType) + } + } + + def verifyVarianceImpurityAgainstOldAPI(loss: String): Unit = { + // Using a non-loss-based impurity we can just check for equivalence with the old API + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setImpurity("variance") + .setLossType(loss) + .setMaxIter(maxIter) + .setStepSize(learningRate) + .setSeed(123) + compareAPIs(data, None, gbt, Map.empty[Int, Int]) } } + test("Regression: Variance-based impurity and L2 loss") { + verifyVarianceImpurityAgainstOldAPI("squared") + } + + test("Regression: Variance-based impurity and L1 loss") { + verifyVarianceImpurityAgainstOldAPI("absolute") + } + + test("variance impurity") { + val variance = (responses: Seq[Double]) => { + val sum = responses.sum + val sum2 = responses.map(math.pow(_, 2)).sum + val n = responses.size + sum2 / n - math.pow(sum / n, 2) + } + val mean = (prediction: Double, labels: Seq[Double], responses: Seq[Double]) => + responses.map(_ - prediction).sum / responses.size + + GBTSuiteHelper.verifyCalculator( + new VarianceAggregator(), + SquaredError, + expectedImpurity = variance, + expectedPrediction = mean) + } + + test("Regression: loss-based impurity and L2 loss") { + val impurityName = "gaussian" + val loss = SquaredError + val expectedImpurity = new VarianceAggregator() + + GBTSuiteHelper.verifyGBTConstruction( + spark, classification = false, impurityName, loss, expectedImpurity) + } + test("GBTRegressor behaves reasonably on toy data") { val df = Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), @@ -134,6 +185,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext GBTRegressor.supportedLossTypes.foreach { loss => val gbt = new GBTRegressor() .setMaxIter(maxIter) + .setImpurity("variance") .setMaxDepth(2) .setLossType(loss) .setValidationTol(0.0) @@ -181,10 +233,15 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val gbt = new GBTRegressor() val rdd = TreeTests.getTreeReadWriteData(sc) - val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") - val continuousData: DataFrame = - TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + // Test for all different impurity types. + for (impurity <- Seq(Some("loss-based"), Some("variance"), None)) { + val allParamSettings = TreeTests.allParamSettings ++ + Map("lossType" -> "squared") ++ + impurity.map("impurity" -> _).toMap + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + } } } @@ -193,14 +250,18 @@ private object GBTRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + * + * The old API only supports variance-based impurity, so gbt should have that. */ def compareAPIs( data: RDD[LabeledPoint], validationData: Option[RDD[LabeledPoint]], gbt: GBTRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + assert(gbt.getImpurity == "variance") val numFeatures = data.first().features.size - val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldBoostingStrategy = gbt.getOldBoostingStrategy( + categoricalFeatures, OldAlgo.Regression, Variance) val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML)) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/GBTSuiteHelper.scala b/mllib/src/test/scala/org/apache/spark/ml/util/GBTSuiteHelper.scala new file mode 100644 index 000000000000..7e99927a3ab3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/GBTSuiteHelper.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable.ArrayBuffer + +import org.scalactic.TolerantNumerics + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.GBTClassifier +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.regression.GBTRegressor +import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.tree.impurity.{ImpurityAggregator, ImpurityCalculator} +import org.apache.spark.mllib.tree.loss.Loss +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ + +object GBTSuiteHelper extends SparkFunSuite { + /** + * @param labels set of GBT labels + * @param agg the aggregator to use + * @return the calculator from aggregation on the labels + */ + def computeCalculator( + labels: Seq[Double], + agg: ImpurityAggregator): ImpurityCalculator = { + val stats = new Array[Double](agg.statsSize) + labels.foreach(label => agg.update(stats, offset = 0, label, instanceWeight = 1)) + agg.getCalculator(stats, offset = 0) + } + + /** + * Makes sure that the given aggregator reports the expected impurity for a variety + * of different label sets. + * + * @param actualAgg the aggregator to test + * @param loss the loss function for the responses + * @param expectedImpurity a function from the set of responses to the expected impurity + * @param expectedPrediction a function from the triplet + * (previous prediction, labels, responses) to the + * the expected prediction + */ + def verifyCalculator( + actualAgg: ImpurityAggregator, + loss: Loss, + expectedImpurity: Seq[Double] => Double, + expectedPrediction: (Double, Seq[Double], Seq[Double]) => Double): Unit = { + val npoints = 6 + for (cutoff <- 0 to npoints) withClue(s"for cutoff $cutoff\n") { + val labels = (0 until npoints).map(x => if (x < cutoff) -1.0 else 1.0) + val prediction = 0 + val psuedoResiduals = labels.map(x => -loss.gradient(prediction, x)) + + val calculator = computeCalculator(psuedoResiduals, actualAgg) + withClue(s"for calculator $calculator\n") { + assert(calculator.count === npoints) + assert(calculator.calculate() ~== expectedImpurity(psuedoResiduals) absTol 1e-3) + assert(calculator.predict ~== + expectedPrediction(prediction, labels, psuedoResiduals) absTol 1e-3) + } + } + } + + /** + * Makes sure that a GBT is constructed in accordance with its impurity and loss function. + * + * @param spark the spark session + * @param classification whether to use a [[GBTClassifier]] or [[GBTRegressor]] + * @param impurityName name of the impurity to use + * @param loss loss function for gradients + * @param expectedImpurity expected impurity aggregator for statistics + */ + def verifyGBTConstruction( + spark: SparkSession, + classification: Boolean, + impurityName: String, + loss: Loss, + expectedImpurity: ImpurityAggregator): Unit = { + // We create a dataset that can be optimally classified with a root tree + // and one round of gradient boosting. The first tree will not be a perfect classifier, + // so the leaf node predictions will differ for different impurity measures. This is expected + // to be tested on depth-2 trees (7 nodes max). The generated trees should do no + // sub-sampling. + // + // The error is slight enough to force a certain tree structure the first round, + // but still give nontrivial results in both cases. + + val data = new ArrayBuffer[LabeledPoint]() + + // At depth-1, the trees separate 4 intervals. + def addPoints(npoints: Int, label: Double, features: Vector): Unit = { + for (_ <- 0 until npoints) { + data += new LabeledPoint(label, features) + } + } + + // Adds 9 points of label 'label' and 1 point of label '1.0-label' + def addMixedSection(label: Double, feature0: Double, feature1: Double): Unit = { + val pointsPerSection = 10 + val offPoints = 1 + val features = Vectors.dense(feature0, feature1) + val offFeatures = Vectors.dense(feature0, feature1) + addPoints(pointsPerSection - offPoints, label, features) + addPoints(offPoints, 1.0 - label, offFeatures) + } + + for (feature0 <- Seq(0.0, 1.0); feature1 <- Seq(0.0, 1.0)) { + val xor = if (feature0 == feature1) 0.0 else 1.0 + addMixedSection(label = xor, feature0, feature1) + } + addMixedSection(label = 0.0, feature0 = 0.0, feature1 = 0.0) + addMixedSection(label = 1.0, feature0 = 0.0, feature1 = 1.0) + addMixedSection(label = 1.0, feature0 = 1.0, feature1 = 0.0) + addMixedSection(label = 0.0, feature0 = 1.0, feature1 = 1.0) + + // Make splitting on feature 0 slightly more attractive for an initial split + // than the others by making it a slight decent identity predictor while keeping + // other features uninformative. Note that feature 2's predictive power, when + // conditioned on feature 0, is maintained. + for (feature1 <- Seq(0.0, 1.0); label <- Seq(0.0, 1.0)) { + addPoints(npoints = 5, label, Vectors.dense(label, feature1)) + } + + // Convert the input dataframe to a more convenient format to check our results against. + val rawInput = spark.createDataFrame(data) + val vectorAsArray = udf((v: Vector) => v.toArray) + val input = rawInput.select( + col("*"), + vectorAsArray(col("features"))(0).as("feature0"), + vectorAsArray(col("features"))(1).as("feature1")) + + // Classification/regression ambivalent tree retrieval + val infoGain = 1.0 / data.size + val (trees, treeWeights) = if (classification) { + val model = new GBTClassifier() + .setMaxDepth(2) + .setMinInstancesPerNode(1) + .setMinInfoGain(infoGain) + .setSubsamplingRate(1.0) + .setMaxIter(2) + .setStepSize(1) + .setLossType(impurityName) + .fit(rawInput) + (model.trees, model.treeWeights) + } else { + val model = new GBTRegressor() + .setMaxDepth(2) + .setMinInstancesPerNode(1) + .setMinInfoGain(infoGain) + .setSubsamplingRate(1.0) + .setMaxIter(2) + .setStepSize(1) + .setLossType(impurityName) + .fit(rawInput) + (model.trees, model.treeWeights) + } + + assert(trees.length === 2) + assert(treeWeights === Array(1.0, 1.0)) + + // A "feature" with index below 0 is a label + def pointFilter(featureMap: Seq[(Int, Int)]): String = { + if (featureMap.isEmpty) return "true" + val sqlConditions = featureMap.map({ + case (idx, value) if idx < 0 => s"label = $value" + case (idx, value) => s"feature$idx = $value" + }) + sqlConditions.mkString(" and ") + } + + var relabeledDF: DataFrame = null + def trueCalculator(featureMap: (Int, Int)*) = { + implicit val encoder = Encoders.scalaDouble + val df = relabeledDF.where(pointFilter(featureMap)) + val labels = df.select("label").as[Double].collect() + computeCalculator(labels, expectedImpurity) + } + + def verifyImpurity(actualImpurity: ImpurityCalculator, featureMap: Seq[(Int, Int)]): Unit = { + implicit val approxEquals = TolerantNumerics.tolerantDoubleEquality(1e-3) + val expectedCalculator = trueCalculator(featureMap: _*) + withClue(s"actualImpurity $actualImpurity\nexpectedImpurity $expectedCalculator\n\n") { + assert(actualImpurity.count === expectedCalculator.count) + assert(actualImpurity.calculate() ~== expectedCalculator.calculate() absTol 1e-3) + assert(actualImpurity.predict ~== expectedCalculator.predict absTol 1e-3) + } + } + + def verifyInternalNode(node: Node, feature: Int, featureMap: (Int, Int)*): InternalNode = { + withClue(s"node $node\n\nlocation ${featureMap.mkString(" ")}\n\n") { + assert(node.isInstanceOf[InternalNode]) + val internal = node.asInstanceOf[InternalNode] + assert(internal.split.featureIndex === feature) + verifyImpurity(internal.impurityStats, featureMap) + internal + } + } + + def verifyLeafNode(node: Node, featureMap: (Int, Int)*): Unit = { + withClue(s"node $node\n\nlocation ${featureMap.mkString(" ")}\n\n") { + assert(node.isInstanceOf[LeafNode]) + val leaf = node.asInstanceOf[LeafNode] + verifyImpurity(leaf.impurityStats, featureMap) + } + } + + val oldLabel = input.withColumnRenamed("label", "oldlabel") + val transformedLabel = if (classification) col("oldlabel") * 2 - 1 else col("oldlabel") + relabeledDF = oldLabel.withColumn("label", transformedLabel) + withClue(s"Tree 0:\n\n${trees.head.toDebugString}\n") { + val root = verifyInternalNode(trees.head.rootNode, 0) + val left = verifyInternalNode(root.leftChild, 1, 0 -> 0) + val right = verifyInternalNode(root.rightChild, 1, 0 -> 1) + verifyLeafNode(left.leftChild, 0 -> 0, 1 -> 0) + verifyLeafNode(left.rightChild, 0 -> 0, 1 -> 1) + verifyLeafNode(right.leftChild, 0 -> 1, 1 -> 0) + verifyLeafNode(right.rightChild, 0 -> 1, 1 -> 1) + } + + def computeGain(splitFeature: Int, featureMap: (Int, Int)*): Double = { + val preSplit = trueCalculator(featureMap: _*) + val postSplit = Seq(0, 1).map(value => { + val half = trueCalculator((splitFeature -> value) +: featureMap: _*) + half.calculate() * half.count + }).sum / preSplit.count + preSplit.calculate() - postSplit + } + + // The second tree's structure is going to be sensitive to the actual loss + val gradient = udf((pred: Double, label: Double) => -loss.gradient(pred, label)) + relabeledDF = trees.head.transform(oldLabel) + .withColumn("label", gradient(col("prediction"), transformedLabel)) + withClue(s"Tree 1:\n\n${trees.last.toDebugString}\n") { + val rootFeature = Seq(0, 1).maxBy(computeGain(_)) + if (computeGain(rootFeature) < infoGain) { + verifyLeafNode(trees.last.rootNode) + return + } + + val root = verifyInternalNode(trees.last.rootNode, rootFeature) + val otherFeature = 1 - rootFeature + + for (splitValue <- Seq(0, 1)) { + val genericChild = if (splitValue == 0) root.leftChild else root.rightChild + if (computeGain(otherFeature, rootFeature -> splitValue) < infoGain) { + verifyLeafNode(genericChild, rootFeature -> splitValue) + } else { + val child = verifyInternalNode(genericChild, otherFeature, rootFeature -> splitValue) + verifyLeafNode(child.leftChild, rootFeature -> splitValue, otherFeature -> 0) + verifyLeafNode(child.rightChild, rootFeature -> splitValue, otherFeature -> 1) + } + } + } + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 350b144f8294..621decafac4c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -864,6 +864,18 @@ object MimaExcludes { // [SPARK-12221] Add CPU time to metrics ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") + ) ++ Seq( + // [SPARK-16718] GBT model public changes to package-private class + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.getOldBoostingStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.getOldBoostingStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.getOldBoostingStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.getOldBoostingStrategy"), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeRegressorParamsWithDefault.org$apache$spark$ml$tree$TreeRegressorParamsWithDefault$_setter_$impurity_="), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeClassifierParamsWithDefault.org$apache$spark$ml$tree$TreeClassifierParamsWithDefault$_setter_$impurity_="), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeClassifierParamsWithDefault.org$apache$spark$ml$tree$TreeClassifierParamsWithDefault$_setter_$impurity_="), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeRegressorParamsWithDefault.org$apache$spark$ml$tree$TreeRegressorParamsWithDefault$_setter_$impurity_="), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeClassifierParamsWithDefault.org$apache$spark$ml$tree$TreeClassifierParamsWithDefault$_setter_$impurity_="), +ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.TreeRegressorParamsWithDefault.org$apache$spark$ml$tree$TreeRegressorParamsWithDefault$_setter_$impurity_=") ) } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d9ff356b9403..f4087edb421a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -562,7 +562,7 @@ class GBTParams(TreeEnsembleParams): .. versionadded:: 1.4.0 """ - supportedLossTypes = ["logistic"] + supportedLossTypes = ["logistic", "bernoulli"] @inherit_doc @@ -889,6 +889,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 + + .. versionchanged:: 2.0.2 """ lossType = Param(Params._dummy(), "lossType", diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9233d2e7e1a7..416a0f13bf5e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -633,7 +633,7 @@ class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. """ - supportedLossTypes = ["squared", "absolute"] + supportedLossTypes = ["squared", "absolute", "gaussian", "laplace"] @inherit_doc @@ -992,6 +992,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 + + .. versionchanged:: 2.0.2 """ lossType = Param(Params._dummy(), "lossType", @@ -1003,20 +1005,20 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, + checkpointInterval=10, lossType="gaussian", maxIter=20, stepSize=0.1, seed=None, impurity="variance"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ + checkpointInterval=10, lossType="gaussian", maxIter=20, stepSize=0.1, seed=None, \ impurity="variance") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, + checkpointInterval=10, lossType="gaussian", maxIter=20, stepSize=0.1, impurity="variance") kwargs = self.__init__._input_kwargs self.setParams(**kwargs)