From e60b5476ada57b17aff46771b36867e62fabffe9 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 16 Nov 2015 16:02:53 -0800 Subject: [PATCH 1/5] Added save/load to LogisticRegression Estimator, and improved unit test for it --- .../classification/LogisticRegression.scala | 17 +++++--- .../org/apache/spark/ml/util/ReadWrite.scala | 1 + .../ml/classification/ClassifierSuite.scala | 32 ++++++++++++++ .../LogisticRegressionSuite.scala | 37 ++++++++++++---- .../ProbabilisticClassifierSuite.scala | 14 ++++++ .../spark/ml/util/DefaultReadWriteTest.scala | 43 +++++++++++++++++++ .../apache/spark/ml/util/TempDirectory.scala | 7 ++- 7 files changed, 136 insertions(+), 15 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a88f52674102..31dcdbc352fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Logging { + with LogisticRegressionParams with Writable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,6 +385,12 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object LogisticRegression extends Readable[LogisticRegression] { + override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] } /** @@ -518,12 +524,12 @@ class LogisticRegressionModel private[ml] ( * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. */ - override def write: Writer = new LogisticRegressionWriter(this) + override def write: Writer = new LogisticRegressionModelWriter(this) } /** [[Writer]] instance for [[LogisticRegressionModel]] */ -private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel) +private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends Writer with Logging { private case class Data( @@ -546,13 +552,14 @@ private[classification] class LogisticRegressionWriter(instance: LogisticRegress object LogisticRegressionModel extends Readable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader + override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader override def load(path: String): LogisticRegressionModel = read.load(path) } -private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] { +private[classification] class LogisticRegressionModelReader + extends Reader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 3169c9e9af5b..dddb72af5ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -217,6 +217,7 @@ private[ml] object DefaultParamsWriter { * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type + * TODO: Consider adding check for correct class name. */ private[ml] class DefaultParamsReader[T] extends Reader[T] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 000000000000..d0e3fe7ad14b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +object ClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "rawPredictionCol" -> "myRawPrediction" + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 51b06b7eb6d5..48ce1bb63068 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -873,15 +873,34 @@ class LogisticRegressionSuite } test("read/write") { - // Set some Params to make sure set Params are serialized. + def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + assert(model.numClasses === model2.numClasses) + assert(model.numFeatures === model2.numFeatures) + } val lr = new LogisticRegression() - .setElasticNetParam(0.1) - .setMaxIter(2) - .fit(dataset) - val lr2 = testDefaultReadWrite(lr) - assert(lr.intercept === lr2.intercept) - assert(lr.coefficients.toArray === lr2.coefficients.toArray) - assert(lr.numClasses === lr2.numClasses) - assert(lr.numFeatures === lr2.numFeatures) + testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, + checkModelData) } } + +object LogisticRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6), + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> false, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index fb5f00e0646c..cfa75ecf387c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } } + +object ProbabilisticClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6) + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index c37f0503f133..66a20a816d0e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,8 +22,10 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Model, Estimator} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -69,6 +71,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(another.uid === instance.uid) another } + + /** + * Default test for Estimator, Model pairs: + * - Explicitly set Params, and train model + * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Check Params on Estimator and Model + * @param estimator Estimator to test + * @param dataset Dataset to pass to [[Estimator.fit()]] + * @param testParams Set of [[Param]] values to set in estimator + * @param checkModelData Method which takes the original and loaded [[Model]] and compares their + * data. This method does not need to check [[Param]] values. + * @tparam E Type of [[Estimator]] + * @tparam M Type of [[Model]] produced by estimator + */ + def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + estimator: E, + dataset: DataFrame, + testParams: Map[String, Any], + checkModelData: (M, M) => Unit): Unit = { + // Set some Params to make sure set Params are serialized. + testParams.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + val model = estimator.fit(dataset) + + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator) + testParams.foreach { case (p, v) => + val param = estimator.getParam(p) + assert(estimator.get(param).get === estimator2.get(param).get) + } + + deleteTempDir() + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + testParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + } } class MyParams(override val uid: String) extends Params with Writable { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 2742026a69c2..d83244627c2d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -39,7 +39,12 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => } override def afterAll(): Unit = { - Utils.deleteRecursively(_tempDir) + deleteTempDir() super.afterAll() } + + /** Delete [[tempDir]] */ + def deleteTempDir(): Unit = { + Utils.deleteRecursively(_tempDir) + } } From 52c22fe336daef032fc7c88a369ab98ac25470af Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 16 Nov 2015 16:11:15 -0800 Subject: [PATCH 2/5] Moved LogisticRegressionReader/Writer to within LogisticRegressionModel --- .../classification/LogisticRegression.scala | 80 +++++++++---------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 31dcdbc352fa..34c738de527b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -524,29 +524,7 @@ class LogisticRegressionModel private[ml] ( * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. */ - override def write: Writer = new LogisticRegressionModelWriter(this) -} - - -/** [[Writer]] instance for [[LogisticRegressionModel]] */ -private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { - - private case class Data( - numClasses: Int, - numFeatures: Int, - intercept: Double, - coefficients: Vector) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: numClasses, numFeatures, intercept, coefficients - val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, - instance.coefficients) - val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) - } + override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } @@ -555,30 +533,50 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader override def load(path: String): LogisticRegressionModel = read.load(path) -} + /** [[Writer]] instance for [[LogisticRegressionModel]] */ + private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + extends Writer with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } -private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + private[classification] class LogisticRegressionModelReader + extends Reader[LogisticRegressionModel] { - /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" - override def load(path: String): LogisticRegressionModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - .select("numClasses", "numFeatures", "intercept", "coefficients").head() - // We will need numClasses, numFeatures in the future for multinomial logreg support. - // val numClasses = data.getInt(0) - // val numFeatures = data.getInt(1) - val intercept = data.getDouble(2) - val coefficients = data.getAs[Vector](3) - val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) - model + DefaultParamsReader.getAndSetParams(model, metadata) + model + } } } From b4b828c54cc24137a857daf2d031205e4dc7b80f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 11:34:38 -0800 Subject: [PATCH 3/5] Changed testDefaultReadWrite to create a subdir in tempDir with a random name, to avoid conflicts between calls --- .../spark/ml/classification/LogisticRegression.scala | 2 ++ .../apache/spark/ml/util/DefaultReadWriteTest.scala | 11 ++++++++--- .../org/apache/spark/ml/util/TempDirectory.scala | 7 +------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 34c738de527b..71c2533bcbf4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -523,6 +523,8 @@ class LogisticRegressionModel private[ml] ( * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. */ override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 66a20a816d0e..3e41c450184b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.apache.hadoop.fs.Path import org.scalatest.Suite import org.apache.spark.SparkFunSuite @@ -31,6 +32,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Checks "overwrite" option and params. + * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name + * in order to avoid conflicts from multiple calls to this method. * @param instance ML instance to test saving/loading * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. * @tparam T ML instance type @@ -40,7 +43,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance: T, testParams: Boolean = true): T = { val uid = instance.uid - val path = new File(tempDir, uid).getPath + val subdirName = Identifiable.randomUID("test") + val subdir = new Path(tempDir.getPath, subdirName).toString + val path = new File(subdir, uid).getPath instance.save(path) intercept[IOException] { @@ -77,6 +82,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Explicitly set Params, and train model * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model * - Check Params on Estimator and Model + * + * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] * @param testParams Set of [[Param]] values to set in estimator @@ -103,8 +110,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(estimator.get(param).get === estimator2.get(param).get) } - deleteTempDir() - // Test Model save/load val model2 = testDefaultReadWrite(model) testParams.foreach { case (p, v) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index d83244627c2d..2742026a69c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -39,12 +39,7 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => } override def afterAll(): Unit = { - deleteTempDir() - super.afterAll() - } - - /** Delete [[tempDir]] */ - def deleteTempDir(): Unit = { Utils.deleteRecursively(_tempDir) + super.afterAll() } } From 49ad64862ca2631d983b192da7c7e8a3b4559624 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 13:03:22 -0800 Subject: [PATCH 4/5] removed unnecessary part of PipelineSuite read/write test --- .../src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 484026b1ba9a..7f5c3895acb0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(pipeline2.stages(0).isInstanceOf[WritableStage]) val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] assert(writableStage.getIntParam === writableStage2.getIntParam) - - val path = new File(tempDir, pipeline.uid).getPath - val stagesDir = new Path(path, "stages").toString - val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir) - assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)), - s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" + - s" to be saved to path: $expectedStagePath") } test("PipelineModel read/write: getStagePath") { From 1f61d8609e48314585c7171c3c86d70e10fde6cd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 17 Nov 2015 13:04:48 -0800 Subject: [PATCH 5/5] removed Path import from DefaultReadWriteTest --- .../scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 3e41c450184b..dd1e8acce941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.util import java.io.{File, IOException} -import org.apache.hadoop.fs.Path import org.scalatest.Suite import org.apache.spark.SparkFunSuite @@ -44,7 +43,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => testParams: Boolean = true): T = { val uid = instance.uid val subdirName = Identifiable.randomUID("test") - val subdir = new Path(tempDir.getPath, subdirName).toString + + val subdir = new File(tempDir, subdirName) val path = new File(subdir, uid).getPath instance.save(path)