Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.{PipelineStage, PredictorParams}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.BLAS._
Expand All @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] (
@Since("1.3.0") val intercept: Double,
@Since("2.3.0") val scale: Double)
extends RegressionModel[Vector, LinearRegressionModel]
with LinearRegressionParams with MLWritable {
with LinearRegressionParams with GeneralMLWritable {

private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
this(uid, coefficients, intercept, 1.0)
Expand Down Expand Up @@ -710,15 +711,58 @@ class LinearRegressionModel private[ml] (
}

/**
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
*
* For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this)
override def write: GeneralMLWriter = new GeneralMLWriter(this)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc above this is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

/** A writer for LinearRegression that handles the "internal" (or default) format */
private class InternalLinearRegressionModelWriter
extends MLWriterFormat with MLFormatRegister {

override def format(): String = "internal"
override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"

private case class Data(intercept: Double, coefficients: Vector, scale: Double)

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[LinearRegressionModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients, scale
val data = Data(instance.intercept, instance.coefficients, instance.scale)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

/** A writer for LinearRegression that handles the "pmml" format */
private class PMMLLinearRegressionModelWriter
extends MLWriterFormat with MLFormatRegister {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be two space indentation
extends MLWriterFormat with MLFormatRegister {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out, I'll fix it in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've included this in #20907

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


override def format(): String = "pmml"

override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"

private case class Data(intercept: Double, coefficients: Vector)

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val sc = sparkSession.sparkContext
// Construct the MLLib model which knows how to write to PMML.
val instance = stage.asInstanceOf[LinearRegressionModel]
val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept)
// Save PMML
oldModel.toPMML(sc, path)
}
}

@Since("1.6.0")
Expand All @@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
@Since("1.6.0")
override def load(path: String): LinearRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[LinearRegressionModel]] */
private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
extends MLWriter with Logging {

private case class Data(intercept: Double, coefficients: Vector, scale: Double)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients, scale
val data = Data(instance.intercept, instance.coefficients, instance.scale)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {

/** Checked against metadata when loading model */
Expand Down
173 changes: 165 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
package org.apache.spark.ml.util

import java.io.IOException
import java.util.Locale
import java.util.{Locale, ServiceLoader}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
Expand Down Expand Up @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite {
}

/**
* Abstract class for utility classes that can save ML instances.
* Abstract class to be implemented by objects that provide ML exportability.
*
* A new instance of this class will be instantiated each time a save call is made.
*
* Must have a valid zero argument constructor which will be called to instantiate.
*
* @since 2.4.0
*/
@InterfaceStability.Unstable
@Since("2.4.0")
trait MLWriterFormat {
/**
* Function to write the provided pipeline stage out.
*
* @param path The path to write the result out to.
* @param session SparkSession associated with the write request.
* @param optionMap User provided options stored as strings.
* @param stage The pipeline stage to be saved.
*/
@Since("2.4.0")
def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
stage: PipelineStage): Unit
}

/**
* ML export formats for should implement this trait so that users can specify a shortname rather
* than the fully qualified class name of the exporter.
*
* A new instance of this class will be instantiated each time a save call is made.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about zero arg constructor requirement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*
* @since 2.4.0
*/
@InterfaceStability.Unstable
@Since("2.4.0")
trait MLFormatRegister extends MLWriterFormat {
/**
* The string that represents the format that this format provider uses. This is, along with
* stageName, is overridden by children to provide a nice alias for the writer. For example:
*
* {{{
* override def format(): String =
* "pmml"
* }}}
* Indicates that this format is capable of saving a pmml model.
*
* Must have a valid zero argument constructor which will be called to instantiate.
*
* Format discovery is done using a ServiceLoader so make sure to list your format in
* META-INF/services.
* @since 2.4.0
*/
@Since("2.4.0")
def format(): String

/**
* The string that represents the stage type that this writer supports. This is, along with
* format, is overridden by children to provide a nice alias for the writer. For example:
*
* {{{
* override def stageName(): String =
* "org.apache.spark.ml.regression.LinearRegressionModel"
* }}}
* Indicates that this format is capable of saving Spark's own PMML model.
*
* Format discovery is done using a ServiceLoader so make sure to list your format in
* META-INF/services.
* @since 2.4.0
*/
@Since("2.4.0")
def stageName(): String

private[ml] def shortName(): String = s"${format()}+${stageName()}"
}

/**
* Abstract class for utility classes that can save ML instances in Spark's internal format.
*/
@Since("1.6.0")
abstract class MLWriter extends BaseReadWrite with Logging {
Expand All @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging {
@Since("1.6.0")
protected def saveImpl(path: String): Unit

/**
* Overwrites if the output path already exists.
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
this
}

/**
* Map to store extra options for this writer.
*/
Expand All @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging {
this
}

// override for Java compatibility
@Since("1.6.0")
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

// override for Java compatibility
@Since("1.6.0")
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
* A ML Writer which delegates based on the requested format.
*/
@InterfaceStability.Unstable
@Since("2.4.0")
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
private var source: String = "internal"

/**
* Overwrites if the output path already exists.
* Specifies the format of ML export (e.g. "pmml", "internal", or
* the fully qualified class name for export).
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
@Since("2.4.0")
def format(source: String): this.type = {
this.source = source
this
}

/**
* Dispatches the save to the correct MLFormat.
*/
@Since("2.4.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
@throws[SparkException]("If multiple sources for a given short name format are found.")
override protected def saveImpl(path: String): Unit = {
val loader = Utils.getContextOrSparkClassLoader
val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
val stageName = stage.getClass.getName
val targetName = s"$source+$stageName"
val formats = serviceLoader.asScala.toList
val shortNames = formats.map(_.shortName())
val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
// requested name did not match any given registered alias
case Nil =>
Try(loader.loadClass(source)) match {
case Success(writer) =>
// Found the ML writer using the fully qualified path
writer
case Failure(error) =>
throw new SparkException(
s"Could not load requested format $source for $stageName ($targetName) had $formats" +
s"supporting $shortNames", error)
}
case head :: Nil =>
head.getClass
case _ =>
// Multiple sources
throw new SparkException(
s"Multiple writers found for $source+$stageName, try using the class name of the writer")
}
if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail, non-intuitively, if anyone ever extends MLWriterFormat with a constructor that has more than zero arguments. Meaning:

class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormat

will raise java.lang.NoSuchMethodException: org.apache.spark.ml.regression.DummyLinearRegressionWriter.<init>()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, we have the same issue with the DataFormat provider though. I don't think there is a way around this while keeping the DF like interface that lets us be pluggable in the way folks want (but if there is a way to require a 0 argument constructor in the concrete class with a trait I'm interested).

I think given the folks who we general expect to be writing these formats that reasonable, but I'll add a comment about this in the doc?

writer.write(path, sparkSession, optionMap, stage)
} else {
throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
}
}

// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

Expand Down Expand Up @@ -162,6 +306,19 @@ trait MLWritable {
def save(path: String): Unit = write.save(path)
}

/**
* Trait for classes that provide `GeneralMLWriter`.
*/
@Since("2.4.0")
@InterfaceStability.Unstable
trait GeneralMLWritable extends MLWritable {
/**
* Returns an `MLWriter` instance for this ML instance.
*/
@Since("2.4.0")
override def write: GeneralMLWriter
}

/**
* :: DeveloperApi ::
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
org.apache.spark.ml.util.DuplicateLinearRegressionWriter1
org.apache.spark.ml.util.DuplicateLinearRegressionWriter2
org.apache.spark.ml.util.FakeLinearRegressionWriterWithName
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@

package org.apache.spark.ml.regression

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel}

import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.{DataFrame, Row}

class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {

class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {

import testImplicits._

Expand Down Expand Up @@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
LinearRegressionSuite.allParamSettings, checkModelData)
}

test("pmml export") {
val lr = new LinearRegression()
val model = lr.fit(datasetWithWeight)
def checkModel(pmml: PMML): Unit = {
val dd = pmml.getDataDictionary
assert(dd.getNumberOfFields === 3)
val fields = dd.getDataFields.asScala
assert(fields(0).getName().toString === "field_0")
assert(fields(0).getOpType() == OpType.CONTINUOUS)
val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList
assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
}
testPMMLWrite(sc, model, checkModel)
}

test("should support all NumericType labels and weights, and not support other types") {
for (solver <- Seq("auto", "l-bfgs", "normal")) {
val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
Expand Down
Loading