@@ -151,7 +151,7 @@ The example below demonstrates how to load a
151151[ LIBSVM data file] ( http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ ) ,
152152parse it as an RDD of ` LabeledPoint ` and then
153153perform classification using a decision tree with Gini impurity as an impurity measure and a
154- maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
154+ maximum tree depth of 5. The test error is calculated to measure the algorithm accuracy.
155155
156156<div class =" codetabs " >
157157
@@ -161,8 +161,10 @@ import org.apache.spark.mllib.tree.DecisionTree
161161import org.apache.spark.mllib.util.MLUtils
162162
163163// Load and parse the data file.
164- // Cache the data since we will use it again to compute training error.
165- val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
164+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
165+ // Split the data into training and test sets (30% held out for testing)
166+ val splits = data.randomSplit(Array(0.7, 0.3))
167+ val (trainingData, testData) = (splits(0), splits(1))
166168
167169// Train a DecisionTree model.
168170// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -172,25 +174,24 @@ val impurity = "gini"
172174val maxDepth = 5
173175val maxBins = 32
174176
175- val model = DecisionTree.trainClassifier(data , numClasses, categoricalFeaturesInfo, impurity ,
176- maxDepth, maxBins)
177+ val model = DecisionTree.trainClassifier(trainingData , numClasses, categoricalFeaturesInfo,
178+ impurity, maxDepth, maxBins)
177179
178- // Evaluate model on training instances and compute training error
179- val labelAndPreds = data .map { point =>
180+ // Evaluate model on test instances and compute test error
181+ val labelAndPreds = testData .map { point =>
180182 val prediction = model.predict(point.features)
181183 (point.label, prediction)
182184}
183- val trainErr = labelAndPreds.filter(r => r._ 1 != r._ 2).count.toDouble / data .count
184- println("Training Error = " + trainErr )
185- println("Learned classification tree model:\n" + model)
185+ val testErr = labelAndPreds.filter(r => r._ 1 != r._ 2).count.toDouble / testData .count()
186+ println("Test Error = " + testErr )
187+ println("Learned classification tree model:\n" + model.toDebugString )
186188{% endhighlight %}
187189</div >
188190
189191<div data-lang =" java " >
190192{% highlight java %}
191193import java.util.HashMap;
192194import scala.Tuple2;
193- import org.apache.spark.api.java.function.Function2;
194195import org.apache.spark.api.java.JavaPairRDD;
195196import org.apache.spark.api.java.JavaRDD;
196197import org.apache.spark.api.java.JavaSparkContext;
@@ -206,9 +207,12 @@ SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
206207JavaSparkContext sc = new JavaSparkContext(sparkConf);
207208
208209// Load and parse the data file.
209- // Cache the data since we will use it again to compute training error.
210210String datapath = "data/mllib/sample_libsvm_data.txt";
211- JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
211+ JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
212+ // Split the data into training and test sets (30% held out for testing)
213+ JavaRDD<LabeledPoint >[ ] splits = data.randomSplit(new double[ ] {0.7, 0.3});
214+ JavaRDD<LabeledPoint > trainingData = splits[ 0] ;
215+ JavaRDD<LabeledPoint > testData = splits[ 1] ;
212216
213217// Set parameters.
214218// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -219,24 +223,26 @@ Integer maxDepth = 5;
219223Integer maxBins = 32;
220224
221225// Train a DecisionTree model for classification.
222- final DecisionTreeModel model = DecisionTree.trainClassifier(data , numClasses,
226+ final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData , numClasses,
223227 categoricalFeaturesInfo, impurity, maxDepth, maxBins);
224228
225- // Evaluate model on training instances and compute training error
229+ // Evaluate model on test instances and compute test error
226230JavaPairRDD<Double, Double> predictionAndLabel =
227- data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
228- @Override public Tuple2<Double, Double> call(LabeledPoint p) {
231+ testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
232+ @Override
233+ public Tuple2<Double, Double> call(LabeledPoint p) {
229234 return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
230235 }
231236 });
232- Double trainErr =
237+ Double testErr =
233238 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
234- @Override public Boolean call(Tuple2<Double, Double> pl) {
239+ @Override
240+ public Boolean call(Tuple2<Double, Double> pl) {
235241 return !pl._ 1().equals(pl._ 2());
236242 }
237- }).count() / data .count();
238- System.out.println("Training error : " + trainErr );
239- System.out.println("Learned classification tree model:\n" + model);
243+ }).count() / testData .count();
244+ System.out.println("Test Error : " + testErr );
245+ System.out.println("Learned classification tree model:\n" + model.toDebugString() );
240246{% endhighlight %}
241247</div >
242248
@@ -247,21 +253,22 @@ from pyspark.mllib.tree import DecisionTree
247253from pyspark.mllib.util import MLUtils
248254
249255# Load and parse the data file into an RDD of LabeledPoint.
250- # Cache the data since we will use it again to compute training error.
251- data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
256+ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
257+ # Split the data into training and test sets (30% held out for testing)
258+ (trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
252259
253260# Train a DecisionTree model.
254261# Empty categoricalFeaturesInfo indicates all features are continuous.
255- model = DecisionTree.trainClassifier(data , numClasses=2, categoricalFeaturesInfo={},
262+ model = DecisionTree.trainClassifier(trainingData , numClasses=2, categoricalFeaturesInfo={},
256263 impurity='gini', maxDepth=5, maxBins=32)
257264
258- # Evaluate model on training instances and compute training error
259- predictions = model.predict(data .map(lambda x: x.features))
260- labelsAndPredictions = data .map(lambda lp: lp.label).zip(predictions)
261- trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data .count())
262- print('Training Error = ' + str(trainErr ))
265+ # Evaluate model on test instances and compute test error
266+ predictions = model.predict(testData .map(lambda x: x.features))
267+ labelsAndPredictions = testData .map(lambda lp: lp.label).zip(predictions)
268+ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData .count())
269+ print('Test Error = ' + str(testErr ))
263270print('Learned classification tree model:')
264- print(model)
271+ print(model.toDebugString() )
265272{% endhighlight %}
266273
267274Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
@@ -288,8 +295,10 @@ import org.apache.spark.mllib.tree.DecisionTree
288295import org.apache.spark.mllib.util.MLUtils
289296
290297// Load and parse the data file.
291- // Cache the data since we will use it again to compute training error.
292- val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache()
298+ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
299+ // Split the data into training and test sets (30% held out for testing)
300+ val splits = data.randomSplit(Array(0.7, 0.3))
301+ val (trainingData, testData) = (splits(0), splits(1))
293302
294303// Train a DecisionTree model.
295304// Empty categoricalFeaturesInfo indicates all features are continuous.
@@ -298,17 +307,17 @@ val impurity = "variance"
298307val maxDepth = 5
299308val maxBins = 32
300309
301- val model = DecisionTree.trainRegressor(data , categoricalFeaturesInfo, impurity,
310+ val model = DecisionTree.trainRegressor(trainingData , categoricalFeaturesInfo, impurity,
302311 maxDepth, maxBins)
303312
304- // Evaluate model on training instances and compute training error
305- val labelsAndPredictions = data .map { point =>
313+ // Evaluate model on test instances and compute test error
314+ val labelsAndPredictions = testData .map { point =>
306315 val prediction = model.predict(point.features)
307316 (point.label, prediction)
308317}
309- val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
310- println("Training Mean Squared Error = " + trainMSE )
311- println("Learned regression tree model:\n" + model)
318+ val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
319+ println("Test Mean Squared Error = " + testMSE )
320+ println("Learned regression tree model:\n" + model.toDebugString )
312321{% endhighlight %}
313322</div >
314323
@@ -328,14 +337,17 @@ import org.apache.spark.mllib.tree.model.DecisionTreeModel;
328337import org.apache.spark.mllib.util.MLUtils;
329338import org.apache.spark.SparkConf;
330339
331- // Load and parse the data file.
332- // Cache the data since we will use it again to compute training error.
333- String datapath = "data/mllib/sample_libsvm_data.txt";
334- JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
335-
336340SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
337341JavaSparkContext sc = new JavaSparkContext(sparkConf);
338342
343+ // Load and parse the data file.
344+ String datapath = "data/mllib/sample_libsvm_data.txt";
345+ JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
346+ // Split the data into training and test sets (30% held out for testing)
347+ JavaRDD<LabeledPoint >[ ] splits = data.randomSplit(new double[ ] {0.7, 0.3});
348+ JavaRDD<LabeledPoint > trainingData = splits[ 0] ;
349+ JavaRDD<LabeledPoint > testData = splits[ 1] ;
350+
339351// Set parameters.
340352// Empty categoricalFeaturesInfo indicates all features are continuous.
341353HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
@@ -344,29 +356,32 @@ Integer maxDepth = 5;
344356Integer maxBins = 32;
345357
346358// Train a DecisionTree model.
347- final DecisionTreeModel model = DecisionTree.trainRegressor(data ,
359+ final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData ,
348360 categoricalFeaturesInfo, impurity, maxDepth, maxBins);
349361
350- // Evaluate model on training instances and compute training error
362+ // Evaluate model on test instances and compute test error
351363JavaPairRDD<Double, Double> predictionAndLabel =
352- data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
353- @Override public Tuple2<Double, Double> call(LabeledPoint p) {
364+ testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
365+ @Override
366+ public Tuple2<Double, Double> call(LabeledPoint p) {
354367 return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
355368 }
356369 });
357- Double trainMSE =
370+ Double testMSE =
358371 predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
359- @Override public Double call(Tuple2<Double, Double> pl) {
372+ @Override
373+ public Double call(Tuple2<Double, Double> pl) {
360374 Double diff = pl._ 1() - pl._ 2();
361375 return diff * diff;
362376 }
363377 }).reduce(new Function2<Double, Double, Double>() {
364- @Override public Double call(Double a, Double b) {
378+ @Override
379+ public Double call(Double a, Double b) {
365380 return a + b;
366381 }
367382 }) / data.count();
368- System.out.println("Training Mean Squared Error: " + trainMSE );
369- System.out.println("Learned regression tree model:\n" + model);
383+ System.out.println("Test Mean Squared Error: " + testMSE );
384+ System.out.println("Learned regression tree model:\n" + model.toDebugString() );
370385{% endhighlight %}
371386</div >
372387
@@ -377,21 +392,22 @@ from pyspark.mllib.tree import DecisionTree
377392from pyspark.mllib.util import MLUtils
378393
379394# Load and parse the data file into an RDD of LabeledPoint.
380- # Cache the data since we will use it again to compute training error.
381- data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
395+ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
396+ # Split the data into training and test sets (30% held out for testing)
397+ (trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
382398
383399# Train a DecisionTree model.
384400# Empty categoricalFeaturesInfo indicates all features are continuous.
385- model = DecisionTree.trainRegressor(data , categoricalFeaturesInfo={},
401+ model = DecisionTree.trainRegressor(trainingData , categoricalFeaturesInfo={},
386402 impurity='variance', maxDepth=5, maxBins=32)
387403
388- # Evaluate model on training instances and compute training error
389- predictions = model.predict(data .map(lambda x: x.features))
390- labelsAndPredictions = data .map(lambda lp: lp.label).zip(predictions)
391- trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data .count())
392- print('Training Mean Squared Error = ' + str(trainMSE ))
404+ # Evaluate model on test instances and compute test error
405+ predictions = model.predict(testData .map(lambda x: x.features))
406+ labelsAndPredictions = testData .map(lambda lp: lp.label).zip(predictions)
407+ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData .count())
408+ print('Test Mean Squared Error = ' + str(testMSE ))
393409print('Learned regression tree model:')
394- print(model)
410+ print(model.toDebugString() )
395411{% endhighlight %}
396412
397413Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
0 commit comments