-
Notifications
You must be signed in to change notification settings - Fork 29k
[ML][SPARK-23783][SPARK-11239] Add PMML export to Spark ML pipelines #19876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
43ae30f
9fec08f
0075bf4
c68880d
c2108df
de86190
8b1c752
72b509f
b8362a4
6e9cdc3
b8844c7
c265200
8fba2e5
6411054
4047239
cd330f3
41312e7
9075626
cb6fd70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| org.apache.spark.ml.regression.InternalLinearRegressionModelWriter | ||
| org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path | |
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml.PredictorParams | ||
| import org.apache.spark.ml.{PipelineStage, PredictorParams} | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.linalg.BLAS._ | ||
|
|
@@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ | |
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.evaluation.RegressionMetrics | ||
| import org.apache.spark.mllib.linalg.VectorImplicits._ | ||
| import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} | ||
| import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DataType, DoubleType, StructType} | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
@@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] ( | |
| @Since("1.3.0") val intercept: Double, | ||
| @Since("2.3.0") val scale: Double) | ||
| extends RegressionModel[Vector, LinearRegressionModel] | ||
| with LinearRegressionParams with MLWritable { | ||
| with LinearRegressionParams with GeneralMLWritable { | ||
|
|
||
| private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = | ||
| this(uid, coefficients, intercept, 1.0) | ||
|
|
@@ -710,15 +711,58 @@ class LinearRegressionModel private[ml] ( | |
| } | ||
|
|
||
| /** | ||
| * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. | ||
| * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. | ||
| * | ||
| * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. | ||
| * An option to save [[summary]] may be added in the future. | ||
| * | ||
| * This also does not save the [[parent]] currently. | ||
| */ | ||
| @Since("1.6.0") | ||
| override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) | ||
| override def write: GeneralMLWriter = new GeneralMLWriter(this) | ||
| } | ||
|
|
||
| /** A writer for LinearRegression that handles the "internal" (or default) format */ | ||
| private class InternalLinearRegressionModelWriter | ||
| extends MLWriterFormat with MLFormatRegister { | ||
|
|
||
| override def format(): String = "internal" | ||
| override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector, scale: Double) | ||
|
|
||
| override def write(path: String, sparkSession: SparkSession, | ||
| optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { | ||
| val instance = stage.asInstanceOf[LinearRegressionModel] | ||
| val sc = sparkSession.sparkContext | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: intercept, coefficients, scale | ||
| val data = Data(instance.intercept, instance.coefficients, instance.scale) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| /** A writer for LinearRegression that handles the "pmml" format */ | ||
| private class PMMLLinearRegressionModelWriter | ||
| extends MLWriterFormat with MLFormatRegister { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be two space indentation
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out, I'll fix it in a follow up.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've included this in #20907
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
|
|
||
| override def format(): String = "pmml" | ||
|
|
||
| override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector) | ||
|
|
||
| override def write(path: String, sparkSession: SparkSession, | ||
| optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { | ||
| val sc = sparkSession.sparkContext | ||
| // Construct the MLLib model which knows how to write to PMML. | ||
| val instance = stage.asInstanceOf[LinearRegressionModel] | ||
| val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) | ||
| // Save PMML | ||
| oldModel.toPMML(sc, path) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
@@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { | |
| @Since("1.6.0") | ||
| override def load(path: String): LinearRegressionModel = super.load(path) | ||
|
|
||
| /** [[MLWriter]] instance for [[LinearRegressionModel]] */ | ||
| private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) | ||
| extends MLWriter with Logging { | ||
|
|
||
| private case class Data(intercept: Double, coefficients: Vector, scale: Double) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| // Save metadata and Params | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| // Save model data: intercept, coefficients, scale | ||
| val data = Data(instance.intercept, instance.coefficients, instance.scale) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
| } | ||
|
|
||
| private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { | ||
|
|
||
| /** Checked against metadata when loading model */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,18 +18,20 @@ | |
| package org.apache.spark.ml.util | ||
|
|
||
| import java.io.IOException | ||
| import java.util.Locale | ||
| import java.util.{Locale, ServiceLoader} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.mutable | ||
| import scala.util.{Failure, Success, Try} | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s._ | ||
| import org.json4s.{DefaultFormats, JObject} | ||
| import org.json4s.JsonDSL._ | ||
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.annotation.{DeveloperApi, Since} | ||
| import org.apache.spark.{SparkContext, SparkException} | ||
| import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml._ | ||
| import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} | ||
|
|
@@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { | |
| } | ||
|
|
||
| /** | ||
| * Abstract class for utility classes that can save ML instances. | ||
| * Abstract class to be implemented by objects that provide ML exportability. | ||
| * | ||
| * A new instance of this class will be instantiated each time a save call is made. | ||
| * | ||
| * Must have a valid zero argument constructor which will be called to instantiate. | ||
| * | ||
| * @since 2.4.0 | ||
| */ | ||
| @InterfaceStability.Unstable | ||
| @Since("2.4.0") | ||
| trait MLWriterFormat { | ||
| /** | ||
| * Function to write the provided pipeline stage out. | ||
| * | ||
| * @param path The path to write the result out to. | ||
| * @param session SparkSession associated with the write request. | ||
| * @param optionMap User provided options stored as strings. | ||
| * @param stage The pipeline stage to be saved. | ||
| */ | ||
| @Since("2.4.0") | ||
| def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], | ||
| stage: PipelineStage): Unit | ||
| } | ||
|
|
||
| /** | ||
| * ML export formats for should implement this trait so that users can specify a shortname rather | ||
| * than the fully qualified class name of the exporter. | ||
| * | ||
| * A new instance of this class will be instantiated each time a save call is made. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment about zero arg constructor requirement
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| * | ||
| * @since 2.4.0 | ||
| */ | ||
| @InterfaceStability.Unstable | ||
| @Since("2.4.0") | ||
| trait MLFormatRegister extends MLWriterFormat { | ||
| /** | ||
| * The string that represents the format that this format provider uses. This is, along with | ||
| * stageName, is overridden by children to provide a nice alias for the writer. For example: | ||
| * | ||
| * {{{ | ||
| * override def format(): String = | ||
| * "pmml" | ||
| * }}} | ||
| * Indicates that this format is capable of saving a pmml model. | ||
| * | ||
| * Must have a valid zero argument constructor which will be called to instantiate. | ||
| * | ||
| * Format discovery is done using a ServiceLoader so make sure to list your format in | ||
| * META-INF/services. | ||
| * @since 2.4.0 | ||
| */ | ||
| @Since("2.4.0") | ||
| def format(): String | ||
|
|
||
| /** | ||
| * The string that represents the stage type that this writer supports. This is, along with | ||
| * format, is overridden by children to provide a nice alias for the writer. For example: | ||
| * | ||
| * {{{ | ||
| * override def stageName(): String = | ||
| * "org.apache.spark.ml.regression.LinearRegressionModel" | ||
| * }}} | ||
| * Indicates that this format is capable of saving Spark's own PMML model. | ||
| * | ||
| * Format discovery is done using a ServiceLoader so make sure to list your format in | ||
| * META-INF/services. | ||
| * @since 2.4.0 | ||
| */ | ||
| @Since("2.4.0") | ||
| def stageName(): String | ||
|
|
||
| private[ml] def shortName(): String = s"${format()}+${stageName()}" | ||
| } | ||
|
|
||
| /** | ||
| * Abstract class for utility classes that can save ML instances in Spark's internal format. | ||
| */ | ||
| @Since("1.6.0") | ||
| abstract class MLWriter extends BaseReadWrite with Logging { | ||
|
|
@@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { | |
| @Since("1.6.0") | ||
| protected def saveImpl(path: String): Unit | ||
|
|
||
| /** | ||
| * Overwrites if the output path already exists. | ||
| */ | ||
| @Since("1.6.0") | ||
| def overwrite(): this.type = { | ||
| shouldOverwrite = true | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Map to store extra options for this writer. | ||
| */ | ||
|
|
@@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { | |
| this | ||
| } | ||
|
|
||
| // override for Java compatibility | ||
| @Since("1.6.0") | ||
| override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) | ||
|
|
||
| // override for Java compatibility | ||
| @Since("1.6.0") | ||
| override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) | ||
| } | ||
|
|
||
| /** | ||
| * A ML Writer which delegates based on the requested format. | ||
| */ | ||
| @InterfaceStability.Unstable | ||
| @Since("2.4.0") | ||
| class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { | ||
| private var source: String = "internal" | ||
|
|
||
| /** | ||
| * Overwrites if the output path already exists. | ||
| * Specifies the format of ML export (e.g. "pmml", "internal", or | ||
| * the fully qualified class name for export). | ||
| */ | ||
| @Since("1.6.0") | ||
| def overwrite(): this.type = { | ||
| shouldOverwrite = true | ||
| @Since("2.4.0") | ||
| def format(source: String): this.type = { | ||
| this.source = source | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Dispatches the save to the correct MLFormat. | ||
| */ | ||
| @Since("2.4.0") | ||
| @throws[IOException]("If the input path already exists but overwrite is not enabled.") | ||
| @throws[SparkException]("If multiple sources for a given short name format are found.") | ||
| override protected def saveImpl(path: String): Unit = { | ||
| val loader = Utils.getContextOrSparkClassLoader | ||
| val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) | ||
| val stageName = stage.getClass.getName | ||
| val targetName = s"$source+$stageName" | ||
| val formats = serviceLoader.asScala.toList | ||
| val shortNames = formats.map(_.shortName()) | ||
| val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { | ||
| // requested name did not match any given registered alias | ||
| case Nil => | ||
| Try(loader.loadClass(source)) match { | ||
| case Success(writer) => | ||
| // Found the ML writer using the fully qualified path | ||
| writer | ||
| case Failure(error) => | ||
| throw new SparkException( | ||
| s"Could not load requested format $source for $stageName ($targetName) had $formats" + | ||
| s"supporting $shortNames", error) | ||
| } | ||
| case head :: Nil => | ||
| head.getClass | ||
| case _ => | ||
| // Multiple sources | ||
| throw new SparkException( | ||
| s"Multiple writers found for $source+$stageName, try using the class name of the writer") | ||
| } | ||
| if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { | ||
| val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will fail, non-intuitively, if anyone ever extends class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormatwill raise
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, we have the same issue with the DataFormat provider though. I don't think there is a way around this while keeping the DF like interface that lets us be pluggable in the way folks want (but if there is a way to require a 0 argument constructor in the concrete class with a trait I'm interested). I think given the folks who we general expect to be writing these formats that reasonable, but I'll add a comment about this in the doc? |
||
| writer.write(path, sparkSession, optionMap, stage) | ||
| } else { | ||
| throw new SparkException(s"ML source $source is not a valid MLWriterFormat") | ||
| } | ||
| } | ||
|
|
||
| // override for Java compatibility | ||
| override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) | ||
|
|
||
|
|
@@ -162,6 +306,19 @@ trait MLWritable { | |
| def save(path: String): Unit = write.save(path) | ||
| } | ||
|
|
||
| /** | ||
| * Trait for classes that provide `GeneralMLWriter`. | ||
| */ | ||
| @Since("2.4.0") | ||
| @InterfaceStability.Unstable | ||
| trait GeneralMLWritable extends MLWritable { | ||
| /** | ||
| * Returns an `MLWriter` instance for this ML instance. | ||
| */ | ||
| @Since("2.4.0") | ||
| override def write: GeneralMLWriter | ||
| } | ||
|
|
||
| /** | ||
| * :: DeveloperApi :: | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 | ||
| org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 | ||
| org.apache.spark.ml.util.FakeLinearRegressionWriterWithName |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc above this is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed