Skip to content

Commit 9880bb4

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-4580] [SPARK-4610] [mllib] [docs] Documentation for tree ensembles + DecisionTree API fix
Major changes: * Added programming guide sections for tree ensembles * Added examples for tree ensembles * Updated DecisionTree programming guide with more info on parameters * **API change**: Standardized the tree parameter for the number of classes (for classification) Minor changes: * Updated decision tree documentation * Updated existing tree and tree ensemble examples * Use train/test split, and compute test error instead of training error. * Fixed decision_tree_runner.py to actually use the number of classes it computes from data. (small bug fix) Note: I know this is a lot of lines, but most is covered by: * Programming guide sections for gradient boosting and random forests. (The changes are probably best viewed by generating the docs locally.) * New examples (which were copied from the programming guide) * The "numClasses" renaming I have run all examples and relevant unit tests. CC: mengxr manishamde codedeft Author: Joseph K. Bradley <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes #3461 from jkbradley/ensemble-docs and squashes the following commits: 70a75f3 [Joseph K. Bradley] updated forest vs boosting comparison d1de753 [Joseph K. Bradley] Added note about toString and toDebugString for DecisionTree to migration guide 8e87f8f [Joseph K. Bradley] Combined GBT and RandomForest guides into one ensembles guide 6fab846 [Joseph K. Bradley] small fixes based on review b9f8576 [Joseph K. Bradley] updated decision tree doc 375204c [Joseph K. Bradley] fixed python style 2b60b6e [Joseph K. Bradley] merged Java RandomForest examples into 1 file. added header. Fixed small bug in same example in the programming guide. 706d332 [Joseph K. Bradley] updated python DT runner to print full model if it is small c76c823 [Joseph K. Bradley] added migration guide for mllib abe5ed7 [Joseph K. Bradley] added examples for random forest in Java and Python to examples folder 07fc11d [Joseph K. Bradley] Renamed numClassesForClassification to numClasses everywhere in trees and ensembles. This is a breaking API change, but it was necessary to correct an API inconsistency in Spark 1.1 (where Python DecisionTree used numClasses but Scala used numClassesForClassification). cdfdfbc [Joseph K. Bradley] added examples for GBT 6372a2b [Joseph K. Bradley] updated decision tree examples to use random split. tested all of them. ad3e695 [Joseph K. Bradley] added gbt and random forest to programming guide. still need to update their examples (cherry picked from commit 657a888) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 4259ca8 commit 9880bb4

File tree

19 files changed

+1140
-182
lines changed

19 files changed

+1140
-182
lines changed

docs/mllib-decision-tree.md

Lines changed: 144 additions & 97 deletions
Large diffs are not rendered by default.

docs/mllib-ensembles.md

Lines changed: 653 additions & 0 deletions
Large diffs are not rendered by default.

docs/mllib-guide.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv
1616
* random data generation
1717
* [Classification and regression](mllib-classification-regression.html)
1818
* [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html)
19-
* [decision trees](mllib-decision-tree.html)
2019
* [naive Bayes](mllib-naive-bayes.html)
20+
* [decision trees](mllib-decision-tree.html)
21+
* [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees)
2122
* [Collaborative filtering](mllib-collaborative-filtering.html)
2223
* alternating least squares (ALS)
2324
* [Clustering](mllib-clustering.html)
@@ -60,6 +61,32 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4
6061

6162
# Migration Guide
6263

64+
## From 1.1 to 1.2
65+
66+
The only API changes in MLlib v1.2 are in
67+
[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
68+
which continues to be an experimental API in MLlib 1.2:
69+
70+
1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number
71+
of classes. In MLlib v1.1, this argument was called `numClasses` in Python and
72+
`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`.
73+
This `numClasses` parameter is specified either via
74+
[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
75+
or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
76+
static `trainClassifier` and `trainRegressor` methods.
77+
78+
2. *(Breaking change)* The API for
79+
[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed.
80+
This should generally not affect user code, unless the user manually constructs decision trees
81+
(instead of using the `trainClassifier` or `trainRegressor` methods).
82+
The tree `Node` now includes more information, including the probability of the predicted label
83+
(for classification).
84+
85+
3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`.
86+
87+
Examples in the Spark distribution and examples in the
88+
[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly.
89+
6390
## From 1.0 to 1.1
6491

6592
The only API changes in MLlib v1.1 are in

examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public static void main(String[] args) {
7373
return p.label();
7474
}
7575
}).countByValue().size();
76-
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
76+
boostingStrategy.treeStrategy().setNumClasses(numClasses);
7777

7878
// Train a GradientBoosting model for classification.
7979
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
import scala.Tuple2;
21+
22+
import java.util.HashMap;
23+
24+
import org.apache.spark.SparkConf;
25+
import org.apache.spark.api.java.JavaPairRDD;
26+
import org.apache.spark.api.java.JavaRDD;
27+
import org.apache.spark.api.java.JavaSparkContext;
28+
import org.apache.spark.api.java.function.Function;
29+
import org.apache.spark.api.java.function.Function2;
30+
import org.apache.spark.api.java.function.PairFunction;
31+
import org.apache.spark.mllib.regression.LabeledPoint;
32+
import org.apache.spark.mllib.tree.RandomForest;
33+
import org.apache.spark.mllib.tree.model.RandomForestModel;
34+
import org.apache.spark.mllib.util.MLUtils;
35+
36+
public final class JavaRandomForestExample {
37+
38+
/**
39+
* Note: This example illustrates binary classification.
40+
* For information on multiclass classification, please refer to the JavaDecisionTree.java
41+
* example.
42+
*/
43+
private static void testClassification(JavaRDD<LabeledPoint> trainingData,
44+
JavaRDD<LabeledPoint> testData) {
45+
// Train a RandomForest model.
46+
// Empty categoricalFeaturesInfo indicates all features are continuous.
47+
Integer numClasses = 2;
48+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
49+
Integer numTrees = 3; // Use more in practice.
50+
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
51+
String impurity = "gini";
52+
Integer maxDepth = 4;
53+
Integer maxBins = 32;
54+
Integer seed = 12345;
55+
56+
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
57+
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
58+
seed);
59+
60+
// Evaluate model on test instances and compute test error
61+
JavaPairRDD<Double, Double> predictionAndLabel =
62+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
63+
@Override
64+
public Tuple2<Double, Double> call(LabeledPoint p) {
65+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
66+
}
67+
});
68+
Double testErr =
69+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
70+
@Override
71+
public Boolean call(Tuple2<Double, Double> pl) {
72+
return !pl._1().equals(pl._2());
73+
}
74+
}).count() / testData.count();
75+
System.out.println("Test Error: " + testErr);
76+
System.out.println("Learned classification forest model:\n" + model.toDebugString());
77+
}
78+
79+
private static void testRegression(JavaRDD<LabeledPoint> trainingData,
80+
JavaRDD<LabeledPoint> testData) {
81+
// Train a RandomForest model.
82+
// Empty categoricalFeaturesInfo indicates all features are continuous.
83+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
84+
Integer numTrees = 3; // Use more in practice.
85+
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
86+
String impurity = "variance";
87+
Integer maxDepth = 4;
88+
Integer maxBins = 32;
89+
Integer seed = 12345;
90+
91+
final RandomForestModel model = RandomForest.trainRegressor(trainingData,
92+
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
93+
seed);
94+
95+
// Evaluate model on test instances and compute test error
96+
JavaPairRDD<Double, Double> predictionAndLabel =
97+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
98+
@Override
99+
public Tuple2<Double, Double> call(LabeledPoint p) {
100+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
101+
}
102+
});
103+
Double testMSE =
104+
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
105+
@Override
106+
public Double call(Tuple2<Double, Double> pl) {
107+
Double diff = pl._1() - pl._2();
108+
return diff * diff;
109+
}
110+
}).reduce(new Function2<Double, Double, Double>() {
111+
@Override
112+
public Double call(Double a, Double b) {
113+
return a + b;
114+
}
115+
}) / testData.count();
116+
System.out.println("Test Mean Squared Error: " + testMSE);
117+
System.out.println("Learned regression forest model:\n" + model.toDebugString());
118+
}
119+
120+
public static void main(String[] args) {
121+
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample");
122+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
123+
124+
// Load and parse the data file.
125+
String datapath = "data/mllib/sample_libsvm_data.txt";
126+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
127+
// Split the data into training and test sets (30% held out for testing)
128+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
129+
JavaRDD<LabeledPoint> trainingData = splits[0];
130+
JavaRDD<LabeledPoint> testData = splits[1];
131+
132+
System.out.println("\nRunning example of classification using RandomForest\n");
133+
testClassification(trainingData, testData);
134+
135+
System.out.println("\nRunning example of regression using RandomForest\n");
136+
testRegression(trainingData, testData);
137+
sc.stop();
138+
}
139+
}

examples/src/main/python/mllib/decision_tree_runner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def reindexClassLabels(data):
106106

107107
def usage():
108108
print >> sys.stderr, \
109-
"Usage: decision_tree_runner [libsvm format data filepath]\n" + \
110-
" Note: This only supports binary classification."
109+
"Usage: decision_tree_runner [libsvm format data filepath]"
111110
exit(1)
112111

113112

@@ -127,16 +126,20 @@ def usage():
127126

128127
# Re-index class labels if needed.
129128
(reindexedData, origToNewLabels) = reindexClassLabels(points)
129+
numClasses = len(origToNewLabels)
130130

131131
# Train a classifier.
132132
categoricalFeaturesInfo = {} # no categorical features
133-
model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
133+
model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
134134
categoricalFeaturesInfo=categoricalFeaturesInfo)
135135
# Print learned tree and stats.
136136
print "Trained DecisionTree for classification:"
137-
print " Model numNodes: %d\n" % model.numNodes()
138-
print " Model depth: %d\n" % model.depth()
139-
print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
140-
print model
137+
print " Model numNodes: %d" % model.numNodes()
138+
print " Model depth: %d" % model.depth()
139+
print " Training accuracy: %g" % getAccuracy(model, reindexedData)
140+
if model.numNodes() < 20:
141+
print model.toDebugString()
142+
else:
143+
print model
141144

142145
sc.stop()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Random Forest classification and regression using MLlib.
20+
21+
Note: This example illustrates binary classification.
22+
For information on multiclass classification, please refer to the decision_tree_runner.py
23+
example.
24+
"""
25+
26+
import sys
27+
28+
from pyspark.context import SparkContext
29+
from pyspark.mllib.tree import RandomForest
30+
from pyspark.mllib.util import MLUtils
31+
32+
33+
def testClassification(trainingData, testData):
34+
# Train a RandomForest model.
35+
# Empty categoricalFeaturesInfo indicates all features are continuous.
36+
# Note: Use larger numTrees in practice.
37+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
38+
model = RandomForest.trainClassifier(trainingData, numClasses=2,
39+
categoricalFeaturesInfo={},
40+
numTrees=3, featureSubsetStrategy="auto",
41+
impurity='gini', maxDepth=4, maxBins=32)
42+
43+
# Evaluate model on test instances and compute test error
44+
predictions = model.predict(testData.map(lambda x: x.features))
45+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
46+
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\
47+
/ float(testData.count())
48+
print('Test Error = ' + str(testErr))
49+
print('Learned classification forest model:')
50+
print(model.toDebugString())
51+
52+
53+
def testRegression(trainingData, testData):
54+
# Train a RandomForest model.
55+
# Empty categoricalFeaturesInfo indicates all features are continuous.
56+
# Note: Use larger numTrees in practice.
57+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
58+
model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
59+
numTrees=3, featureSubsetStrategy="auto",
60+
impurity='variance', maxDepth=4, maxBins=32)
61+
62+
# Evaluate model on test instances and compute test error
63+
predictions = model.predict(testData.map(lambda x: x.features))
64+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
65+
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\
66+
/ float(testData.count())
67+
print('Test Mean Squared Error = ' + str(testMSE))
68+
print('Learned regression forest model:')
69+
print(model.toDebugString())
70+
71+
72+
if __name__ == "__main__":
73+
if len(sys.argv) > 1:
74+
print >> sys.stderr, "Usage: random_forest_example"
75+
exit(1)
76+
sc = SparkContext(appName="PythonRandomForestExample")
77+
78+
# Load and parse the data file into an RDD of LabeledPoint.
79+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
80+
# Split the data into training and test sets (30% held out for testing)
81+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
82+
83+
print('\nRunning example of classification using RandomForest\n')
84+
testClassification(trainingData, testData)
85+
86+
print('\nRunning example of regression using RandomForest\n')
87+
testRegression(trainingData, testData)
88+
89+
sc.stop()

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ object DecisionTreeRunner {
278278
impurity = impurityCalculator,
279279
maxDepth = params.maxDepth,
280280
maxBins = params.maxBins,
281-
numClassesForClassification = numClasses,
281+
numClasses = numClasses,
282282
minInstancesPerNode = params.minInstancesPerNode,
283283
minInfoGain = params.minInfoGain,
284284
useNodeIdCache = params.useNodeIdCache,

examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ object GradientBoostedTreesRunner {
103103
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
104104

105105
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
106-
boostingStrategy.treeStrategy.numClassesForClassification = numClasses
106+
boostingStrategy.treeStrategy.numClasses = numClasses
107107
boostingStrategy.numIterations = params.numIterations
108108
boostingStrategy.treeStrategy.maxDepth = params.maxDepth
109109

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class PythonMLLibAPI extends Serializable {
477477
algo = algo,
478478
impurity = impurity,
479479
maxDepth = maxDepth,
480-
numClassesForClassification = numClasses,
480+
numClasses = numClasses,
481481
maxBins = maxBins,
482482
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
483483
minInstancesPerNode = minInstancesPerNode,
@@ -513,7 +513,7 @@ class PythonMLLibAPI extends Serializable {
513513
algo = algo,
514514
impurity = impurity,
515515
maxDepth = maxDepth,
516-
numClassesForClassification = numClasses,
516+
numClasses = numClasses,
517517
maxBins = maxBins,
518518
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
519519
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)

0 commit comments

Comments
 (0)