Skip to content

Commit e575bd8

Browse files
committed
Merge remote-tracking branch 'upstream/master' into als-bugfix
2 parents 9401b16 + 5b3b6f6 commit e575bd8

File tree

15 files changed

+1059
-295
lines changed

15 files changed

+1059
-295
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
import scala.Tuple2;
21+
22+
import org.apache.spark.SparkConf;
23+
import org.apache.spark.api.java.JavaPairRDD;
24+
import org.apache.spark.api.java.JavaRDD;
25+
import org.apache.spark.api.java.JavaSparkContext;
26+
import org.apache.spark.api.java.function.Function;
27+
import org.apache.spark.api.java.function.Function2;
28+
import org.apache.spark.api.java.function.PairFunction;
29+
import org.apache.spark.mllib.regression.LabeledPoint;
30+
import org.apache.spark.mllib.tree.GradientBoosting;
31+
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
32+
import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
33+
import org.apache.spark.mllib.util.MLUtils;
34+
35+
/**
36+
* Classification and regression using gradient-boosted decision trees.
37+
*/
38+
public final class JavaGradientBoostedTrees {
39+
40+
private static void usage() {
41+
System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" +
42+
" <Classification/Regression>");
43+
System.exit(-1);
44+
}
45+
46+
public static void main(String[] args) {
47+
String datapath = "data/mllib/sample_libsvm_data.txt";
48+
String algo = "Classification";
49+
if (args.length >= 1) {
50+
datapath = args[0];
51+
}
52+
if (args.length >= 2) {
53+
algo = args[1];
54+
}
55+
if (args.length > 2) {
56+
usage();
57+
}
58+
SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
59+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
60+
61+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
62+
63+
// Set parameters.
64+
// Note: All features are treated as continuous.
65+
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
66+
boostingStrategy.setNumIterations(10);
67+
boostingStrategy.weakLearnerParams().setMaxDepth(5);
68+
69+
if (algo.equals("Classification")) {
70+
// Compute the number of classes from the data.
71+
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
72+
@Override public Double call(LabeledPoint p) {
73+
return p.label();
74+
}
75+
}).countByValue().size();
76+
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
77+
78+
// Train a GradientBoosting model for classification.
79+
final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
80+
81+
// Evaluate model on training instances and compute training error
82+
JavaPairRDD<Double, Double> predictionAndLabel =
83+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
84+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
85+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
86+
}
87+
});
88+
Double trainErr =
89+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
90+
@Override public Boolean call(Tuple2<Double, Double> pl) {
91+
return !pl._1().equals(pl._2());
92+
}
93+
}).count() / data.count();
94+
System.out.println("Training error: " + trainErr);
95+
System.out.println("Learned classification tree model:\n" + model);
96+
} else if (algo.equals("Regression")) {
97+
// Train a GradientBoosting model for classification.
98+
final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
99+
100+
// Evaluate model on training instances and compute training error
101+
JavaPairRDD<Double, Double> predictionAndLabel =
102+
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
103+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
104+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
105+
}
106+
});
107+
Double trainMSE =
108+
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
109+
@Override public Double call(Tuple2<Double, Double> pl) {
110+
Double diff = pl._1() - pl._2();
111+
return diff * diff;
112+
}
113+
}).reduce(new Function2<Double, Double, Double>() {
114+
@Override public Double call(Double a, Double b) {
115+
return a + b;
116+
}
117+
}) / data.count();
118+
System.out.println("Training Mean Squared Error: " + trainMSE);
119+
System.out.println("Learned regression tree model:\n" + model);
120+
} else {
121+
usage();
122+
}
123+
124+
sc.stop();
125+
}
126+
}

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,30 @@ object DecisionTreeRunner {
154154
}
155155
}
156156

157-
def run(params: Params) {
158-
159-
val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
160-
val sc = new SparkContext(conf)
161-
162-
println(s"DecisionTreeRunner with parameters:\n$params")
163-
157+
/**
158+
* Load training and test data from files.
159+
* @param input Path to input dataset.
160+
* @param dataFormat "libsvm" or "dense"
161+
* @param testInput Path to test dataset.
162+
* @param algo Classification or Regression
163+
* @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
164+
* @return (training dataset, test dataset, number of classes),
165+
* where the number of classes is inferred from data (and set to 0 for Regression)
166+
*/
167+
private[mllib] def loadDatasets(
168+
sc: SparkContext,
169+
input: String,
170+
dataFormat: String,
171+
testInput: String,
172+
algo: Algo,
173+
fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
164174
// Load training data and cache it.
165-
val origExamples = params.dataFormat match {
166-
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
167-
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
175+
val origExamples = dataFormat match {
176+
case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
177+
case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
168178
}
169179
// For classification, re-index classes if needed.
170-
val (examples, classIndexMap, numClasses) = params.algo match {
180+
val (examples, classIndexMap, numClasses) = algo match {
171181
case Classification => {
172182
// classCounts: class --> # examples in class
173183
val classCounts = origExamples.map(_.label).countByValue()
@@ -205,14 +215,14 @@ object DecisionTreeRunner {
205215
}
206216

207217
// Create training, test sets.
208-
val splits = if (params.testInput != "") {
218+
val splits = if (testInput != "") {
209219
// Load testInput.
210220
val numFeatures = examples.take(1)(0).features.size
211-
val origTestExamples = params.dataFormat match {
212-
case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
213-
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
221+
val origTestExamples = dataFormat match {
222+
case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
223+
case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
214224
}
215-
params.algo match {
225+
algo match {
216226
case Classification => {
217227
// classCounts: class --> # examples in class
218228
val testExamples = {
@@ -229,17 +239,31 @@ object DecisionTreeRunner {
229239
}
230240
} else {
231241
// Split input into training, test.
232-
examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
242+
examples.randomSplit(Array(1.0 - fracTest, fracTest))
233243
}
234244
val training = splits(0).cache()
235245
val test = splits(1).cache()
246+
236247
val numTraining = training.count()
237248
val numTest = test.count()
238-
239249
println(s"numTraining = $numTraining, numTest = $numTest.")
240250

241251
examples.unpersist(blocking = false)
242252

253+
(training, test, numClasses)
254+
}
255+
256+
def run(params: Params) {
257+
258+
val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
259+
val sc = new SparkContext(conf)
260+
261+
println(s"DecisionTreeRunner with parameters:\n$params")
262+
263+
// Load training and test data and cache it.
264+
val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
265+
params.testInput, params.algo, params.fracTest)
266+
243267
val impurityCalculator = params.impurity match {
244268
case Gini => impurity.Gini
245269
case Entropy => impurity.Entropy
@@ -338,7 +362,9 @@ object DecisionTreeRunner {
338362
/**
339363
* Calculates the mean squared error for regression.
340364
*/
341-
private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
365+
private[mllib] def meanSquaredError(
366+
tree: WeightedEnsembleModel,
367+
data: RDD[LabeledPoint]): Double = {
342368
data.map { y =>
343369
val err = tree.predict(y.features) - y.label
344370
err * err
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import scopt.OptionParser
21+
22+
import org.apache.spark.{SparkConf, SparkContext}
23+
import org.apache.spark.mllib.evaluation.MulticlassMetrics
24+
import org.apache.spark.mllib.tree.GradientBoosting
25+
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
26+
import org.apache.spark.util.Utils
27+
28+
/**
29+
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
30+
* {{{
31+
* ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
32+
* }}}
33+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
34+
*
35+
* Note: This script treats all features as real-valued (not categorical).
36+
* To include categorical features, modify categoricalFeaturesInfo.
37+
*/
38+
object GradientBoostedTrees {
39+
40+
case class Params(
41+
input: String = null,
42+
testInput: String = "",
43+
dataFormat: String = "libsvm",
44+
algo: String = "Classification",
45+
maxDepth: Int = 5,
46+
numIterations: Int = 10,
47+
fracTest: Double = 0.2) extends AbstractParams[Params]
48+
49+
def main(args: Array[String]) {
50+
val defaultParams = Params()
51+
52+
val parser = new OptionParser[Params]("GradientBoostedTrees") {
53+
head("GradientBoostedTrees: an example decision tree app.")
54+
opt[String]("algo")
55+
.text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
56+
.action((x, c) => c.copy(algo = x))
57+
opt[Int]("maxDepth")
58+
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
59+
.action((x, c) => c.copy(maxDepth = x))
60+
opt[Int]("numIterations")
61+
.text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
62+
.action((x, c) => c.copy(numIterations = x))
63+
opt[Double]("fracTest")
64+
.text(s"fraction of data to hold out for testing. If given option testInput, " +
65+
s"this option is ignored. default: ${defaultParams.fracTest}")
66+
.action((x, c) => c.copy(fracTest = x))
67+
opt[String]("testInput")
68+
.text(s"input path to test dataset. If given, option fracTest is ignored." +
69+
s" default: ${defaultParams.testInput}")
70+
.action((x, c) => c.copy(testInput = x))
71+
opt[String]("<dataFormat>")
72+
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
73+
.action((x, c) => c.copy(dataFormat = x))
74+
arg[String]("<input>")
75+
.text("input path to labeled examples")
76+
.required()
77+
.action((x, c) => c.copy(input = x))
78+
checkConfig { params =>
79+
if (params.fracTest < 0 || params.fracTest > 1) {
80+
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
81+
} else {
82+
success
83+
}
84+
}
85+
}
86+
87+
parser.parse(args, defaultParams).map { params =>
88+
run(params)
89+
}.getOrElse {
90+
sys.exit(1)
91+
}
92+
}
93+
94+
def run(params: Params) {
95+
96+
val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
97+
val sc = new SparkContext(conf)
98+
99+
println(s"GradientBoostedTrees with parameters:\n$params")
100+
101+
// Load training and test data and cache it.
102+
val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
103+
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
104+
105+
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
106+
boostingStrategy.numClassesForClassification = numClasses
107+
boostingStrategy.numIterations = params.numIterations
108+
boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
109+
110+
val randomSeed = Utils.random.nextInt()
111+
if (params.algo == "Classification") {
112+
val startTime = System.nanoTime()
113+
val model = GradientBoosting.trainClassifier(training, boostingStrategy)
114+
val elapsedTime = (System.nanoTime() - startTime) / 1e9
115+
println(s"Training time: $elapsedTime seconds")
116+
if (model.totalNumNodes < 30) {
117+
println(model.toDebugString) // Print full model.
118+
} else {
119+
println(model) // Print model summary.
120+
}
121+
val trainAccuracy =
122+
new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
123+
.precision
124+
println(s"Train accuracy = $trainAccuracy")
125+
val testAccuracy =
126+
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
127+
println(s"Test accuracy = $testAccuracy")
128+
} else if (params.algo == "Regression") {
129+
val startTime = System.nanoTime()
130+
val model = GradientBoosting.trainRegressor(training, boostingStrategy)
131+
val elapsedTime = (System.nanoTime() - startTime) / 1e9
132+
println(s"Training time: $elapsedTime seconds")
133+
if (model.totalNumNodes < 30) {
134+
println(model.toDebugString) // Print full model.
135+
} else {
136+
println(model) // Print model summary.
137+
}
138+
val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
139+
println(s"Train mean squared error = $trainMSE")
140+
val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
141+
println(s"Test mean squared error = $testMSE")
142+
}
143+
144+
sc.stop()
145+
}
146+
}

0 commit comments

Comments
 (0)