Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1749aec
Try adding PMMLExportable to ML with KMeans
holdenk Oct 21, 2015
bc1b508
Everything is better with tests
holdenk Oct 22, 2015
adf0b36
Move the PMML loading utils into test and make it clear they are test…
holdenk Oct 23, 2015
494ecbf
Merge in master (now both pmml and native export)
holdenk Dec 1, 2015
41611b8
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Dec 30, 2015
9525283
Merge in master
holdenk Jan 11, 2016
461c1ce
Fix import ordering from automerge
holdenk Jan 11, 2016
9aa5265
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 20, 2016
b514421
Import order fix
holdenk Jan 20, 2016
57c303d
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 21, 2016
6e2efc2
Merge branch 'SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans' …
holdenk Jan 21, 2016
90b0e22
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 25, 2016
2adc069
Merge in master
holdenk Mar 7, 2016
4f693d1
Just make the old ML one inheret from the new one and override the on…
holdenk Mar 7, 2016
4f3ac08
Fix ambigious reference
holdenk Mar 7, 2016
46879d3
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Mar 24, 2016
a99880d
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 5, 2016
2fb0857
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 13, 2016
8042e4f
Exclude the methods we've moved around
holdenk Apr 13, 2016
bebd0e7
Merge in master
holdenk Apr 14, 2016
344d5a0
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 15, 2016
a30b6c5
Merge in master
holdenk Jun 21, 2016
cf335bf
Start refactoring to the more generic writer
holdenk Jun 22, 2016
59dd4c6
More progress towards the new API
holdenk Jun 22, 2016
b5a57ea
Revert some uneeded changes with the new approach
holdenk Jun 22, 2016
b7edccf
Basic test pass
holdenk Jun 22, 2016
a41b474
Refactor tests a bit
holdenk Jun 23, 2016
1146e45
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jul 1, 2016
0dd6c94
Make the writer have a convience function to set the type to pmml (an…
holdenk Jul 2, 2016
49f8a8d
Add pmml() to the writer
holdenk Jul 3, 2016
e6845f1
pmmml -> pmml
holdenk Jul 3, 2016
0b042e8
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jul 5, 2016
9170b3f
Add a MIMA exclusion for the saveImpl change and fix inadvertant spac…
holdenk Jul 5, 2016
8579c1b
Merge in master
holdenk Jul 18, 2016
c8573f0
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Aug 1, 2016
8103b76
Remove TODO since we have moved it to Spark 2.1 now that master is on…
holdenk Aug 1, 2016
bdcfbd1
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Aug 3, 2016
00173aa
Update master
holdenk Aug 5, 2016
0e8c523
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Sep 8, 2016
9cb8994
Merge in master
holdenk Nov 16, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {
extends Model[KMeansModel] with KMeansParams with MLWritable with PMMLWritable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -160,7 +160,7 @@ class KMeansModel private[ml] (
*
*/
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
override def write: PMMLWriter = new KMeansModel.KMeansModelWriter(this)

private var trainingSummary: Option[KMeansSummary] = None

Expand Down Expand Up @@ -205,9 +205,13 @@ object KMeansModel extends MLReadable[KMeansModel] {
private case class OldData(clusterCenters: Array[OldVector])

/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends PMMLWriter {

override protected def saveImpl(path: String): Unit = {
override protected def savePMML(path: String): Unit = {
instance.parentModel.toPMML(sc, path)
}

override protected def saveNative(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
Expand Down
75 changes: 74 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,25 @@ abstract class MLWriter extends BaseReadWrite with Logging {
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
saveImpl(path)
if (supportedFormats.contains(source)) {
saveImpl(path)
} else {
throw new IllegalArgumentException(s"Format ${source} is not supported by this model " +
s"try one of (${supportedFormats})")
}
}

protected val supportedFormats = List("native")

protected var source = "native"

/**
* Specifies the underlying output data format. Default is "native" with some models also
* supporting "pmml".
*/
def format(source: String): MLWriter = {
this.source = source.toLowerCase()
this
}

/**
Expand All @@ -137,6 +155,36 @@ abstract class MLWriter extends BaseReadWrite with Logging {
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
* :: Experimental ::
*
* Abstract class for utility classes that can save ML instances in both native Spark and PMML.
*/
@Experimental
@Since("1.6.0")
abstract class PMMLWriter extends MLWriter {
protected override val supportedFormats = List("pmml", "native")

/**
* Specifies the underlying output data format as PMML
*/
def pmml(): MLWriter = {
this.source = "pmml"
this
}

override protected def saveImpl(path: String): Unit = {
source match {
case "native" => saveNative(path)
case "pmml" => savePMML(path)
}
}

protected def savePMML(path: String): Unit

protected def saveNative(path: String): Unit
}

/**
* :: Experimental ::
*
Expand All @@ -163,6 +211,31 @@ trait MLWritable {
/**
* :: DeveloperApi ::
*
* Trait for classes that can be exported to PMML.
*/
@Experimental
@Since("2.1.0")
trait PMMLWritable extends MLWritable {
/**
* Returns an [[PMMLWriter]] instance for this ML instance capable of saving to both native Spark
* and PMML.
*/
@Since("2.1.0")
override def write: PMMLWriter

/**
* Save this ML instance to the input path in PMML format. A shortcut of
* `write.format("pmml").save(path)`.
*/
@Since("2.1.0")
def toPMML(path: String): Unit = {
write.format("pmml").save(path)
}
}

/**
* :: Experimental ::
*
* Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores
* all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide
* a default implementation of writing saved instances of the class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format
* Export the model to the stream result in PMML format.
*/
private def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
Expand All @@ -53,7 +53,9 @@ trait PMMLExportable {
}

/**
* Export the model to a directory on a distributed file system in PMML format
* Export the model to a directory on a distributed file system in PMML format.
* Models should override if they may contain more data than
* is reasonable to store locally.
*/
@Since("1.4.0")
def toPMML(sc: SparkContext, path: String): Unit = {
Expand All @@ -78,5 +80,4 @@ trait PMMLExportable {
toPMML(new StreamResult(writer))
writer.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package org.apache.spark.ml.clustering

import org.dmg.pmml.PMML

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.util.Utils

private[clustering] case class TestRow(features: Vector)

class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
with PMMLReadWriteTest {

final val k = 5
@transient var dataset: Dataset[_] = _
Expand Down Expand Up @@ -125,6 +129,24 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusterSizes.forall(_ >= 0))
}


test("pmml export") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
def checkModel(pmml: PMML): Unit = {
assert(pmml.getDataDictionary.getNumberOfFields === 3)
}
testPMMLWrite(sc, model, checkModel)
}

test("generic pmml export") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)

}

test("KMeansModel transform with non-default feature and prediction cols") {
val featuresColName = "kmeans_model_features"
val predictionColName = "kmeans_model_prediction"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import java.io.{File, IOException}

import org.dmg.pmml.PMML
import org.scalatest.Suite

import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset

trait PMMLReadWriteTest extends TempDirectory { self: Suite =>
/**
* Test PMML export. Requires exported model is small enough to be loaded locally.
* Checks that the model can be exported and the result is valid PMML, but does not check
* the specific contents of the model.
*/
def testPMMLWrite[T <: Params with PMMLWritable](sc: SparkContext, instance: T,
checkModelData: PMML => Unit): Unit = {
val uid = instance.uid
val subdirName = Identifiable.randomUID("pmml-")

val subdir = new File(tempDir, subdirName)
val path = new File(subdir, uid).getPath

instance.toPMML(path)
intercept[IOException] {
instance.toPMML(path)
}
intercept[IOException] {
instance.write.pmml().save(path)
}
instance.write.format("pmml").overwrite().save(path)
val pmmlStr = sc.textFile(path).collect.mkString("\n")
val pmmlModel = PMMLUtils.loadFromString(pmmlStr)
assert(pmmlModel.getHeader().getApplication().getName() == "Apache Spark MLlib")
}
}
43 changes: 43 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util

import java.io.StringReader
import javax.xml.bind.Unmarshaller
import javax.xml.transform.Source

import org.dmg.pmml._
import org.jpmml.model.{ImportFilter, JAXBUtil}
import org.xml.sax.InputSource

/**
* Testing utils for working with PMML.
* Predictive Model Markup Language (PMML) is an XML-based file format
* developed by the Data Mining Group (www.dmg.org).
*/
private[spark] object PMMLUtils {
/**
* :: Experimental ::
* Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported
* through external spark-packages.
*/
def loadFromString(input: String): PMML = {
val is = new StringReader(input)
val transformed = ImportFilter.apply(new InputSource(is))
JAXBUtil.unmarshalPMML(transformed)
}
}
7 changes: 7 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ object MimaExcludes {
// [SPARK-14743] Improve delegation token handling in secure cluster
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"),
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
) ++
Seq(
// [SPARK-11171][SPARK-11237][SPARK-11241] Add PMML exportable to ML
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.ml.util.MLWriter.saveImpl")
) ++
Seq(
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
// [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"),
Expand Down