Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.{PipelineStage, PredictorParams}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.BLAS._
Expand All @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -482,7 +483,7 @@ class LinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("1.3.0") val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
with LinearRegressionParams with MLWritable {
with LinearRegressionParams with GeneralMLWritable {

private var trainingSummary: Option[LinearRegressionTrainingSummary] = None

Expand Down Expand Up @@ -554,7 +555,49 @@ class LinearRegressionModel private[ml] (
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this)
override def write: GeneralMLWriter = new GeneralMLWriter(this)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc above this is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

/** A writer for LinearRegression that handles the "internal" (or default) format */
private class InternalLinearRegressionModelWriter()
extends MLWriterFormat with MLFormatRegister {

override def shortName(): String =
"internal+org.apache.spark.ml.regression.LinearRegressionModel"

private case class Data(intercept: Double, coefficients: Vector)

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[LinearRegressionModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

/** A writer for LinearRegression that handles the "pmml" format */
private class PMMLLinearRegressionModelWriter()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong, but I think we prefer just omitting the ()?

extends MLWriterFormat with MLFormatRegister {

override def shortName(): String =
"pmml+org.apache.spark.ml.regression.LinearRegressionModel"

private case class Data(intercept: Double, coefficients: Vector)

override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val sc = sparkSession.sparkContext
// Construct the MLLib model which knows how to write to PMML.
val instance = stage.asInstanceOf[LinearRegressionModel]
val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept)
// Save PMML
oldModel.toPMML(sc, path)
}
}

@Since("1.6.0")
Expand All @@ -566,22 +609,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
@Since("1.6.0")
override def load(path: String): LinearRegressionModel = super.load(path)

/** [[MLWriter]] instance for [[LinearRegressionModel]] */
private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
extends MLWriter with Logging {

private case class Data(intercept: Double, coefficients: Vector)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {

/** Checked against metadata when loading model */
Expand Down
136 changes: 128 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
package org.apache.spark.ml.util

import java.io.IOException
import java.util.Locale
import java.util.{Locale, ServiceLoader}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
Expand Down Expand Up @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
protected final def sc: SparkContext = sparkSession.sparkContext
}

/**
* ML export formats for should implement this trait so that users can specify a shortname rather
* than the fully qualified class name of the exporter.
*
* A new instance of this class will be instantiated each time a DDL call is made.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this supposed to be retained from the DataSourceRegister?

*
* @since 2.3.0
*/
@InterfaceStability.Evolving
trait MLFormatRegister {
/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"data source" -> "model format"?

*
* {{{
* override def shortName(): String =
* "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about making a second abstract field def stageName(): String, instead of having it packed into one string?

* }}}
* Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml.
*
* Format discovery is done using a ServiceLoader so make sure to list your format in
* META-INF/services.
* @since 2.3.0
*/
def shortName(): String
}

/**
* Implemented by objects that provide ML exportability.
*
* A new instance of this class will be instantiated each time a DDL call is made.
*
* @since 2.3.0
*/
@InterfaceStability.Evolving
trait MLWriterFormat {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need the actual since annotations here, though?

/**
* Function write the provided pipeline stage out.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a full doc here with param annotations. Also should it be "Function to write ..."?

*/
def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
stage: PipelineStage)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type?

}

/**
* Abstract class for utility classes that can save ML instances.
*/
@deprecated("Use GeneralMLWriter instead. Will be removed in Spark 3.0.0", "2.3.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debating if this should be deprecated in 2.4 and just have this as a new option in 2.3. What do you think @sethah / @MLnick ?

@Since("1.6.0")
abstract class MLWriter extends BaseReadWrite with Logging {

protected var shouldOverwrite: Boolean = false

/**
Expand All @@ -110,6 +155,15 @@ abstract class MLWriter extends BaseReadWrite with Logging {
@Since("1.6.0")
protected def saveImpl(path: String): Unit

/**
* Overwrites if the output path already exists.
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
this
}

/**
* Map to store extra options for this writer.
*/
Expand All @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging {
this
}

// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since tags here


// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
}

/**
* A ML Writer which delegates based on the requested format.
*/
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps for another PR, but maybe we could add a method here:

  def pmml(path: String): Unit = {
    this.source = "pmml"
    save(path)
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't think that belongs in the base GeneralMLWriter, but we could make a trait for writers which support PMML to mix in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The follow up issue to track this is https://issues.apache.org/jira/browse/SPARK-11241

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need @Since("2.3.0") here?

private var source: String = "internal"

/**
* Overwrites if the output path already exists.
* Specifies the format of ML export (e.g. PMML, internal, or
Copy link
Contributor

@sethah sethah Jan 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to e.g. "pmml", "internal", or the fully qualified class name for export).

* the fully qualified class name for export).
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
@Since("2.3.0")
def format(source: String): this.type = {
this.source = source
this
}

/**
* Dispatches the save to the correct MLFormat.
*/
@Since("2.3.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
@throws[SparkException]("If multiple sources for a given short name format are found.")
override protected def saveImpl(path: String) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type

val loader = Utils.getContextOrSparkClassLoader
val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
val stageName = stage.getClass.getName
val targetName = s"${source}+${stageName}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need brackets

val formats = serviceLoader.asScala.toList
val shortNames = formats.map(_.shortName())
val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
// requested name did not match any given registered alias
case Nil =>
Try(loader.loadClass(source)) match {
case Success(writer) =>
// Found the ML writer using the fully qualified path
writer
case Failure(error) =>
throw new SparkException(
s"Could not load requested format $source for $stageName ($targetName) had $formats" +
s"supporting $shortNames", error)
}
case head :: Nil =>
head.getClass
case _ =>
// Multiple sources
throw new SparkException(
s"Multiple writers found for $source+$stageName, try using the class name of the writer")
}
if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail, non-intuitively, if anyone ever extends MLWriterFormat with a constructor that has more than zero arguments. Meaning:

class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormat

will raise java.lang.NoSuchMethodException: org.apache.spark.ml.regression.DummyLinearRegressionWriter.<init>()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, we have the same issue with the DataFormat provider though. I don't think there is a way around this while keeping the DF like interface that lets us be pluggable in the way folks want (but if there is a way to require a 0 argument constructor in the concrete class with a trait I'm interested).

I think given the folks who we general expect to be writing these formats that reasonable, but I'll add a comment about this in the doc?

writer.write(path, sparkSession, optionMap, stage)
} else {
throw new SparkException("ML source $source is not a valid MLWriterFormat")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: need string interpolation here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I've added a test for this error message.

}
}

// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)

Expand Down Expand Up @@ -162,6 +270,18 @@ trait MLWritable {
def save(path: String): Unit = write.save(path)
}

/**
* Trait for classes that provide `GeneralMLWriter`.
*/
@Since("2.3.0")
trait GeneralMLWritable extends MLWritable {
/**
* Returns an `MLWriter` instance for this ML instance.
*/
@Since("2.3.0")
override def write: GeneralMLWriter
}

/**
* :: DeveloperApi ::
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@

package org.apache.spark.ml.regression

import scala.collection.JavaConverters._
import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils, PMMLReadWriteTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, Row}

class LinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
with PMMLReadWriteTest {

import testImplicits._

Expand Down Expand Up @@ -994,6 +998,38 @@ class LinearRegressionSuite
LinearRegressionSuite.allParamSettings, checkModelData)
}

test("pmml export") {
val lr = new LinearRegression()
val model = lr.fit(datasetWithWeight)
def checkModel(pmml: PMML): Unit = {
val dd = pmml.getDataDictionary
assert(dd.getNumberOfFields === 3)
val fields = dd.getDataFields.asScala
assert(fields(0).getName().toString === "field_0")
assert(fields(0).getOpType() == OpType.CONTINUOUS)
val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList
assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
}
testPMMLWrite(sc, model, checkModel)
}

test("unsupported export format") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have a test that verifies that this works with third party implementations. Specifically, that something like model.write.format("org.apache.spark.ml.MyDummyWriter").save(path) works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll put a dummy writer in test so it doesn't clog up our class space.

val lr = new LinearRegression()
val model = lr.fit(datasetWithWeight)
intercept[SparkException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this and the one below it test the same thing? I think we could remove the first one.

model.write.format("boop").save("boop")
}
intercept[SparkException] {
model.write.format("com.holdenkarau.boop").save("boop")
}
intercept[SparkException] {
model.write.format("org.apache.spark.SparkContext").save("boop2")
}
}

test("should support all NumericType labels and weights, and not support other types") {
for (solver <- Seq("auto", "l-bfgs", "normal")) {
val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
Expand Down
Loading