|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.examples.mllib |
| 19 | + |
| 20 | +import scopt.OptionParser |
| 21 | + |
| 22 | +import org.apache.spark.{SparkConf, SparkContext} |
| 23 | +import org.apache.spark.mllib.evaluation.MulticlassMetrics |
| 24 | +import org.apache.spark.mllib.tree.GradientBoosting |
| 25 | +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} |
| 26 | +import org.apache.spark.util.Utils |
| 27 | + |
| 28 | +/** |
| 29 | + * An example runner for Gradient Boosting using decision trees as weak learners. Run with |
| 30 | + * {{{ |
| 31 | + * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] |
| 32 | + * }}} |
| 33 | + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. |
| 34 | + * |
| 35 | + * Note: This script treats all features as real-valued (not categorical). |
| 36 | + * To include categorical features, modify categoricalFeaturesInfo. |
| 37 | + */ |
| 38 | +object GradientBoostedTrees { |
| 39 | + |
| 40 | + case class Params( |
| 41 | + input: String = null, |
| 42 | + testInput: String = "", |
| 43 | + dataFormat: String = "libsvm", |
| 44 | + algo: String = "Classification", |
| 45 | + maxDepth: Int = 5, |
| 46 | + numIterations: Int = 10, |
| 47 | + fracTest: Double = 0.2) extends AbstractParams[Params] |
| 48 | + |
| 49 | + def main(args: Array[String]) { |
| 50 | + val defaultParams = Params() |
| 51 | + |
| 52 | + val parser = new OptionParser[Params]("GradientBoostedTrees") { |
| 53 | + head("GradientBoostedTrees: an example decision tree app.") |
| 54 | + opt[String]("algo") |
| 55 | + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") |
| 56 | + .action((x, c) => c.copy(algo = x)) |
| 57 | + opt[Int]("maxDepth") |
| 58 | + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") |
| 59 | + .action((x, c) => c.copy(maxDepth = x)) |
| 60 | + opt[Int]("numIterations") |
| 61 | + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") |
| 62 | + .action((x, c) => c.copy(numIterations = x)) |
| 63 | + opt[Double]("fracTest") |
| 64 | + .text(s"fraction of data to hold out for testing. If given option testInput, " + |
| 65 | + s"this option is ignored. default: ${defaultParams.fracTest}") |
| 66 | + .action((x, c) => c.copy(fracTest = x)) |
| 67 | + opt[String]("testInput") |
| 68 | + .text(s"input path to test dataset. If given, option fracTest is ignored." + |
| 69 | + s" default: ${defaultParams.testInput}") |
| 70 | + .action((x, c) => c.copy(testInput = x)) |
| 71 | + opt[String]("<dataFormat>") |
| 72 | + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") |
| 73 | + .action((x, c) => c.copy(dataFormat = x)) |
| 74 | + arg[String]("<input>") |
| 75 | + .text("input path to labeled examples") |
| 76 | + .required() |
| 77 | + .action((x, c) => c.copy(input = x)) |
| 78 | + checkConfig { params => |
| 79 | + if (params.fracTest < 0 || params.fracTest > 1) { |
| 80 | + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") |
| 81 | + } else { |
| 82 | + success |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + parser.parse(args, defaultParams).map { params => |
| 88 | + run(params) |
| 89 | + }.getOrElse { |
| 90 | + sys.exit(1) |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + def run(params: Params) { |
| 95 | + |
| 96 | + val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") |
| 97 | + val sc = new SparkContext(conf) |
| 98 | + |
| 99 | + println(s"GradientBoostedTrees with parameters:\n$params") |
| 100 | + |
| 101 | + // Load training and test data and cache it. |
| 102 | + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, |
| 103 | + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) |
| 104 | + |
| 105 | + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) |
| 106 | + boostingStrategy.numClassesForClassification = numClasses |
| 107 | + boostingStrategy.numIterations = params.numIterations |
| 108 | + boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth |
| 109 | + |
| 110 | + val randomSeed = Utils.random.nextInt() |
| 111 | + if (params.algo == "Classification") { |
| 112 | + val startTime = System.nanoTime() |
| 113 | + val model = GradientBoosting.trainClassifier(training, boostingStrategy) |
| 114 | + val elapsedTime = (System.nanoTime() - startTime) / 1e9 |
| 115 | + println(s"Training time: $elapsedTime seconds") |
| 116 | + if (model.totalNumNodes < 30) { |
| 117 | + println(model.toDebugString) // Print full model. |
| 118 | + } else { |
| 119 | + println(model) // Print model summary. |
| 120 | + } |
| 121 | + val trainAccuracy = |
| 122 | + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) |
| 123 | + .precision |
| 124 | + println(s"Train accuracy = $trainAccuracy") |
| 125 | + val testAccuracy = |
| 126 | + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision |
| 127 | + println(s"Test accuracy = $testAccuracy") |
| 128 | + } else if (params.algo == "Regression") { |
| 129 | + val startTime = System.nanoTime() |
| 130 | + val model = GradientBoosting.trainRegressor(training, boostingStrategy) |
| 131 | + val elapsedTime = (System.nanoTime() - startTime) / 1e9 |
| 132 | + println(s"Training time: $elapsedTime seconds") |
| 133 | + if (model.totalNumNodes < 30) { |
| 134 | + println(model.toDebugString) // Print full model. |
| 135 | + } else { |
| 136 | + println(model) // Print model summary. |
| 137 | + } |
| 138 | + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) |
| 139 | + println(s"Train mean squared error = $trainMSE") |
| 140 | + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) |
| 141 | + println(s"Test mean squared error = $testMSE") |
| 142 | + } |
| 143 | + |
| 144 | + sc.stop() |
| 145 | + } |
| 146 | +} |
0 commit comments