Skip to content

Commit 57eee9f

Browse files
committed
Created JavaDecisionTree example from example in docs, and corrected doc example as needed.
1 parent d939a92 commit 57eee9f

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

docs/mllib-decision-tree.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ println("Training Error = " + trainErr)
167167

168168
<div data-lang="java">
169169
{% highlight java %}
170+
import scala.Tuple2;
170171
import org.apache.spark.api.java.JavaPairRDD;
171172
import org.apache.spark.api.java.JavaRDD;
172173
import org.apache.spark.api.java.function.Function;
173174
import org.apache.spark.api.java.function.PairFunction;
174175
import org.apache.spark.mllib.regression.LabeledPoint;
175176
import org.apache.spark.mllib.tree.DecisionTree;
176177
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
177-
import scala.Tuple2;
178178

179179
JavaRDD<LabeledPoint> data = ... // data set
180180

@@ -186,7 +186,7 @@ String impurity = "gini";
186186
Integer maxDepth = 5;
187187
Integer maxBins = 100;
188188

189-
final DecisionTreeModel model = DecisionTree.trainClassifier(data.rdd(), numClasses,
189+
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
190190
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
191191

192192
// Evaluate model on training instances and compute training error
@@ -198,9 +198,11 @@ JavaPairRDD<Double, Double> predictionAndLabel =
198198
});
199199
Double trainErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
200200
@Override public Boolean call(Tuple2<Double, Double> pl) {
201-
return pl._1() != pl._2();
201+
return !pl._1().equals(pl._2());
202202
}
203203
}).count() / data.count();
204+
System.out.print("Training error: " + trainErr);
205+
System.out.print("Learned model:\n" + model);
204206
{% endhighlight %}
205207
</div>
206208

@@ -289,7 +291,7 @@ String impurity = "variance";
289291
Integer maxDepth = 5;
290292
Integer maxBins = 100;
291293

292-
final DecisionTreeModel model = DecisionTree.trainRegressor(data.rdd(),
294+
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
293295
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
294296

295297
// Evaluate model on training instances and compute training error
@@ -305,6 +307,8 @@ Double trainMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Do
305307
return diff * diff;
306308
}
307309
}).sum() / data.count();
310+
System.out.print("Training Mean Squared Error: " + trainMSE);
311+
System.out.print("Learned model:\n" + model);
308312
{% endhighlight %}
309313
</div>
310314

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
import java.util.HashMap;
21+
22+
import scala.reflect.ClassTag;
23+
import scala.Tuple2;
24+
25+
import org.apache.spark.api.java.function.Function2;
26+
import org.apache.spark.api.java.JavaPairRDD;
27+
import org.apache.spark.api.java.JavaRDD;
28+
import org.apache.spark.api.java.JavaSparkContext;
29+
import org.apache.spark.api.java.function.Function;
30+
import org.apache.spark.api.java.function.PairFunction;
31+
import org.apache.spark.mllib.regression.LabeledPoint;
32+
import org.apache.spark.mllib.tree.DecisionTree;
33+
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
34+
import org.apache.spark.mllib.util.MLUtils;
35+
import org.apache.spark.SparkConf;
36+
37+
38+
/**
39+
* Classification and regression using decision trees.
40+
*/
41+
public final class JavaDecisionTree {
42+
43+
public static void main(String[] args) {
44+
if (args.length != 1) {
45+
System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
46+
System.exit(1);
47+
}
48+
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
49+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
50+
String datapath = args[0];
51+
52+
JavaRDD<LabeledPoint> data = JavaRDD.fromRDD(MLUtils.loadLibSVMFile(sc.sc(), datapath));
53+
54+
// Compute the number of classes from the data.
55+
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
56+
@Override public Double call(LabeledPoint p) {
57+
return p.label();
58+
}
59+
}).countByValue().size();
60+
// Empty categoricalFeaturesInfo indicates all features are continuous.
61+
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
62+
String impurity = "gini";
63+
Integer maxDepth = 5;
64+
Integer maxBins = 100;
65+
66+
// Train a DecisionTree model for classification.
67+
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
68+
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
69+
70+
// Evaluate model on training instances and compute training error
71+
JavaPairRDD<Double, Double> predictionAndLabel =
72+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
73+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
74+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
75+
}
76+
});
77+
Double trainErr =
78+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
79+
@Override public Boolean call(Tuple2<Double, Double> pl) {
80+
return !pl._1().equals(pl._2());
81+
}
82+
}).count() / data.count();
83+
System.out.print("Training error: " + trainErr);
84+
System.out.print("Learned classification tree model:\n" + model);
85+
86+
// Train a DecisionTree model for regression.
87+
impurity = "variance";
88+
89+
final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
90+
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
91+
92+
// Evaluate model on training instances and compute training error
93+
JavaPairRDD<Double, Double> regressorPredictionAndLabel =
94+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
95+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
96+
return new Tuple2<Double, Double>(regressionModel.predict(p.features()), p.label());
97+
}
98+
});
99+
Double trainMSE =
100+
regressorPredictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
101+
@Override public Double call(Tuple2<Double, Double> pl) {
102+
Double diff = pl._1() - pl._2();
103+
return diff * diff;
104+
}
105+
}).reduce(new Function2<Double, Double, Double>() {
106+
@Override public Double call(Double a, Double b) {
107+
return a + b;
108+
}
109+
}) / data.count();
110+
System.out.print("Training Mean Squared Error: " + trainMSE);
111+
System.out.print("Learned regression tree model:\n" + regressionModel);
112+
113+
sc.stop();
114+
}
115+
}

0 commit comments

Comments
 (0)