Skip to content

Commit abe5ed7

Browse files
committed
added examples for random forest in Java and Python to examples folder
1 parent 07fc11d commit abe5ed7

File tree

6 files changed

+231
-6
lines changed

6 files changed

+231
-6
lines changed

docs/mllib-decision-tree.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ MLlib supports decision trees for binary and multiclass classification and for r
1919
using both continuous and categorical features. The implementation partitions data by rows,
2020
allowing distributed training with millions of instances.
2121

22-
Ensembles of trees are described in [random forests](mllib-random-forest.html) and
23-
[gradient-boosted trees](mllib-gbt.html).
22+
Ensembles of trees are described in the [Random Forest guide](mllib-random-forest.html) and
23+
[Gradient-Boosted Trees guide](mllib-gbt.html).
2424

2525
## Basic algorithm
2626

docs/mllib-gbt.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = fea
5959
<tr>
6060
<td>Squared Error</td>
6161
<td>Regression</td>
62-
<td>$\sum_{i=1}^{N} \frac{1}{2} (y_i - F(x_i))^2$</td><td>Also called L2 loss. Default loss for regression tasks.</td>
62+
<td>$\sum_{i=1}^{N} (y_i - F(x_i))^2$</td><td>Also called L2 loss. Default loss for regression tasks.</td>
6363
</tr>
6464
<tr>
6565
<td>Absolute Error</td>
@@ -85,6 +85,8 @@ We omit some decision tree parameters since those are covered in the [decision t
8585

8686
## Examples
8787

88+
GBTs currently have APIs in Scala and Java. Examples in both languages are shown below.
89+
8890
### Classification
8991

9092
The example below demonstrates how to load a
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package org.apache.spark.examples.mllib;
2+
3+
import scala.Tuple2;
4+
5+
import java.util.HashMap;
6+
7+
import org.apache.spark.SparkConf;
8+
import org.apache.spark.api.java.JavaPairRDD;
9+
import org.apache.spark.api.java.JavaRDD;
10+
import org.apache.spark.api.java.JavaSparkContext;
11+
import org.apache.spark.api.java.function.Function;
12+
import org.apache.spark.api.java.function.PairFunction;
13+
import org.apache.spark.mllib.regression.LabeledPoint;
14+
import org.apache.spark.mllib.tree.RandomForest;
15+
import org.apache.spark.mllib.tree.model.RandomForestModel;
16+
import org.apache.spark.mllib.util.MLUtils;
17+
18+
public final class JavaRandomForestClassification {
19+
20+
public static void main(String[] args) {
21+
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification");
22+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
23+
24+
// Load and parse the data file.
25+
String datapath = "data/mllib/sample_libsvm_data.txt";
26+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
27+
// Split the data into training and test sets (30% held out for testing)
28+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
29+
JavaRDD<LabeledPoint> trainingData = splits[0];
30+
JavaRDD<LabeledPoint> testData = splits[1];
31+
32+
// Train a RandomForest model.
33+
// Empty categoricalFeaturesInfo indicates all features are continuous.
34+
Integer numClasses = 2;
35+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
36+
Integer numTrees = 3; // Use more in practice.
37+
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
38+
String impurity = "gini";
39+
Integer maxDepth = 4;
40+
Integer maxBins = 32;
41+
Integer seed = 12345;
42+
43+
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
44+
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
45+
seed);
46+
47+
// Evaluate model on test instances and compute test error
48+
JavaPairRDD<Double, Double> predictionAndLabel =
49+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
50+
@Override
51+
public Tuple2<Double, Double> call(LabeledPoint p) {
52+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
53+
}
54+
});
55+
Double testErr =
56+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
57+
@Override
58+
public Boolean call(Tuple2<Double, Double> pl) {
59+
return !pl._1().equals(pl._2());
60+
}
61+
}).count() / testData.count();
62+
System.out.println("Test Error: " + testErr);
63+
System.out.println("Learned classification forest model:\n" + model.toDebugString());
64+
}
65+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package org.apache.spark.examples.mllib;
2+
3+
import scala.Tuple2;
4+
5+
import java.util.HashMap;
6+
7+
import org.apache.spark.SparkConf;
8+
import org.apache.spark.api.java.JavaPairRDD;
9+
import org.apache.spark.api.java.JavaRDD;
10+
import org.apache.spark.api.java.JavaSparkContext;
11+
import org.apache.spark.api.java.function.Function;
12+
import org.apache.spark.api.java.function.Function2;
13+
import org.apache.spark.api.java.function.PairFunction;
14+
import org.apache.spark.mllib.regression.LabeledPoint;
15+
import org.apache.spark.mllib.tree.RandomForest;
16+
import org.apache.spark.mllib.tree.model.RandomForestModel;
17+
import org.apache.spark.mllib.util.MLUtils;
18+
19+
public final class JavaRandomForestRegression {
20+
21+
public static void main(String[] args) {
22+
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegression");
23+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
24+
25+
// Load and parse the data file.
26+
String datapath = "data/mllib/sample_libsvm_data.txt";
27+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
28+
// Split the data into training and test sets (30% held out for testing)
29+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
30+
JavaRDD<LabeledPoint> trainingData = splits[0];
31+
JavaRDD<LabeledPoint> testData = splits[1];
32+
33+
// Train a RandomForest model.
34+
// Empty categoricalFeaturesInfo indicates all features are continuous.
35+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
36+
Integer numTrees = 3; // Use more in practice.
37+
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
38+
String impurity = "variance";
39+
Integer maxDepth = 4;
40+
Integer maxBins = 32;
41+
Integer seed = 12345;
42+
43+
final RandomForestModel model = RandomForest.trainRegressor(trainingData,
44+
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
45+
seed);
46+
47+
// Evaluate model on test instances and compute test error
48+
JavaPairRDD<Double, Double> predictionAndLabel =
49+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
50+
@Override
51+
public Tuple2<Double, Double> call(LabeledPoint p) {
52+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
53+
}
54+
});
55+
Double testMSE =
56+
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
57+
@Override
58+
public Double call(Tuple2<Double, Double> pl) {
59+
Double diff = pl._1() - pl._2();
60+
return diff * diff;
61+
}
62+
}).reduce(new Function2<Double, Double, Double>() {
63+
@Override
64+
public Double call(Double a, Double b) {
65+
return a + b;
66+
}
67+
}) / data.count();
68+
System.out.println("Test Mean Squared Error: " + testMSE);
69+
System.out.println("Learned regression forest model:\n" + model.toDebugString());
70+
}
71+
}

examples/src/main/python/mllib/decision_tree_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def reindexClassLabels(data):
106106

107107
def usage():
108108
print >> sys.stderr, \
109-
"Usage: decision_tree_runner [libsvm format data filepath]\n" + \
110-
" Note: This only supports binary classification."
109+
"Usage: decision_tree_runner [libsvm format data filepath]"
111110
exit(1)
112111

113112

@@ -127,10 +126,11 @@ def usage():
127126

128127
# Re-index class labels if needed.
129128
(reindexedData, origToNewLabels) = reindexClassLabels(points)
129+
numClasses = len(origToNewLabels)
130130

131131
# Train a classifier.
132132
categoricalFeaturesInfo = {} # no categorical features
133-
model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
133+
model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses,
134134
categoricalFeaturesInfo=categoricalFeaturesInfo)
135135
# Print learned tree and stats.
136136
print "Trained DecisionTree for classification:"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Random Forest classification and regression using MLlib.
20+
21+
Note: This example illustrates binary classification.
22+
For information on multiclass classification, please refer to the decision_tree_runner.py
23+
example.
24+
"""
25+
26+
import sys
27+
28+
from pyspark.context import SparkContext
29+
from pyspark.mllib.tree import RandomForest
30+
from pyspark.mllib.util import MLUtils
31+
32+
def testClassification(trainingData, testData):
33+
# Train a RandomForest model.
34+
# Empty categoricalFeaturesInfo indicates all features are continuous.
35+
# Note: Use larger numTrees in practice.
36+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
37+
model = RandomForest.trainClassifier(trainingData, numClasses=2,
38+
categoricalFeaturesInfo={},
39+
numTrees=3, featureSubsetStrategy="auto",
40+
impurity='gini', maxDepth=4, maxBins=32)
41+
42+
# Evaluate model on test instances and compute test error
43+
predictions = model.predict(testData.map(lambda x: x.features))
44+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
45+
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\
46+
/ float(testData.count())
47+
print('Test Error = ' + str(testErr))
48+
print('Learned classification forest model:')
49+
print(model.toDebugString())
50+
51+
def testRegression(trainingData, testData):
52+
# Train a RandomForest model.
53+
# Empty categoricalFeaturesInfo indicates all features are continuous.
54+
# Note: Use larger numTrees in practice.
55+
# Setting featureSubsetStrategy="auto" lets the algorithm choose.
56+
model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
57+
numTrees=3, featureSubsetStrategy="auto",
58+
impurity='variance', maxDepth=4, maxBins=32)
59+
60+
# Evaluate model on test instances and compute test error
61+
predictions = model.predict(testData.map(lambda x: x.features))
62+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
63+
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\
64+
/ float(testData.count())
65+
print('Test Mean Squared Error = ' + str(testMSE))
66+
print('Learned regression forest model:')
67+
print(model.toDebugString())
68+
69+
70+
if __name__ == "__main__":
71+
if len(sys.argv) > 1:
72+
print >> sys.stderr, "Usage: random_forest_example"
73+
exit(1)
74+
sc = SparkContext(appName="PythonRandomForestExample")
75+
76+
# Load and parse the data file into an RDD of LabeledPoint.
77+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
78+
# Split the data into training and test sets (30% held out for testing)
79+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
80+
81+
print('\nRunning example of classification using RandomForest\n')
82+
testClassification(trainingData, testData)
83+
84+
print('\nRunning example of regression using RandomForest\n')
85+
testRegression(trainingData, testData)
86+
87+
sc.stop()

0 commit comments

Comments
 (0)