-
Notifications
You must be signed in to change notification settings - Fork 29k
[ML][SPARK-23783][SPARK-11239] Add PMML export to Spark ML pipelines #19876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
43ae30f
9fec08f
0075bf4
c68880d
c2108df
de86190
8b1c752
72b509f
b8362a4
6e9cdc3
b8844c7
c265200
8fba2e5
6411054
4047239
cd330f3
41312e7
9075626
cb6fd70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| org.apache.spark.ml.regression.InternalLinearRegressionModelWriter | ||
| org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path | |
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.PredictorParams | ||
| import org.apache.spark.ml.{PipelineStage, PredictorParams} | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.linalg.BLAS._ | ||
|
|
@@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ | |
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.evaluation.RegressionMetrics | ||
| import org.apache.spark.mllib.linalg.VectorImplicits._ | ||
| import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} | ||
| import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.DoubleType | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
@@ -482,7 +483,7 @@ class LinearRegressionModel private[ml] ( | |
| @Since("2.0.0") val coefficients: Vector, | ||
| @Since("1.3.0") val intercept: Double) | ||
| extends RegressionModel[Vector, LinearRegressionModel] | ||
| with LinearRegressionParams with MLWritable { | ||
| with LinearRegressionParams with GeneralMLWritable { | ||
|
|
||
| private var trainingSummary: Option[LinearRegressionTrainingSummary] = None | ||
|
|
||
|
|
@@ -554,7 +555,49 @@ class LinearRegressionModel private[ml] ( | |
| * This also does not save the [[parent]] currently. | ||
| */ | ||
| @Since("1.6.0") | ||
| override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) | ||
| override def write: GeneralMLWriter = new GeneralMLWriter(this) | ||
| } | ||
|
|
||
| /** A writer for LinearRegression that handles the "internal" (or default) format */ | ||
| private class InternalLinearRegressionModelWriter() | ||
| extends MLWriterFormat with MLFormatRegister { | ||
|
|
||
| override def shortName(): String = | ||
| "internal+org.apache.spark.ml.regression.LinearRegressionModel" | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector) | ||
|
|
||
| override def write(path: String, sparkSession: SparkSession, | ||
| optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { | ||
| val instance = stage.asInstanceOf[LinearRegressionModel] | ||
| val sc = sparkSession.sparkContext | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: intercept, coefficients | ||
| val data = Data(instance.intercept, instance.coefficients) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| /** A writer for LinearRegression that handles the "pmml" format */ | ||
| private class PMMLLinearRegressionModelWriter() | ||
|
||
| extends MLWriterFormat with MLFormatRegister { | ||
|
|
||
| override def shortName(): String = | ||
| "pmml+org.apache.spark.ml.regression.LinearRegressionModel" | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector) | ||
|
|
||
| override def write(path: String, sparkSession: SparkSession, | ||
| optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { | ||
| val sc = sparkSession.sparkContext | ||
| // Construct the MLLib model which knows how to write to PMML. | ||
| val instance = stage.asInstanceOf[LinearRegressionModel] | ||
| val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) | ||
| // Save PMML | ||
| oldModel.toPMML(sc, path) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
@@ -566,22 +609,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { | |
| @Since("1.6.0") | ||
| override def load(path: String): LinearRegressionModel = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[LinearRegressionModel]] */ | ||
| private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) | ||
| extends MLWriter with Logging { | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: intercept, coefficients | ||
| val data = Data(instance.intercept, instance.coefficients) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { | ||
|
|
||
| /** Checked against metadata when loading model */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,18 +18,20 @@ | |
| package org.apache.spark.ml.util | ||
|
|
||
| import java.io.IOException | ||
| import java.util.Locale | ||
| import java.util.{Locale, ServiceLoader} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable | ||
| import scala.util.{Failure, Success, Try} | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s._ | ||
| import org.json4s.{DefaultFormats, JObject} | ||
| import org.json4s.JsonDSL._ | ||
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.annotation.{DeveloperApi, Since} | ||
| import org.apache.spark.{SparkContext, SparkException} | ||
| import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml._ | ||
| import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} | ||
|
|
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { | |
| protected final def sc: SparkContext = sparkSession.sparkContext | ||
| } | ||
|
|
||
| /** | ||
| * ML export formats for should implement this trait so that users can specify a shortname rather | ||
| * than the fully qualified class name of the exporter. | ||
| * | ||
| * A new instance of this class will be instantiated each time a DDL call is made. | ||
|
||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| @InterfaceStability.Evolving | ||
| trait MLFormatRegister { | ||
| /** | ||
| * The string that represents the format that this data source provider uses. This is | ||
| * overridden by children to provide a nice alias for the data source. For example: | ||
|
||
| * | ||
| * {{{ | ||
| * override def shortName(): String = | ||
| * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" | ||
|
||
| * }}} | ||
| * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. | ||
| * | ||
| * Format discovery is done using a ServiceLoader so make sure to list your format in | ||
| * META-INF/services. | ||
| * @since 2.3.0 | ||
| */ | ||
| def shortName(): String | ||
| } | ||
|
|
||
| /** | ||
| * Implemented by objects that provide ML exportability. | ||
| * | ||
| * A new instance of this class will be instantiated each time a DDL call is made. | ||
| * | ||
| * @since 2.3.0 | ||
| */ | ||
| @InterfaceStability.Evolving | ||
| trait MLWriterFormat { | ||
|
||
| /** | ||
| * Function write the provided pipeline stage out. | ||
|
||
| */ | ||
| def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], | ||
| stage: PipelineStage) | ||
|
||
| } | ||
|
|
||
| /** | ||
| * Abstract class for utility classes that can save ML instances. | ||
| */ | ||
| @deprecated("Use GeneralMLWriter instead. Will be removed in Spark 3.0.0", "2.3.0") | ||
|
||
| @Since("1.6.0") | ||
| abstract class MLWriter extends BaseReadWrite with Logging { | ||
|
|
||
| protected var shouldOverwrite: Boolean = false | ||
|
|
||
| /** | ||
|
|
@@ -110,6 +155,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { | |
| @Since("1.6.0") | ||
| protected def saveImpl(path: String): Unit | ||
|
|
||
| /** | ||
| * Overwrites if the output path already exists. | ||
| */ | ||
| @Since("1.6.0") | ||
| def overwrite(): this.type = { | ||
| shouldOverwrite = true | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Map to store extra options for this writer. | ||
| */ | ||
|
|
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { | |
| this | ||
| } | ||
|
|
||
| // override for Java compatibility | ||
| override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since tags here |
||
|
|
||
| // override for Java compatibility | ||
| override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) | ||
| } | ||
|
|
||
| /** | ||
| * A ML Writer which delegates based on the requested format. | ||
| */ | ||
| class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps for another PR, but maybe we could add a method here: def pmml(path: String): Unit = {
this.source = "pmml"
save(path)
}
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I don't think that belongs in the base GeneralMLWriter, but we could make a trait for writers which support PMML to mix in?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The follow up issue to track this is https://issues.apache.org/jira/browse/SPARK-11241
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need |
||
| private var source: String = "internal" | ||
|
|
||
| /** | ||
| * Overwrites if the output path already exists. | ||
| * Specifies the format of ML export (e.g. PMML, internal, or | ||
|
||
| * the fully qualified class name for export). | ||
| */ | ||
| @Since("1.6.0") | ||
| def overwrite(): this.type = { | ||
| shouldOverwrite = true | ||
| @Since("2.3.0") | ||
| def format(source: String): this.type = { | ||
| this.source = source | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Dispatches the save to the correct MLFormat. | ||
| */ | ||
| @Since("2.3.0") | ||
| @throws[IOException]("If the input path already exists but overwrite is not enabled.") | ||
| @throws[SparkException]("If multiple sources for a given short name format are found.") | ||
| override protected def saveImpl(path: String) = { | ||
|
||
| val loader = Utils.getContextOrSparkClassLoader | ||
| val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) | ||
| val stageName = stage.getClass.getName | ||
| val targetName = s"${source}+${stageName}" | ||
|
||
| val formats = serviceLoader.asScala.toList | ||
| val shortNames = formats.map(_.shortName()) | ||
| val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { | ||
| // requested name did not match any given registered alias | ||
| case Nil => | ||
| Try(loader.loadClass(source)) match { | ||
| case Success(writer) => | ||
| // Found the ML writer using the fully qualified path | ||
| writer | ||
| case Failure(error) => | ||
| throw new SparkException( | ||
| s"Could not load requested format $source for $stageName ($targetName) had $formats" + | ||
| s"supporting $shortNames", error) | ||
| } | ||
| case head :: Nil => | ||
| head.getClass | ||
| case _ => | ||
| // Multiple sources | ||
| throw new SparkException( | ||
| s"Multiple writers found for $source+$stageName, try using the class name of the writer") | ||
| } | ||
| if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { | ||
| val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will fail, non-intuitively, if anyone ever extends class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormatwill raise
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, we have the same issue with the DataFormat provider though. I don't think there is a way around this while keeping the DF like interface that lets us be pluggable in the way folks want (but if there is a way to require a 0 argument constructor in the concrete class with a trait I'm interested). I think given the folks who we general expect to be writing these formats that reasonable, but I'll add a comment about this in the doc? |
||
| writer.write(path, sparkSession, optionMap, stage) | ||
| } else { | ||
| throw new SparkException("ML source $source is not a valid MLWriterFormat") | ||
|
||
| } | ||
| } | ||
|
|
||
| // override for Java compatibility | ||
| override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) | ||
|
|
||
|
|
@@ -162,6 +270,18 @@ trait MLWritable { | |
| def save(path: String): Unit = write.save(path) | ||
| } | ||
|
|
||
| /** | ||
| * Trait for classes that provide `GeneralMLWriter`. | ||
| */ | ||
| @Since("2.3.0") | ||
| trait GeneralMLWritable extends MLWritable { | ||
| /** | ||
| * Returns an `MLWriter` instance for this ML instance. | ||
| */ | ||
| @Since("2.3.0") | ||
| override def write: GeneralMLWriter | ||
| } | ||
|
|
||
| /** | ||
| * :: DeveloperApi :: | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,20 +17,24 @@ | |
|
|
||
| package org.apache.spark.ml.regression | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} | ||
|
|
||
| import org.apache.spark.{SparkException, SparkFunSuite} | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.{ParamMap, ParamsSuite} | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils, PMMLReadWriteTest} | ||
| import org.apache.spark.ml.util.TestingUtils._ | ||
| import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} | ||
| import org.apache.spark.sql.{DataFrame, Row} | ||
|
|
||
| class LinearRegressionSuite | ||
| extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
| extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest | ||
| with PMMLReadWriteTest { | ||
|
|
||
| import testImplicits._ | ||
|
|
||
|
|
@@ -994,6 +998,38 @@ class LinearRegressionSuite | |
| LinearRegressionSuite.allParamSettings, checkModelData) | ||
| } | ||
|
|
||
| test("pmml export") { | ||
| val lr = new LinearRegression() | ||
| val model = lr.fit(datasetWithWeight) | ||
| def checkModel(pmml: PMML): Unit = { | ||
| val dd = pmml.getDataDictionary | ||
| assert(dd.getNumberOfFields === 3) | ||
| val fields = dd.getDataFields.asScala | ||
| assert(fields(0).getName().toString === "field_0") | ||
| assert(fields(0).getOpType() == OpType.CONTINUOUS) | ||
| val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] | ||
| val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors | ||
| val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList | ||
| assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) | ||
| assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) | ||
| } | ||
| testPMMLWrite(sc, model, checkModel) | ||
| } | ||
|
|
||
| test("unsupported export format") { | ||
|
||
| val lr = new LinearRegression() | ||
| val model = lr.fit(datasetWithWeight) | ||
| intercept[SparkException] { | ||
|
||
| model.write.format("boop").save("boop") | ||
| } | ||
| intercept[SparkException] { | ||
| model.write.format("com.holdenkarau.boop").save("boop") | ||
| } | ||
| intercept[SparkException] { | ||
| model.write.format("org.apache.spark.SparkContext").save("boop2") | ||
| } | ||
| } | ||
|
|
||
| test("should support all NumericType labels and weights, and not support other types") { | ||
| for (solver <- Seq("auto", "l-bfgs", "normal")) { | ||
| val lr = new LinearRegression().setMaxIter(1).setSolver(solver) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc above this is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed