From 1749aecff0d52a3485ebe6c4852619f24608556d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 21 Oct 2015 16:26:01 -0700 Subject: [PATCH 01/18] Try adding PMMLExportable to ML with KMeans --- .../apache/spark/ml/clustering/KMeans.scala | 15 +++- .../apache/spark/ml/pmml/PMMLExportable.scala | 90 +++++++++++++++++++ .../spark/mllib/pmml/PMMLExportable.scala | 2 +- 3 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be6300239..145e159e9e77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,9 +17,12 @@ package org.apache.spark.ml.clustering +import javax.xml.transform.stream.StreamResult + import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.pmml.PMMLExportable import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -94,7 +97,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams + with PMMLExportable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +133,14 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + /** + * Export the model to stream result in PMML format + */ + @Since("1.6.0") + override def toPMML(streamResult: StreamResult): Unit = { + parentModel.toPMML(streamResult) + } } /** @@ -209,4 +221,3 @@ class KMeans @Since("1.5.0") ( validateAndTransformSchema(schema) } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala new file mode 100644 index 000000000000..f446f73fba30 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.pmml + +import java.io.{File, OutputStream, StringWriter} +import javax.xml.transform.stream.StreamResult + +import org.jpmml.model.JAXBUtil + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory + +/** + * :: DeveloperApi :: + * Export model to the PMML format + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + * Based on [[org.apache.spark.mllib.pmml.Exportable]] + */ +@DeveloperApi +@Since("1.6.0") +trait PMMLExportable { + + /** + * Export the model to the stream result in PMML format. + */ + private[spark] def toPMML(streamResult: StreamResult): Unit + + /** + * :: Experimental :: + * Export the model to a local file in PMML format + */ + @Experimental + @Since("1.6.0") + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * :: Experimental :: + * Export the model to a directory on a distributed file system in PMML format. + * Models should override if they may contain more data than + * is reasonable to store locally. + */ + @Experimental + @Since("1.6.0") + def toPMML(sc: SparkContext, path: String): Unit = { + val pmml = toPMML() + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) + } + + /** + * :: Experimental :: + * Export the model to the OutputStream in PMML format + */ + @Experimental + @Since("1.6.0") + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * :: Experimental :: + * Export the model to a String in PMML format + */ + @Experimental + @Since("1.6.0") + def toPMML(): String = { + val writer = new StringWriter + toPMML(new StreamResult(writer)) + writer.toString + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 274ac7c99553..1bba5735d452 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -39,7 +39,7 @@ trait PMMLExportable { /** * Export the model to the stream result in PMML format */ - private def toPMML(streamResult: StreamResult): Unit = { + private[spark] def toPMML(streamResult: StreamResult): Unit = { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) } From bc1b508d2af1dc54ddb1a3a077a2c48cff64a91b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 21 Oct 2015 17:59:52 -0700 Subject: [PATCH 02/18] Everything is better with tests --- .../org/apache/spark/ml/pmml/PMMLUtils.scala | 42 +++++++++++++++++++ .../spark/ml/clustering/KMeansSuite.scala | 10 +++++ 2 files changed, 52 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala new file mode 100644 index 000000000000..d629526d0789 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.pmml + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d16..d4c6ab722404 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.pmml._ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) @@ -106,4 +107,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) } + + test("pmml export") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + val pmmlStr = model.toPMML() + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getDataDictionary.getNumberOfFields === 3) + } } From adf0b367855b83560d0444dabc62ae102340a0af Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Oct 2015 10:53:50 -0700 Subject: [PATCH 03/18] Move the PMML loading utils into test and make it clear they are test only (since actual PMML model evaluation is to be done through a spark-packages project and previous JIRA decided re-loading PMML is out of project scope) --- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 2 +- .../scala/org/apache/spark/ml/util}/PMMLUtils.scala | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) rename mllib/src/{main/scala/org/apache/spark/ml/pmml => test/scala/org/apache/spark/ml/util}/PMMLUtils.scala (87%) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index d4c6ab722404..6b81a0221560 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.ml.pmml._ +import org.apache.spark.ml.util.PMMLUtils import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) diff --git a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala similarity index 87% rename from mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala rename to mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala index d629526d0789..dbdc69f95d84 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.ml.pmml +package org.apache.spark.ml.util import java.io.StringReader import javax.xml.bind.Unmarshaller @@ -25,14 +25,15 @@ import org.jpmml.model.{ImportFilter, JAXBUtil} import org.xml.sax.InputSource /** - * Utils for working with PMML. + * Testing utils for working with PMML. * Predictive Model Markup Language (PMML) is an XML-based file format * developed by the Data Mining Group (www.dmg.org). */ private[spark] object PMMLUtils { /** * :: Experimental :: - * Load a PMML model from a string + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. */ def loadFromString(input: String): PMML = { val is = new StringReader(input) From 461c1ce37c69110bdc563ef8c21dbab3f0bb5490 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 11 Jan 2016 14:19:51 -0800 Subject: [PATCH 04/18] Fix import ordering from automerge --- .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9effb668f197..7ccd60328c63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.clustering import javax.xml.transform.stream.StreamResult + import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} -import org.apache.spark.ml.pmml.PMMLExportable import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.pmml.PMMLExportable import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} From b514421683170d8c29ee3d39cb50abb59ff74816 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 20 Jan 2016 14:55:45 -0800 Subject: [PATCH 05/18] Import order fix --- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ac3dfc057f1c..d98dc9ad13a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, PMMLUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.ml.util.PMMLUtils import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) From 4f693d1e5bb3356d9d922788e94beae5c94b06e4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 7 Mar 2016 14:32:39 -0800 Subject: [PATCH 06/18] Just make the old ML one inheret from the new one and override the one required method for the factory logic --- .../apache/spark/ml/pmml/PMMLExportable.scala | 1 - .../spark/mllib/pmml/PMMLExportable.scala | 49 ++----------------- 2 files changed, 3 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala index f446f73fba30..48227a70fb6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala @@ -86,5 +86,4 @@ trait PMMLExportable { toPMML(new StreamResult(writer)) writer.toString } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 1bba5735d452..00a804dd5d55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -24,6 +24,7 @@ import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.ml.pmml.{PMMLExportable => NewPMMLExportable} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -34,57 +35,13 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory */ @DeveloperApi @Since("1.4.0") -trait PMMLExportable { +trait PMMLExportable extends NewPMMLExportable { /** * Export the model to the stream result in PMML format */ - private[spark] def toPMML(streamResult: StreamResult): Unit = { + private[spark] override def toPMML(streamResult: StreamResult): Unit = { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) } - - /** - * :: Experimental :: - * Export the model to a local file in PMML format - */ - @Experimental - @Since("1.4.0") - def toPMML(localPath: String): Unit = { - toPMML(new StreamResult(new File(localPath))) - } - - /** - * :: Experimental :: - * Export the model to a directory on a distributed file system in PMML format - */ - @Experimental - @Since("1.4.0") - def toPMML(sc: SparkContext, path: String): Unit = { - val pmml = toPMML() - sc.parallelize(Array(pmml), 1).saveAsTextFile(path) - } - - /** - * :: Experimental :: - * Export the model to the OutputStream in PMML format - */ - @Experimental - @Since("1.4.0") - def toPMML(outputStream: OutputStream): Unit = { - toPMML(new StreamResult(outputStream)) - } - - /** - * :: Experimental :: - * Export the model to a String in PMML format - */ - @Experimental - @Since("1.4.0") - def toPMML(): String = { - val writer = new StringWriter - toPMML(new StreamResult(writer)) - writer.toString - } - } From 4f3ac08219c6cc4b73e2dbc3610e8879c74d022d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 7 Mar 2016 15:26:50 -0800 Subject: [PATCH 07/18] Fix ambigious reference --- .../apache/spark/examples/mllib/PMMLModelExportExample.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb1..fbe8cf7a40f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println("PMML Model:\n" + clusters.toPMML()) // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") From 8042e4f657d965dea217cfef2b2f197582fdfb30 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 13 Apr 2016 16:18:16 -0700 Subject: [PATCH 08/18] Exclude the methods we've moved around --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a30581eb487c..e9ba9506697d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -628,7 +628,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-14475] Propagate user-defined context from driver to executors ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty") - ) + ) ++ Seq( + // [SPARK-11171] [SPARK-11237] Add PMML export for ML + ProblemFilters.exclude[UpdateForwarderBodyProblem]("org.apache.spark.mllib.pmml.PMMLExportable.toPMML"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.pmml.PMMLExportable.toPMML")) case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), From cf335bf06a36e84fd9f85d5494b0b07e06d6b25c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 22 Jun 2016 13:38:01 -0700 Subject: [PATCH 09/18] Start refactoring to the more generic writer --- .../apache/spark/ml/clustering/KMeans.scala | 16 ++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 35 ++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a09a318b6a5e..9752702ecfd2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -105,7 +105,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable with PMMLExportable { + extends Model[KMeansModel] with KMeansParams with MLWritable with PMMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -214,7 +214,21 @@ object KMeansModel extends MLReadable[KMeansModel] { /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + override val supportedFormats = List("native", "pmml") + override protected def saveImpl(path: String): Unit = { + if (format == "native") { + saveNative(path) + } else { + savePMML(path) + } + } + + private def savePMML(path: String): Unit = { + parentModel.toPMML(sc, path) + } + + private def saveNative(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 90b8d7df7b49..d0e5fcf74387 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -98,7 +98,23 @@ abstract class MLWriter extends BaseReadWrite with Logging { s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") } } - saveImpl(path) + if (source in supportedFormmats) { + saveImpl(path) + } else { + throw new IllegalArgumentException(s"Format ${source} is not supported in this model") + } + } + + protected val supportedFormats = List("native") + + protected var source = "native" + /** + * Specifies the underlying output data format. Default is "native" with some models also + * supporting "pmml". + */ + def format(source: String): MLWriter = { + this.source = source + this } /** @@ -144,6 +160,23 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * :: Experimental :: + * + * Trait for classes that can be exported to PMML. + */ +@Experimental +@Since("2.1.0") +trait PMMLWritable extends MLWritable { + /** + * Save this ML instance to the input path in PMML format. A shortcut of + * `write.format("pmml").save(path)`. + */ + def toPMML(path: String): Unit = { + write.format("pmml").save(path) + } +} + /** * :: Experimental :: * From 59dd4c6e6c043493e027f20c2f0433d1f56c6d5c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 22 Jun 2016 14:41:29 -0700 Subject: [PATCH 10/18] More progress towards the new API --- .../org/apache/spark/ml/clustering/KMeans.scala | 17 ++++------------- .../org/apache/spark/ml/util/ReadWrite.scala | 7 ++++--- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9752702ecfd2..862193acbc34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -151,14 +151,6 @@ class KMeansModel private[ml] ( parentModel.computeCost(data) } - /** - * Export the model to stream result in PMML format - */ - @Since("1.6.0") - override def toPMML(streamResult: StreamResult): Unit = { - parentModel.toPMML(streamResult) - } - /** * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. * @@ -217,15 +209,14 @@ object KMeansModel extends MLReadable[KMeansModel] { override val supportedFormats = List("native", "pmml") override protected def saveImpl(path: String): Unit = { - if (format == "native") { - saveNative(path) - } else { - savePMML(path) + source match { + case "native" => saveNative(path) + case "pmml" => savePMML(path) } } private def savePMML(path: String): Unit = { - parentModel.toPMML(sc, path) + instance.parentModel.toPMML(sc, path) } private def saveNative(path: String): Unit = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d0e5fcf74387..56e9462f34bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -98,10 +98,11 @@ abstract class MLWriter extends BaseReadWrite with Logging { s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") } } - if (source in supportedFormmats) { + if (supportedFormats.contains(source)) { saveImpl(path) } else { - throw new IllegalArgumentException(s"Format ${source} is not supported in this model") + throw new IllegalArgumentException(s"Format ${source} is not supported by this model " + + s"try one of (${supportedFormats})") } } @@ -113,7 +114,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { * supporting "pmml". */ def format(source: String): MLWriter = { - this.source = source + this.source = source.toLowerCase() this } From b5a57ea297ce3637ff2f5d37344457a426b22d1e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 22 Jun 2016 14:56:03 -0700 Subject: [PATCH 11/18] Revert some uneeded changes with the new approach --- .../mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/ml/clustering/KMeans.scala | 1 - .../apache/spark/ml/pmml/PMMLExportable.scala | 89 ------------------- .../spark/mllib/pmml/PMMLExportable.scala | 52 ++++++++++- project/MimaExcludes.scala | 6 -- 5 files changed, 49 insertions(+), 101 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index fbe8cf7a40f9..d74d74a37fb1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML()) + println("PMML Model:\n" + clusters.toPMML) // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 862193acbc34..cc33299cddc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -27,7 +27,6 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.pmml.PMMLExportable import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} diff --git a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala deleted file mode 100644 index 48227a70fb6d..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.pmml - -import java.io.{File, OutputStream, StringWriter} -import javax.xml.transform.stream.StreamResult - -import org.jpmml.model.JAXBUtil - -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} -import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory - -/** - * :: DeveloperApi :: - * Export model to the PMML format - * Predictive Model Markup Language (PMML) is an XML-based file format - * developed by the Data Mining Group (www.dmg.org). - * Based on [[org.apache.spark.mllib.pmml.Exportable]] - */ -@DeveloperApi -@Since("1.6.0") -trait PMMLExportable { - - /** - * Export the model to the stream result in PMML format. - */ - private[spark] def toPMML(streamResult: StreamResult): Unit - - /** - * :: Experimental :: - * Export the model to a local file in PMML format - */ - @Experimental - @Since("1.6.0") - def toPMML(localPath: String): Unit = { - toPMML(new StreamResult(new File(localPath))) - } - - /** - * :: Experimental :: - * Export the model to a directory on a distributed file system in PMML format. - * Models should override if they may contain more data than - * is reasonable to store locally. - */ - @Experimental - @Since("1.6.0") - def toPMML(sc: SparkContext, path: String): Unit = { - val pmml = toPMML() - sc.parallelize(Array(pmml), 1).saveAsTextFile(path) - } - - /** - * :: Experimental :: - * Export the model to the OutputStream in PMML format - */ - @Experimental - @Since("1.6.0") - def toPMML(outputStream: OutputStream): Unit = { - toPMML(new StreamResult(outputStream)) - } - - /** - * :: Experimental :: - * Export the model to a String in PMML format - */ - @Experimental - @Since("1.6.0") - def toPMML(): String = { - val writer = new StringWriter - toPMML(new StreamResult(writer)) - writer.toString - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 00a804dd5d55..17eee59da29a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -24,7 +24,6 @@ import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} -import org.apache.spark.ml.pmml.{PMMLExportable => NewPMMLExportable} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -35,13 +34,58 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory */ @DeveloperApi @Since("1.4.0") -trait PMMLExportable extends NewPMMLExportable { +trait PMMLExportable { /** - * Export the model to the stream result in PMML format + * Export the model to the stream result in PMML format. */ - private[spark] override def toPMML(streamResult: StreamResult): Unit = { + private def toPMML(streamResult: StreamResult): Unit = { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) } + + /** + * :: Experimental :: + * Export the model to a local file in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * :: Experimental :: + * Export the model to a directory on a distributed file system in PMML format. + * Models should override if they may contain more data than + * is reasonable to store locally. + */ + @Experimental + @Since("1.4.0") + def toPMML(sc: SparkContext, path: String): Unit = { + val pmml = toPMML() + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) + } + + /** + * :: Experimental :: + * Export the model to the OutputStream in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * :: Experimental :: + * Export the model to a String in PMML format + */ + @Experimental + @Since("1.4.0") + def toPMML(): String = { + val writer = new StringWriter + toPMML(new StreamResult(writer)) + writer.toString + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 08d1c31ea94a..a6209d78e168 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -643,12 +643,6 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") ) ++ Seq( // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty") - ) ++ Seq( - // [SPARK-11171] [SPARK-11237] Add PMML export for ML - ProblemFilters.exclude[UpdateForwarderBodyProblem]("org.apache.spark.mllib.pmml.PMMLExportable.toPMML"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.pmml.PMMLExportable.toPMML") - ) ++ Seq( ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), // [SPARK-14617] Remove deprecated APIs in TaskMetrics ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), From b7edccf46236bb366326a7ed02a7023a52700a9f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 22 Jun 2016 15:57:35 -0700 Subject: [PATCH 12/18] Basic test pass --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 2 -- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 7 ++++++- project/MimaExcludes.scala | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index cc33299cddc7..b3e96e3d591f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.clustering -import javax.xml.transform.stream.StreamResult - import org.apache.hadoop.fs.Path import org.apache.spark.SparkException diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32413c380896..cb61151c492a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, PMMLUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.util.Utils private[clustering] case class TestRow(features: Vector) @@ -122,7 +123,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) - val pmmlStr = model.toPMML() + val tempDir = Utils.createTempDir() + val exportPath = tempDir.getPath() + "/pmml" + model.toPMML("file://" + exportPath) + val pmmlStr = sc.textFile(exportPath).collect.mkString("\n") + Utils.deleteRecursively(tempDir) val pmmlModel = PMMLUtils.loadFromString(pmmlStr) assert(pmmlModel.getDataDictionary.getNumberOfFields === 3) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a6209d78e168..74e126ee49dd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -643,7 +643,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") ) ++ Seq( // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), // [SPARK-14617] Remove deprecated APIs in TaskMetrics ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), From a41b4746a2d0129db9b63cc36c15cf58b4319bf1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 22 Jun 2016 20:40:29 -0700 Subject: [PATCH 13/18] Refactor tests a bit --- .../spark/ml/clustering/KMeansSuite.scala | 25 +++++---- .../spark/ml/util/PMMLReadWriteTest.scala | 54 +++++++++++++++++++ 2 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index cb61151c492a..a7c3eb48838f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.clustering +import org.dmg.pmml.PMML + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, PMMLUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -27,7 +29,8 @@ import org.apache.spark.util.Utils private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -123,13 +126,17 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) - val tempDir = Utils.createTempDir() - val exportPath = tempDir.getPath() + "/pmml" - model.toPMML("file://" + exportPath) - val pmmlStr = sc.textFile(exportPath).collect.mkString("\n") - Utils.deleteRecursively(tempDir) - val pmmlModel = PMMLUtils.loadFromString(pmmlStr) - assert(pmmlModel.getDataDictionary.getNumberOfFields === 3) + def checkModel(pmml: PMML): Unit = { + assert(pmml.getDataDictionary.getNumberOfFields === 3) + } + testPMMLWrite(sc, model, checkModel) + } + + test("generic pmml export") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 000000000000..2ac0e4b8a218 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with PMMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.toPMML(path) + intercept[IOException] { + instance.toPMML(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName() == "Apache Spark MLlib") + } +} From 0dd6c94f2c425c1a396cc2bb3a9387e17156c3eb Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 1 Jul 2016 17:13:29 -0700 Subject: [PATCH 14/18] Make the writer have a convience function to set the type to pmml (and do so in a type safe way --- .../apache/spark/ml/clustering/KMeans.scala | 17 ++------ .../org/apache/spark/ml/util/ReadWrite.scala | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f309c92eb256..2dd9f4b0347d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -156,7 +156,7 @@ class KMeansModel private[ml] ( * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: PMMLWriter = new KMeansModel.KMeansModelWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -201,22 +201,13 @@ object KMeansModel extends MLReadable[KMeansModel] { private case class OldData(clusterCenters: Array[OldVector]) /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends PMMLWriter { - override val supportedFormats = List("native", "pmml") - - override protected def saveImpl(path: String): Unit = { - source match { - case "native" => saveNative(path) - case "pmml" => savePMML(path) - } - } - - private def savePMML(path: String): Unit = { + override protected def savePMML(path: String): Unit = { instance.parentModel.toPMML(sc, path) } - private def saveNative(path: String): Unit = { + override protected def saveNative(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index aba2a411b393..880a28e9254e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -122,6 +122,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { protected val supportedFormats = List("native") protected var source = "native" + /** * Specifies the underlying output data format. Default is "native" with some models also * supporting "pmml". @@ -154,6 +155,36 @@ abstract class MLWriter extends BaseReadWrite with Logging { override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } +/** + * :: Experimental :: + * + * Abstract class for utility classes that can save ML instances in both native Spark and PMML. + */ +@Experimental +@Since("1.6.0") +abstract class PMMLWriter extends MLWriter { + protected override val supportedFormats = List("pmml", "native") + + /** + * Specifies the underlying output data format as PMML + */ + def pmmml(): MLWriter = { + this.source = "pmml" + this + } + + override protected def saveImpl(path: String): Unit = { + source match { + case "native" => saveNative(path) + case "pmml" => savePMML(path) + } + } + + protected def savePMML(path: String): Unit + + protected def saveNative(path: String): Unit +} + /** * :: Experimental :: * @@ -185,10 +216,18 @@ trait MLWritable { @Experimental @Since("2.1.0") trait PMMLWritable extends MLWritable { + /** + * Returns an [[PMMLWriter]] instance for this ML instance capable of saving to both native Spark + * and PMML. + */ + @Since("2.1.0") + override def write: PMMLWriter + /** * Save this ML instance to the input path in PMML format. A shortcut of * `write.format("pmml").save(path)`. */ + @Since("2.1.0") def toPMML(path: String): Unit = { write.format("pmml").save(path) } From 49f8a8d540860e28de8a356616f78a64ab112909 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 3 Jul 2016 02:08:51 -0700 Subject: [PATCH 15/18] Add pmml() to the writer --- .../scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala index 2ac0e4b8a218..1cfa9325f61f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -43,8 +43,11 @@ trait PMMLReadWriteTest extends TempDirectory { self: Suite => val path = new File(subdir, uid).getPath instance.toPMML(path) - intercept[IOException] { + intercept[IOException] { instance.toPMML(path) + } + intercept[IOException] { + instance.write.pmml().save(path) } instance.write.format("pmml").overwrite().save(path) val pmmlStr = sc.textFile(path).collect.mkString("\n") From e6845f1c37481b270359a6a4291c07dcaace242e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 3 Jul 2016 02:43:19 -0700 Subject: [PATCH 16/18] pmmml -> pmml --- mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 880a28e9254e..06bde43d5cbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -168,7 +168,7 @@ abstract class PMMLWriter extends MLWriter { /** * Specifies the underlying output data format as PMML */ - def pmmml(): MLWriter = { + def pmml(): MLWriter = { this.source = "pmml" this } From 9170b3fc0b4142258c06f19eb2b4d65890e81fa9 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 5 Jul 2016 12:44:22 -0700 Subject: [PATCH 17/18] Add a MIMA exclusion for the saveImpl change and fix inadvertant spacing change --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 74e126ee49dd..ea9454c2c836 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -643,7 +643,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") ) ++ Seq( // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), // [SPARK-14617] Remove deprecated APIs in TaskMetrics ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), @@ -787,6 +787,9 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema") + ) ++ Seq( + // [SPARK-11171][SPARK-11237][SPARK-11241] Add PMML exportable to ML (TODO move to Spark 2.1 once master is updated to 2.1) + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.ml.util.MLWriter.saveImpl") ) case v if v.startsWith("1.6") => Seq( From 8103b76032012c521005a2b4149dc1e3ac26cd67 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 1 Aug 2016 11:26:49 -0700 Subject: [PATCH 18/18] Remove TODO since we have moved it to Spark 2.1 now that master is on 2.1 --- project/MimaExcludes.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d3c3a43505b7..172f8a69ebf6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,7 +40,7 @@ object MimaExcludes { // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references") ) ++ Seq( - // [SPARK-11171][SPARK-11237][SPARK-11241] Add PMML exportable to ML (TODO move to Spark 2.1 once master is updated to 2.1) + // [SPARK-11171][SPARK-11237][SPARK-11241] Add PMML exportable to ML ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.ml.util.MLWriter.saveImpl") ) }