diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b294ac..bdbea42cfb23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -104,7 +104,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + extends Model[KMeansModel] with KMeansParams with MLWritable with PMMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -160,7 +160,7 @@ class KMeansModel private[ml] ( * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: PMMLWriter = new KMeansModel.KMeansModelWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -205,9 +205,13 @@ object KMeansModel extends MLReadable[KMeansModel] { private case class OldData(clusterCenters: Array[OldVector]) /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends PMMLWriter { - override protected def saveImpl(path: String): Unit = { + override protected def savePMML(path: String): Unit = { + instance.parentModel.toPMML(sc, path) + } + + override protected def saveNative(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: cluster centers diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee..2381df4e8a9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -111,7 +111,25 @@ abstract class MLWriter extends BaseReadWrite with Logging { s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") } } - saveImpl(path) + if (supportedFormats.contains(source)) { + saveImpl(path) + } else { + throw new IllegalArgumentException(s"Format ${source} is not supported by this model " + + s"try one of (${supportedFormats})") + } + } + + protected val supportedFormats = List("native") + + protected var source = "native" + + /** + * Specifies the underlying output data format. Default is "native" with some models also + * supporting "pmml". + */ + def format(source: String): MLWriter = { + this.source = source.toLowerCase() + this } /** @@ -137,6 +155,36 @@ abstract class MLWriter extends BaseReadWrite with Logging { override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } +/** + * :: Experimental :: + * + * Abstract class for utility classes that can save ML instances in both native Spark and PMML. + */ +@Experimental +@Since("1.6.0") +abstract class PMMLWriter extends MLWriter { + protected override val supportedFormats = List("pmml", "native") + + /** + * Specifies the underlying output data format as PMML + */ + def pmml(): MLWriter = { + this.source = "pmml" + this + } + + override protected def saveImpl(path: String): Unit = { + source match { + case "native" => saveNative(path) + case "pmml" => savePMML(path) + } + } + + protected def savePMML(path: String): Unit + + protected def saveNative(path: String): Unit +} + /** * :: Experimental :: * @@ -163,6 +211,31 @@ trait MLWritable { /** * :: DeveloperApi :: * + * Trait for classes that can be exported to PMML. + */ +@Experimental +@Since("2.1.0") +trait PMMLWritable extends MLWritable { + /** + * Returns an [[PMMLWriter]] instance for this ML instance capable of saving to both native Spark + * and PMML. + */ + @Since("2.1.0") + override def write: PMMLWriter + + /** + * Save this ML instance to the input path in PMML format. A shortcut of + * `write.format("pmml").save(path)`. + */ + @Since("2.1.0") + def toPMML(path: String): Unit = { + write.format("pmml").save(path) + } +} + +/** + * :: Experimental :: + * * Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide * a default implementation of writing saved instances of the class. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 5d61796f1de6..73bf00294b43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -37,7 +37,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory trait PMMLExportable { /** - * Export the model to the stream result in PMML format + * Export the model to the stream result in PMML format. */ private def toPMML(streamResult: StreamResult): Unit = { val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) @@ -53,7 +53,9 @@ trait PMMLExportable { } /** - * Export the model to a directory on a distributed file system in PMML format + * Export the model to a directory on a distributed file system in PMML format. + * Models should override if they may contain more data than + * is reasonable to store locally. */ @Since("1.4.0") def toPMML(sc: SparkContext, path: String): Unit = { @@ -78,5 +80,4 @@ trait PMMLExportable { toPMML(new StreamResult(writer)) writer.toString } - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 73972557d263..ed36d46cae57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.ml.clustering +import org.dmg.pmml.PMML + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.util.Utils private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -125,6 +129,24 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.forall(_ >= 0)) } + + test("pmml export") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + def checkModel(pmml: PMML): Unit = { + assert(pmml.getDataDictionary.getNumberOfFields === 3) + } + testPMMLWrite(sc, model, checkModel) + } + + test("generic pmml export") { + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = kmeans.fit(dataset) + + } + test("KMeansModel transform with non-default feature and prediction cols") { val featuresColName = "kmeans_model_features" val predictionColName = "kmeans_model_prediction" diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 000000000000..1cfa9325f61f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with PMMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.toPMML(path) + intercept[IOException] { + instance.toPMML(path) + } + intercept[IOException] { + instance.write.pmml().save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName() == "Apache Spark MLlib") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 000000000000..dbdc69f95d84 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 12f7ed202b9d..92c27a7f9237 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -42,6 +42,13 @@ object MimaExcludes { // [SPARK-14743] Improve delegation token handling in secure cluster ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"), // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references") + ) ++ + Seq( + // [SPARK-11171][SPARK-11237][SPARK-11241] Add PMML exportable to ML + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.ml.util.MLWriter.saveImpl") + ) ++ + Seq( ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"), // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"),