Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 171 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@ import java.{util => ju}
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

import org.apache.spark.Logging
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.Reader
import org.apache.spark.ml.util.Writer
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
* an identity transformer.
*/
@Experimental
class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {

def this() = this(Identifiable.randomUID("pipeline"))

Expand Down Expand Up @@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
"Cannot have duplicate components in a pipeline.")
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
}

override def write: Writer = new Pipeline.PipelineWriter(this)
}

object Pipeline extends Readable[Pipeline] {

override def read: Reader[Pipeline] = new PipelineReader

override def load(path: String): Pipeline = read.load(path)

private[ml] class PipelineWriter(instance: Pipeline) extends Writer {

SharedReadWrite.validateStages(instance.getStages)

override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
}

private[ml] class PipelineReader extends Reader[Pipeline] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.Pipeline"

override def load(path: String): Pipeline = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
new Pipeline(uid).setStages(stages)
}
}

/** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */
private[ml] object SharedReadWrite {

import org.json4s.JsonDSL._

/** Check that all stages are Writable */
def validateStages(stages: Array[PipelineStage]): Unit = {
stages.foreach {
case stage: Writable => // good
case other =>
throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
s" ${other.uid} of type ${other.getClass}")
}
}

/**
* Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* - save metadata to path/metadata
* - save stages to stages/IDX_UID
*/
def saveImpl(
instance: Params,
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
// Copied and edited from DefaultParamsWriter.saveMetadata
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
val uid = instance.uid
val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)

// Save stages
val stagesDir = new Path(path, "stages").toString
stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
}
}

/**
* Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* @return (UID, list of stages)
*/
def load(
expectedClassName: String,
sc: SparkContext,
path: String): (String, Array[PipelineStage]) = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)

implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = metadata.params match {
case JObject(pairs) =>
if (pairs.length != 1) {
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
}
pairs.head match {
case ("stageUids", jsonValue) =>
jsonValue.extract[Seq[String]].toArray
case (paramName, jsonValue) =>
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
s" in metadata: ${metadata.metadataStr}")
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
val cls = Utils.classForName(stageMetadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
}
(metadata.uid, stages)
}

/** Get path for saving the given stage. */
def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
val stageIdxDigits = numStages.toString.length
val idxFormat = s"%0${stageIdxDigits}d"
val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
new Path(stagesDir, stageDir).toString
}
}
}

/**
Expand All @@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
class PipelineModel private[ml] (
override val uid: String,
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
extends Model[PipelineModel] with Writable with Logging {

/** A Java/Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
Expand All @@ -200,4 +332,39 @@ class PipelineModel private[ml] (
override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}

override def write: Writer = new PipelineModel.PipelineModelWriter(this)
}

object PipelineModel extends Readable[PipelineModel] {

import Pipeline.SharedReadWrite

override def read: Reader[PipelineModel] = new PipelineModelReader

override def load(path: String): PipelineModel = read.load(path)

private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {

SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])

override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
}

private[ml] class PipelineModelReader extends Reader[PipelineModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.PipelineModel"

override def load(path: String): PipelineModel = {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val transformers = stages map {
case stage: Transformer => stage
case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}")
}
new PipelineModel(uid, transformers)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ trait Readable[T] {

/**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
*
* Note: Implementing classes should override this to be Java-friendly.
*/
@Since("1.6.0")
def load(path: String): T = read.load(path)
Expand All @@ -187,7 +189,7 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid
Expand Down
120 changes: 117 additions & 3 deletions mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@

package org.apache.spark.ml

import java.io.File

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.{FileSystem, Path}
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

class PipelineSuite extends SparkFunSuite {
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

abstract class MyModel extends Model[MyModel]

Expand Down Expand Up @@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
assert(pipelineModel1.uid === "pipeline1")
assert(pipelineModel1.stages === stages)
}

test("Pipeline read/write") {
val writableStage = new WritableStage("writableStage").setIntParam(56)
val pipeline = new Pipeline().setStages(Array(writableStage))

val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
assert(pipeline2.getStages.length === 1)
assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)
}

test("Pipeline read/write with non-Writable stage") {
val unWritableStage = new UnWritableStage("unwritableStage")
val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
intercept[UnsupportedOperationException] {
unWritablePipeline.write
}
}
}

test("PipelineModel read/write") {
val writableStage = new WritableStage("writableStage").setIntParam(56)
val pipeline =
new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))

val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
assert(pipeline2.stages.length === 1)
assert(pipeline2.stages(0).isInstanceOf[WritableStage])
val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
assert(writableStage.getIntParam === writableStage2.getIntParam)

val path = new File(tempDir, pipeline.uid).getPath
val stagesDir = new Path(path, "stages").toString
val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}" +
s" to be saved to path: $expectedStagePath")
}

test("PipelineModel read/write: getStagePath") {
val stageUid = "myStage"
val stagesDir = new Path("pipeline", "stages").toString
def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
assert(path === expected)
}
testStage(0, 1, "0")
testStage(0, 9, "0")
testStage(0, 10, "00")
testStage(1, 10, "01")
testStage(12, 999, "012")
}

test("PipelineModel read/write with non-Writable stage") {
val unWritableStage = new UnWritableStage("unwritableStage")
val unWritablePipeline =
new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") {
intercept[UnsupportedOperationException] {
unWritablePipeline.write
}
}
}
}


/** Used to test [[Pipeline]] with [[Writable]] stages */
class WritableStage(override val uid: String) extends Transformer with Writable {

final val intParam: IntParam = new IntParam(this, "intParam", "doc")

def getIntParam: Int = $(intParam)

def setIntParam(value: Int): this.type = set(intParam, value)

setDefault(intParam -> 0)

override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)

override def write: Writer = new DefaultParamsWriter(this)

override def transform(dataset: DataFrame): DataFrame = dataset

override def transformSchema(schema: StructType): StructType = schema
}

object WritableStage extends Readable[WritableStage] {

override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]

override def load(path: String): WritableStage = read.load(path)
}

/** Used to test [[Pipeline]] with non-[[Writable]] stages */
class UnWritableStage(override val uid: String) extends Transformer {

final val intParam: IntParam = new IntParam(this, "intParam", "doc")

setDefault(intParam -> 0)

override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)

override def transform(dataset: DataFrame): DataFrame = dataset

override def transformSchema(schema: StructType): StructType = schema
}
Loading