Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 145 additions & 97 deletions docs/mllib-decision-tree.md

Large diffs are not rendered by default.

308 changes: 308 additions & 0 deletions docs/mllib-gbt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
---
layout: global
title: Gradient-Boosted Trees - MLlib
displayTitle: <a href="mllib-guide.html">MLlib</a> - Gradient-Boosted Trees
---

* Table of contents
{:toc}

[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting)
are ensembles of [decision trees](mllib-decision-tree.html).
GBTs iteratively train decision trees in order to minimize a loss function.
Like decision trees, GBTs handle categorical features,
extend to the multiclass classification setting, do not require
feature scaling, and are able to capture non-linearities and feature interactions.

MLlib supports GBTs for binary classification and for regression,
using both continuous and categorical features.
MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees.

*Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use
[decision trees](mllib-decision-tree.html) or [Random Forests](mllib-random-forest.html).

## Basic algorithm

Gradient boosting iteratively trains a sequence of decision trees.
On each iteration, the algorithm uses the current ensemble to predict the label of each training instance and then compares the prediction with the true label. The dataset is re-labeled to put more weight on training instances with poor predictions. Thus, in the next iteration, the decision tree will help correct for previous mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight-based relabeling sounds more like AdaBoost.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I'll change "weight" to "emphasis."


The specific weight mechanism is defined by a loss function (discussed below). With each iteration, GBTs further reduce this loss function on the training data.

### Comparison with Random Forests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this section since this is very useful information. We should try and add some graphs here in a separate PR. However, shouldn't this be in a separate section under Ensemble comparing both RF and Boosting algorithms in terms of performance and accuracy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a new Ensembles section in the guide? It might be pretty short. (And I haven't seen experimental results in the guide; would they belong elsewhere?)

Eventually, I could imagine either (a) an Ensembles section once we have more ensemble algs or (b) a section in the guide covering all algorithms and how to choose between them.


Both GBTs and [Random Forests](mllib-random-forest.html) are algorithms for learning ensembles of trees, but the training processes are different. There are several practical trade-offs:

* GBTs may be able to achieve the same accuracy using fewer trees, so the model produced may be smaller (faster for test time prediction).
* GBTs train one tree at a time, so they can take longer to train than random forests. Random Forests can train multiple trees in parallel.
* On the other hand, it is often reasonable to use smaller trees with GBTs than with Random Forests, and training smaller trees takes less time.
* Random Forests can be less prone to overfitting. Training more trees in a Random Forest reduces the likelihood of overfitting, but training more trees with GBTs increases the likelihood of overfitting.

In short, both algorithms can be effective. GBTs may be more useful if test time prediction speed is important. Random Forests are arguably more successful in industry.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid this lest people start quoting us in the future. :-) There is a lot of results comparing RF and Boosting and the results are mixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll say less : )


### Losses

The table below lists the losses currently supported by GBTs in MLlib.
Note that each loss is applicable to one of classification or regression, not both.

Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$.

<table class="table">
<thead>
<tr><th>Loss</th><th>Task</th><th>Formula</th><th>Description</th></tr>
</thead>
<tbody>
<tr>
<td>Log Loss</td>
<td>Classification</td>
<td>$2 \sum_{i=1}^{N} \log(1+\exp(-2 y_i F(x_i)))$</td><td>Twice binomial negative log likelihood.</td>
</tr>
<tr>
<td>Squared Error</td>
<td>Regression</td>
<td>$\sum_{i=1}^{N} (y_i - F(x_i))^2$</td><td>Also called L2 loss. Default loss for regression tasks.</td>
</tr>
<tr>
<td>Absolute Error</td>
<td>Regression</td>
<td>$\sum_{i=1}^{N} |y_i - F(x_i)|$</td><td>Also called L1 loss. Can be more robust to outliers than Squared Error.</td>
</tr>
</tbody>
</table>

## Usage tips

We include a few guidelines for using GBTs by discussing the various parameters.
We omit some decision tree parameters since those are covered in the [decision tree guide](mllib-decision-tree.html).

* **`loss`**: See the section above for information on losses and their applicability to tasks (classification vs. regression). Different losses can give significantly different results, depending on the dataset.

* **`numIterations`**: This sets the number of trees in the ensemble. Each iteration produces one tree. Increasing this number makes the model more expressive, improving training data accuracy. However, test-time accuracy may suffer if this is too large.

* **`learningRate`**: This parameter should not need to be tuned. If the algorithm behavior seems unstable, decreasing this value may improve stability.

* **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter.


## Examples

GBTs currently have APIs in Scala and Java. Examples in both languages are shown below.

### Classification

The example below demonstrates how to load a
[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),
parse it as an RDD of `LabeledPoint` and then
perform classification using Gradient-Boosted Trees with log loss.
The test error is calculated to measure the algorithm accuracy.

<div class="codetabs">

<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClassesForClassification = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification GBT model:\n" + model.toDebugString)
{% endhighlight %}
</div>

<div data-lang="java">
{% highlight java %}
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setNumClassesForClassification(2);
boostingStrategy.getTreeStrategy().setMaxDepth(5);
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

final GradientBoostedTreesModel model =
GradientBoostedTrees.train(trainingData, boostingStrategy);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double testErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification GBT model:\n" + model.toDebugString());
{% endhighlight %}
</div>

</div>

### Regression

The example below demonstrates how to load a
[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),
parse it as an RDD of `LabeledPoint` and then
perform regression using Gradient-Boosted Trees with Squared Error as the loss.
The Mean Squared Error (MSE) is computed at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).

<div class="codetabs">

<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
val boostingStrategy = BoostingStrategy.defaultParams("Regression")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression GBT model:\n" + model.toDebugString)
{% endhighlight %}
</div>

<div data-lang="java">
{% highlight java %}
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Train a GradientBoostedTrees model.
// The defaultParams for Regression use SquaredError by default.
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
boostingStrategy.getTreeStrategy().setMaxDepth(5);
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

final GradientBoostedTreesModel model =
GradientBoostedTrees.train(trainingData, boostingStrategy);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double testMSE =
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override
public Double call(Tuple2<Double, Double> pl) {
Double diff = pl._1() - pl._2();
return diff * diff;
}
}).reduce(new Function2<Double, Double, Double>() {
@Override
public Double call(Double a, Double b) {
return a + b;
}
}) / data.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression GBT model:\n" + model.toDebugString());
{% endhighlight %}
</div>

</div>
28 changes: 27 additions & 1 deletion docs/mllib-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv
* random data generation
* [Classification and regression](mllib-classification-regression.html)
* [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html)
* [decision trees](mllib-decision-tree.html)
* [naive Bayes](mllib-naive-bayes.html)
* [decision trees](mllib-decision-tree.html)
* [random forests](mllib-random-forest.html)
* [gradient-boosted trees](mllib-gbt.html)
* [Collaborative filtering](mllib-collaborative-filtering.html)
* alternating least squares (ALS)
* [Clustering](mllib-clustering.html)
Expand Down Expand Up @@ -60,6 +62,30 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4

# Migration Guide

## From 1.1 to 1.2

The only API changes in MLlib v1.2 are in
[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
which continues to be an experimental API in MLlib 1.2:

1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number
of classes. In MLlib v1.1, this argument was called `numClasses` in Python and
`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`.
This `numClasses` parameter is specified either via
[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
static `trainClassifier` and `trainRegressor` methods.

2. *(Breaking change)* The API for
[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed.
This should generally not affect user code, unless the user manually constructs decision trees
(instead of using the `trainClassifier` or `trainRegressor` methods).
The tree `Node` now includes more information, including the probability of the predicted label
(for classification).

Examples in the Spark distribution and examples in the
[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly.

## From 1.0 to 1.1

The only API changes in MLlib v1.1 are in
Expand Down
Loading