Skip to content

Commit 6372a2b

Browse files
committed
updated decision tree examples to use random split. tested all of them.
1 parent ad3e695 commit 6372a2b

File tree

2 files changed

+79
-63
lines changed

2 files changed

+79
-63
lines changed

docs/mllib-decision-tree.md

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ The example below demonstrates how to load a
151151
[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),
152152
parse it as an RDD of `LabeledPoint` and then
153153
perform classification using a decision tree with Gini impurity as an impurity measure and a
154-
maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
154+
maximum tree depth of 5. The test error is calculated to measure the algorithm accuracy.
155155

156156
<div class="codetabs">
157157

@@ -161,8 +161,10 @@ import org.apache.spark.mllib.tree.DecisionTree
161161
import org.apache.spark.mllib.util.MLUtils
162162

163163
// Load and parse the data file.
164-
// Cache the data since we will use it again to compute training error.
165-
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
164+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
165+
// Split the data into training and test sets (30% held out for testing)
166+
val splits = data.randomSplit(Array(0.7, 0.3))
167+
val (trainingData, testData) = (splits(0), splits(1))
166168

167169
// Train a DecisionTree model.
168170
// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -172,25 +174,24 @@ val impurity = "gini"
172174
val maxDepth = 5
173175
val maxBins = 32
174176

175-
val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
176-
maxDepth, maxBins)
177+
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
178+
impurity, maxDepth, maxBins)
177179

178-
// Evaluate model on training instances and compute training error
179-
val labelAndPreds = data.map { point =>
180+
// Evaluate model on test instances and compute test error
181+
val labelAndPreds = testData.map { point =>
180182
val prediction = model.predict(point.features)
181183
(point.label, prediction)
182184
}
183-
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count
184-
println("Training Error = " + trainErr)
185-
println("Learned classification tree model:\n" + model)
185+
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
186+
println("Test Error = " + testErr)
187+
println("Learned classification tree model:\n" + model.toDebugString)
186188
{% endhighlight %}
187189
</div>
188190

189191
<div data-lang="java">
190192
{% highlight java %}
191193
import java.util.HashMap;
192194
import scala.Tuple2;
193-
import org.apache.spark.api.java.function.Function2;
194195
import org.apache.spark.api.java.JavaPairRDD;
195196
import org.apache.spark.api.java.JavaRDD;
196197
import org.apache.spark.api.java.JavaSparkContext;
@@ -206,9 +207,12 @@ SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
206207
JavaSparkContext sc = new JavaSparkContext(sparkConf);
207208

208209
// Load and parse the data file.
209-
// Cache the data since we will use it again to compute training error.
210210
String datapath = "data/mllib/sample_libsvm_data.txt";
211-
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
211+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
212+
// Split the data into training and test sets (30% held out for testing)
213+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
214+
JavaRDD<LabeledPoint> trainingData = splits[0];
215+
JavaRDD<LabeledPoint> testData = splits[1];
212216

213217
// Set parameters.
214218
// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -219,24 +223,26 @@ Integer maxDepth = 5;
219223
Integer maxBins = 32;
220224

221225
// Train a DecisionTree model for classification.
222-
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
226+
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
223227
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
224228

225-
// Evaluate model on training instances and compute training error
229+
// Evaluate model on test instances and compute test error
226230
JavaPairRDD<Double, Double> predictionAndLabel =
227-
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
228-
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
231+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
232+
@Override
233+
public Tuple2<Double, Double> call(LabeledPoint p) {
229234
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
230235
}
231236
});
232-
Double trainErr =
237+
Double testErr =
233238
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
234-
@Override public Boolean call(Tuple2<Double, Double> pl) {
239+
@Override
240+
public Boolean call(Tuple2<Double, Double> pl) {
235241
return !pl._1().equals(pl._2());
236242
}
237-
}).count() / data.count();
238-
System.out.println("Training error: " + trainErr);
239-
System.out.println("Learned classification tree model:\n" + model);
243+
}).count() / testData.count();
244+
System.out.println("Test Error: " + testErr);
245+
System.out.println("Learned classification tree model:\n" + model.toDebugString());
240246
{% endhighlight %}
241247
</div>
242248

@@ -247,21 +253,22 @@ from pyspark.mllib.tree import DecisionTree
247253
from pyspark.mllib.util import MLUtils
248254

249255
# Load and parse the data file into an RDD of LabeledPoint.
250-
# Cache the data since we will use it again to compute training error.
251-
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
256+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
257+
# Split the data into training and test sets (30% held out for testing)
258+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
252259

253260
# Train a DecisionTree model.
254261
# Empty categoricalFeaturesInfo indicates all features are continuous.
255-
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
262+
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
256263
impurity='gini', maxDepth=5, maxBins=32)
257264

258-
# Evaluate model on training instances and compute training error
259-
predictions = model.predict(data.map(lambda x: x.features))
260-
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
261-
trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count())
262-
print('Training Error = ' + str(trainErr))
265+
# Evaluate model on test instances and compute test error
266+
predictions = model.predict(testData.map(lambda x: x.features))
267+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
268+
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
269+
print('Test Error = ' + str(testErr))
263270
print('Learned classification tree model:')
264-
print(model)
271+
print(model.toDebugString())
265272
{% endhighlight %}
266273

267274
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
@@ -288,8 +295,10 @@ import org.apache.spark.mllib.tree.DecisionTree
288295
import org.apache.spark.mllib.util.MLUtils
289296

290297
// Load and parse the data file.
291-
// Cache the data since we will use it again to compute training error.
292-
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
298+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
299+
// Split the data into training and test sets (30% held out for testing)
300+
val splits = data.randomSplit(Array(0.7, 0.3))
301+
val (trainingData, testData) = (splits(0), splits(1))
293302

294303
// Train a DecisionTree model.
295304
// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -298,17 +307,17 @@ val impurity = "variance"
298307
val maxDepth = 5
299308
val maxBins = 32
300309

301-
val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
310+
val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
302311
maxDepth, maxBins)
303312

304-
// Evaluate model on training instances and compute training error
305-
val labelsAndPredictions = data.map { point =>
313+
// Evaluate model on test instances and compute test error
314+
val labelsAndPredictions = testData.map { point =>
306315
val prediction = model.predict(point.features)
307316
(point.label, prediction)
308317
}
309-
val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
310-
println("Training Mean Squared Error = " + trainMSE)
311-
println("Learned regression tree model:\n" + model)
318+
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
319+
println("Test Mean Squared Error = " + testMSE)
320+
println("Learned regression tree model:\n" + model.toDebugString)
312321
{% endhighlight %}
313322
</div>
314323

@@ -328,14 +337,17 @@ import org.apache.spark.mllib.tree.model.DecisionTreeModel;
328337
import org.apache.spark.mllib.util.MLUtils;
329338
import org.apache.spark.SparkConf;
330339

331-
// Load and parse the data file.
332-
// Cache the data since we will use it again to compute training error.
333-
String datapath = "data/mllib/sample_libsvm_data.txt";
334-
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
335-
336340
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
337341
JavaSparkContext sc = new JavaSparkContext(sparkConf);
338342

343+
// Load and parse the data file.
344+
String datapath = "data/mllib/sample_libsvm_data.txt";
345+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
346+
// Split the data into training and test sets (30% held out for testing)
347+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
348+
JavaRDD<LabeledPoint> trainingData = splits[0];
349+
JavaRDD<LabeledPoint> testData = splits[1];
350+
339351
// Set parameters.
340352
// Empty categoricalFeaturesInfo indicates all features are continuous.
341353
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
@@ -344,29 +356,32 @@ Integer maxDepth = 5;
344356
Integer maxBins = 32;
345357

346358
// Train a DecisionTree model.
347-
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
359+
final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
348360
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
349361

350-
// Evaluate model on training instances and compute training error
362+
// Evaluate model on test instances and compute test error
351363
JavaPairRDD<Double, Double> predictionAndLabel =
352-
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
353-
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
364+
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
365+
@Override
366+
public Tuple2<Double, Double> call(LabeledPoint p) {
354367
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
355368
}
356369
});
357-
Double trainMSE =
370+
Double testMSE =
358371
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
359-
@Override public Double call(Tuple2<Double, Double> pl) {
372+
@Override
373+
public Double call(Tuple2<Double, Double> pl) {
360374
Double diff = pl._1() - pl._2();
361375
return diff * diff;
362376
}
363377
}).reduce(new Function2<Double, Double, Double>() {
364-
@Override public Double call(Double a, Double b) {
378+
@Override
379+
public Double call(Double a, Double b) {
365380
return a + b;
366381
}
367382
}) / data.count();
368-
System.out.println("Training Mean Squared Error: " + trainMSE);
369-
System.out.println("Learned regression tree model:\n" + model);
383+
System.out.println("Test Mean Squared Error: " + testMSE);
384+
System.out.println("Learned regression tree model:\n" + model.toDebugString());
370385
{% endhighlight %}
371386
</div>
372387

@@ -377,21 +392,22 @@ from pyspark.mllib.tree import DecisionTree
377392
from pyspark.mllib.util import MLUtils
378393

379394
# Load and parse the data file into an RDD of LabeledPoint.
380-
# Cache the data since we will use it again to compute training error.
381-
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
395+
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
396+
# Split the data into training and test sets (30% held out for testing)
397+
(trainingData, testData) = data.randomSplit([0.7, 0.3])
382398

383399
# Train a DecisionTree model.
384400
# Empty categoricalFeaturesInfo indicates all features are continuous.
385-
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
401+
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
386402
impurity='variance', maxDepth=5, maxBins=32)
387403

388-
# Evaluate model on training instances and compute training error
389-
predictions = model.predict(data.map(lambda x: x.features))
390-
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
391-
trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count())
392-
print('Training Mean Squared Error = ' + str(trainMSE))
404+
# Evaluate model on test instances and compute test error
405+
predictions = model.predict(testData.map(lambda x: x.features))
406+
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
407+
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
408+
print('Test Mean Squared Error = ' + str(testMSE))
393409
print('Learned regression tree model:')
394-
print(model)
410+
print(model.toDebugString())
395411
{% endhighlight %}
396412

397413
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather

docs/mllib-gbt.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
9696

9797
<div data-lang="scala">
9898
{% highlight scala %}
99-
import org.apache.spark.mllib.tree.DecisionTree
99+
import org.apache.spark.mllib.tree.GradientBoostedTrees
100100
import org.apache.spark.mllib.util.MLUtils
101101

102102
// Load and parse the data file.

0 commit comments

Comments
 (0)