Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1749aec
Try adding PMMLExportable to ML with KMeans
holdenk Oct 21, 2015
bc1b508
Everything is better with tests
holdenk Oct 22, 2015
adf0b36
Move the PMML loading utils into test and make it clear they are test…
holdenk Oct 23, 2015
494ecbf
Merge in master (now both pmml and native export)
holdenk Dec 1, 2015
41611b8
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Dec 30, 2015
9525283
Merge in master
holdenk Jan 11, 2016
461c1ce
Fix import ordering from automerge
holdenk Jan 11, 2016
9aa5265
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 20, 2016
b514421
Import order fix
holdenk Jan 20, 2016
57c303d
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 21, 2016
6e2efc2
Merge branch 'SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans' …
holdenk Jan 21, 2016
90b0e22
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jan 25, 2016
2adc069
Merge in master
holdenk Mar 7, 2016
4f693d1
Just make the old ML one inheret from the new one and override the on…
holdenk Mar 7, 2016
4f3ac08
Fix ambigious reference
holdenk Mar 7, 2016
46879d3
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Mar 24, 2016
a99880d
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 5, 2016
2fb0857
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 13, 2016
8042e4f
Exclude the methods we've moved around
holdenk Apr 13, 2016
bebd0e7
Merge in master
holdenk Apr 14, 2016
344d5a0
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Apr 15, 2016
a30b6c5
Merge in master
holdenk Jun 21, 2016
cf335bf
Start refactoring to the more generic writer
holdenk Jun 22, 2016
59dd4c6
More progress towards the new API
holdenk Jun 22, 2016
b5a57ea
Revert some uneeded changes with the new approach
holdenk Jun 22, 2016
b7edccf
Basic test pass
holdenk Jun 22, 2016
a41b474
Refactor tests a bit
holdenk Jun 23, 2016
1146e45
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jul 1, 2016
0dd6c94
Make the writer have a convience function to set the type to pmml (an…
holdenk Jul 2, 2016
49f8a8d
Add pmml() to the writer
holdenk Jul 3, 2016
e6845f1
pmmml -> pmml
holdenk Jul 3, 2016
0b042e8
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Jul 5, 2016
9170b3f
Add a MIMA exclusion for the saveImpl change and fix inadvertant spac…
holdenk Jul 5, 2016
8579c1b
Merge in master
holdenk Jul 18, 2016
c8573f0
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Aug 1, 2016
8103b76
Remove TODO since we have moved it to Spark 2.1 now that master is on…
holdenk Aug 1, 2016
bdcfbd1
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Aug 3, 2016
00173aa
Update master
holdenk Aug 5, 2016
0e8c523
Merge branch 'master' into SPARK-11171-SPARK-11237-Add-PMML-export-fo…
holdenk Sep 8, 2016
9cb8994
Merge in master
holdenk Nov 16, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.ml.clustering

import javax.xml.transform.stream.StreamResult
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.ml.pmml.PMMLExportable
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
Expand Down Expand Up @@ -96,7 +98,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {
extends Model[KMeansModel] with KMeansParams with MLWritable with PMMLExportable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand Down Expand Up @@ -132,6 +134,15 @@ class KMeansModel private[ml] (
parentModel.computeCost(data)
}


/**
* Export the model to stream result in PMML format
*/
@Since("1.6.0")
override def toPMML(streamResult: StreamResult): Unit = {
parentModel.toPMML(streamResult)
}

@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
}
Expand Down Expand Up @@ -264,4 +275,3 @@ object KMeans extends DefaultParamsReadable[KMeans] {
@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)
}

90 changes: 90 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/pmml/PMMLExportable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this is a copy-paste of org.apache.spark.mllib.pmml, should we deprecate the mllib one, and use the new one in ml package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that might be good, the main difference is this avoids using the factory implementation that the MLLib API was.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Uses the same public facing API as per the JIRA discussion re: lack of complaints from users with old API)

* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.pmml

import java.io.{File, OutputStream, StringWriter}
import javax.xml.transform.stream.StreamResult

import org.jpmml.model.JAXBUtil

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory

/**
* :: DeveloperApi ::
* Export model to the PMML format
* Predictive Model Markup Language (PMML) is an XML-based file format
* developed by the Data Mining Group (www.dmg.org).
* Based on [[org.apache.spark.mllib.pmml.Exportable]]
*/
@DeveloperApi
@Since("1.6.0")
trait PMMLExportable {

/**
* Export the model to the stream result in PMML format.
*/
private[spark] def toPMML(streamResult: StreamResult): Unit

/**
* :: Experimental ::
* Export the model to a local file in PMML format
*/
@Experimental
@Since("1.6.0")
def toPMML(localPath: String): Unit = {
toPMML(new StreamResult(new File(localPath)))
}

/**
* :: Experimental ::
* Export the model to a directory on a distributed file system in PMML format.
* Models should override if they may contain more data than
* is reasonable to store locally.
*/
@Experimental
@Since("1.6.0")
def toPMML(sc: SparkContext, path: String): Unit = {
val pmml = toPMML()
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
}

/**
* :: Experimental ::
* Export the model to the OutputStream in PMML format
*/
@Experimental
@Since("1.6.0")
def toPMML(outputStream: OutputStream): Unit = {
toPMML(new StreamResult(outputStream))
}

/**
* :: Experimental ::
* Export the model to a String in PMML format
*/
@Experimental
@Since("1.6.0")
def toPMML(): String = {
val writer = new StringWriter
toPMML(new StreamResult(writer))
writer.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait PMMLExportable {
/**
* Export the model to the stream result in PMML format
*/
private def toPMML(streamResult: StreamResult): Unit = {
private[spark] def toPMML(streamResult: StreamResult): Unit = {
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.ml.util.PMMLUtils
import org.apache.spark.sql.{DataFrame, SQLContext}

private[clustering] case class TestRow(features: Vector)
Expand Down Expand Up @@ -99,6 +100,16 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.computeCost(dataset) < 0.1)
}


test("pmml export") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
val pmmlStr = model.toPMML()
val pmmlModel = PMMLUtils.loadFromString(pmmlStr)
assert(pmmlModel.getDataDictionary.getNumberOfFields === 3)
}

test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
Expand Down
43 changes: 43 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util

import java.io.StringReader
import javax.xml.bind.Unmarshaller
import javax.xml.transform.Source

import org.dmg.pmml._
import org.jpmml.model.{ImportFilter, JAXBUtil}
import org.xml.sax.InputSource

/**
* Testing utils for working with PMML.
* Predictive Model Markup Language (PMML) is an XML-based file format
* developed by the Data Mining Group (www.dmg.org).
*/
private[spark] object PMMLUtils {
/**
* :: Experimental ::
* Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported
* through external spark-packages.
*/
def loadFromString(input: String): PMML = {
val is = new StringReader(input)
val transformed = ImportFilter.apply(new InputSource(is))
JAXBUtil.unmarshalPMML(transformed)
}
}