Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a4566a8
Added new documentation for TreeBoost, top-level calls
vlad17 Jul 22, 2016
ce46b10
Implemented ApproxBernoulliImpurity
vlad17 Jul 22, 2016
0311896
Added approximate Bernoulli impurity (L_2 treeboost)
vlad17 Jul 25, 2016
75e9ceb
Added marker saying Laplace Impurity is not yet supported (requires i…
vlad17 Jul 26, 2016
fef6dbb
Updated docs to reflect lack of L1 impurity support
vlad17 Jul 26, 2016
e294499
Fixed urls
vlad17 Jul 26, 2016
6713814
Removed ApproxLaplaceImpurity
vlad17 Jul 26, 2016
33fe35a
Fix reader docs
vlad17 Jul 26, 2016
1a8cbfe
Fixed a bunch of bugs + tested wrt old behavior
vlad17 Jul 26, 2016
347f220
Completed tests for reading/writing new impurities
vlad17 Jul 27, 2016
7731a8f
Finished tests
vlad17 Jul 27, 2016
89acdfc
Added R's gbm as a direct comparison to GBTClassifier
vlad17 Aug 6, 2016
b44f2b1
Got rid of direct R comparison
vlad17 Aug 7, 2016
8449f9c
Direct behavior-checking test (for GBTClassifier)
vlad17 Aug 8, 2016
bc696ee
Added analogous test for GBTReressor
vlad17 Aug 8, 2016
775d991
Cleaned up style-related things
vlad17 Aug 8, 2016
da39cec
Removed weight requirement
vlad17 Aug 9, 2016
ecf08c6
Fixed or ignored binary incompat issues
vlad17 Aug 9, 2016
2d5036f
Changed to variance impurity as default for GBTRegressor
vlad17 Aug 9, 2016
29b1158
defined default python behavior to be mllib (was undefined before)
vlad17 Aug 9, 2016
d36dab3
Modified defaults on pyspark side, too
vlad17 Aug 10, 2016
ca2f505
Addressed partial-pass comments
vlad17 Sep 13, 2016
361606d
Addressed comments (except auto thing)
vlad17 Oct 13, 2016
d3b948b
re-added Mima excludes that got squashed in merge
vlad17 Oct 23, 2016
5f54f4d
More mima excludes, added lots of warnings to not use impurity
vlad17 Nov 1, 2016
4e20a70
Removed depreciation of impurity value for GBTs
vlad17 Nov 1, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,31 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
* It supports both continuous and categorical features.
*
* The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
* The implemention offers both Stochastic Gradient Boosting, as in J.H. Friedman 1999,
* "Stochastic Gradient Boosting" and TreeBoost, as in Friedman 1999
* "Greedy Function Approximation: A Gradient Boosting Machine"
*
* Notes on Gradient Boosting vs. TreeBoost:
* - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
* Notes on Stochastic Gradient Boosting (SGB) vs. TreeBoost:
* - TreeBoost algorithms are a subset of SGB algorithms.
* - Both algorithms learn tree ensembles by minimizing loss functions.
* - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
* based on the loss function, whereas the original gradient boosting method does not.
* - We expect to implement TreeBoost in the future:
* [https://issues.apache.org/jira/browse/SPARK-4240]
* - TreeBoost has two additional properties that general SGB trees don't:
* - The loss function gradients are directly used as an approximate impurity measure.
* - The value reported at a leaf is given by optimizing the loss function is optimized on
* that leaf node's partition of the data, rather than just being the mean.
* - In the case of squared error loss, variance impurity and mean leaf estimates happen
* to make the SGB and TreeBoost algorithms identical.
*
* [[GBTClassifier]] will use the `"loss-based"` impurity by default, conforming to
* TreeBoost behavior. For SGB, set impurity to `"variance"`.
*
* Currently, however, even TreeBoost behavior uses variance impurity for split selection for
* ease and speed. This is the approach `R`'s
* [[https://cran.r-project.org/web/packages/gbm/index.html gbm package]] takes.
*/
@Since("1.4.0")
class GBTClassifier @Since("1.4.0") (
Expand Down Expand Up @@ -91,14 +102,16 @@ class GBTClassifier @Since("1.4.0") (
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)

/**
* The impurity setting is ignored for GBT models.
* Individual trees are built using impurity "Variance."
* Impurity-setting is currently only offered as a way to recover pre-2.0.2 Spark GBT
* behavior (which is Stochastic Gradient Boosting): set impurity to `"variance"` for this.
* @param value new impurity value
* @return this
*/
@Since("1.4.0")
override def setImpurity(value: String): this.type = {
logWarning("GBTClassifier.setImpurity should NOT be used")
this
}
@deprecated(
"Control over impurity will be removed, as it is an implementation detail of GBTs",
"2.0.2")
override def setImpurity(value: String): this.type = super.setImpurity(value)

// Parameters from TreeEnsembleParams:

Expand Down Expand Up @@ -136,7 +149,8 @@ class GBTClassifier @Since("1.4.0") (
LabeledPoint(label, features)
}
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val boostingStrategy = super.getOldBoostingStrategy(
categoricalFeatures, OldAlgo.Classification, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand All @@ -156,11 +170,14 @@ class GBTClassifier @Since("1.4.0") (

@Since("1.4.0")
object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {

/** Accessor for supported loss settings: logistic */
/** Accessor for supported loss settings: logistic, bernoulli */
@Since("1.4.0")
final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes

/** Accessor for supported entropy settings: loss-based or variance */
@Since("2.1")
final val supportedImpurities: Array[String] = GBTClassifierParams.supportedImpurities

@Since("2.0.0")
override def load(path: String): GBTClassifier = super.load(path)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
Expand All @@ -42,21 +41,31 @@ import org.apache.spark.sql.functions._
* learning algorithm for regression.
* It supports both continuous and categorical features.
*
* The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
* The implemention offers both Stochastic Gradient Boosting, as in J.H. Friedman 1999,
* "Stochastic Gradient Boosting" and TreeBoost, as in Friedman 1999
* "Greedy Function Approximation: A Gradient Boosting Machine"
*
* Notes on Gradient Boosting vs. TreeBoost:
* - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
* Notes on Stochastic Gradient Boosting (SGB) vs. TreeBoost:
* - TreeBoost algorithms are a subset of SGB algorithms.
* - Both algorithms learn tree ensembles by minimizing loss functions.
* - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
* based on the loss function, whereas the original gradient boosting method does not.
* - When the loss is SquaredError, these methods give the same result, but they could differ
* for other loss functions.
* - We expect to implement TreeBoost in the future:
* [https://issues.apache.org/jira/browse/SPARK-4240]
* - TreeBoost has two additional properties that general SGB trees don't:
* - The loss function gradients are directly used as an approximate impurity measure.
* - The value reported at a leaf is given by optimizing the loss function is optimized on
* that leaf node's partition of the data, rather than just being the mean.
* - In the case of squared error loss, variance impurity and mean leaf estimates happen
* to make the SGB and TreeBoost algorithms identical.
*
* [[GBTRegressor]] will use the usual `"variance"` impurity by default, conforming to
* SGB behavior. For TreeBoost, set impurity to `"loss-based"`. Note TreeBoost is currently
* incompatible with absolute error.
*
* Currently, however, even TreeBoost behavior uses variance impurity for split selection for
* ease and speed. This is the approach `R`'s
* [[https://cran.r-project.org/web/packages/gbm/index.html gbm package]] takes.
*/
@Since("1.4.0")
class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
extends Regressor[Vector, GBTRegressor, GBTRegressionModel]
with GBTRegressorParams with DefaultParamsWritable with Logging {

@Since("1.4.0")
Expand Down Expand Up @@ -88,14 +97,18 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)

/**
* The impurity setting is ignored for GBT models.
* Individual trees are built using impurity "Variance."
* Note that the loss-based impurity is currently NOT compatible with absolute loss.
*
* Impurity-setting is currently only offered as a way to recover pre-2.0.2 Spark GBT
* behavior (which is Stochastic Gradient Boosting): set impurity to `"variance"` for this.
* @param value new impurity value
* @return this
*/
@Since("1.4.0")
override def setImpurity(value: String): this.type = {
logWarning("GBTRegressor.setImpurity should NOT be used")
this
}
@deprecated(
"Control over impurity will be removed, as it is an implementation detail of GBTs",
"2.0.2")
override def setImpurity(value: String): this.type = super.setImpurity(value)

// Parameters from TreeEnsembleParams:
@Since("1.4.0")
Expand All @@ -113,7 +126,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)

// Parameters from GBTRegressorParams:

/** @group setParam */
/**
* Note that the loss-based impurity is currently NOT compatible with absolute loss.
* @group setParam
*/
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)

Expand All @@ -122,7 +138,8 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val boostingStrategy = super.getOldBoostingStrategy(
categoricalFeatures, OldAlgo.Regression, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand All @@ -141,11 +158,17 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)

@Since("1.4.0")
object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {

/** Accessor for supported loss settings: squared (L2), absolute (L1) */
/**
* Accessor for supported loss settings: squared (L2), absolute (L1),
* gaussian (alias for squared), laplace (alias for absolute)
* */
@Since("1.4.0")
final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes

/** Accessor for support entropy settings: loss-based or variance */
@Since("2.1")
final val supportedImpurities: Array[String] = GBTRegressorParams.supportedImpurities

@Since("2.0.0")
override def load(path: String): GBTRegressor = super.load(path)
}
Expand All @@ -163,7 +186,7 @@ class GBTRegressionModel private[ml](
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
extends RegressionModel[Vector, GBTRegressionModel]
with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tree.impl

import org.apache.spark.ml.tree.impurity._
import org.apache.spark.mllib.tree.impurity._


Expand All @@ -33,11 +34,13 @@ private[spark] class DTStatsAggregator(

/**
* [[ImpurityAggregator]] instance specifying the impurity type.
*/
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
*/
private val impurityAggregator: ImpurityAggregator = metadata.impurity match {
// TODO(SPARK-16728): this switch should be replaced by a virtual call
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case ApproxBernoulliImpurity => new ApproxBernoulliAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ private[spark] object GradientBoostedTrees extends Logging {
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
// Prepare strategy for individual trees, which all use regression.
// TODO(SPARK-16728): changing the strategy here is confusing and should be avoided
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression
treeStrategy.impurity = OldVariance
treeStrategy.assertValid()

// Cache input
Expand Down Expand Up @@ -327,6 +327,8 @@ private[spark] object GradientBoostedTrees extends Logging {
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
// Note: For loss-based impurities, which have optimized loss-based leaf predictions,
// using a constant learning rate is correct.
baseLearnerWeights(m) = learningRate

predError = updatePredictionError(
Expand Down
Loading