diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 2761aeb789621..cdd9d4999fa1b 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -34,7 +34,7 @@ net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically, it is defined as a convex combination of the $L_1$ and the $L_2$ regularization terms: `\[ -\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 \]` By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ regularization as special cases. For example, if a [linear @@ -95,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { SparkContext sc = new SparkContext(conf); SQLContext sql = new SQLContext(sc); - String path = "sample_libsvm_data.txt"; + String path = "data/mllib/sample_libsvm_data.txt"; // Load training data DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); @@ -103,7 +103,7 @@ public class LogisticRegressionWithElasticNetExample { LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.3) - .setElasticNetParam(0.8) + .setElasticNetParam(0.8); // Fit the model LogisticRegressionModel lrModel = lr.fit(training); @@ -158,10 +158,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight scala %} +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example val trainingSummary = lrModel.summary -// Obtain the loss per iteration. +// Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory objectiveHistory.foreach(loss => println(loss)) @@ -173,17 +175,14 @@ val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. val roc = binarySummary.roc roc.show() -roc.select("FPR").show() println(binarySummary.areaUnderROC) -// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with -// this selected threshold. +// Set the model threshold to maximize F-Measure val fMeasure = binarySummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). select("threshold").head().getDouble(0) -logReg.setThreshold(bestThreshold) -logReg.fit(logRegDataFrame) +lrModel.setThreshold(bestThreshold) {% endhighlight %} @@ -199,8 +198,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight java %} +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.sql.functions; + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); +LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -222,20 +225,131 @@ System.out.println(binarySummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. DataFrame fMeasure = binarySummary.fMeasureByThreshold(); -double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); +double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). select("threshold").head().getDouble(0); -logReg.setThreshold(bestThreshold); -logReg.fit(logRegDataFrame); +lrModel.setThreshold(bestThreshold); {% endhighlight %} +