Skip to content

Commit 1e340c3

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel
See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988). Author: Xusen Yin <[email protected]> Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits: cb1ecfa [Xusen Yin] change Assignment into case class b1dd24c [Xusen Yin] add test suite 63c3923 [Xusen Yin] add save load for power iteration clustering
1 parent 6cc5b3e commit 1e340c3

File tree

2 files changed

+97
-5
lines changed

2 files changed

+97
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20-
import org.apache.spark.{Logging, SparkException}
20+
import org.json4s.JsonDSL._
21+
import org.json4s._
22+
import org.json4s.jackson.JsonMethods._
23+
2124
import org.apache.spark.annotation.Experimental
2225
import org.apache.spark.api.java.JavaRDD
2326
import org.apache.spark.graphx._
2427
import org.apache.spark.graphx.impl.GraphImpl
2528
import org.apache.spark.mllib.linalg.Vectors
26-
import org.apache.spark.mllib.util.MLUtils
29+
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
2730
import org.apache.spark.rdd.RDD
31+
import org.apache.spark.sql.{Row, SQLContext}
2832
import org.apache.spark.util.random.XORShiftRandom
33+
import org.apache.spark.{Logging, SparkContext, SparkException}
2934

3035
/**
3136
* :: Experimental ::
@@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
3843
@Experimental
3944
class PowerIterationClusteringModel(
4045
val k: Int,
41-
val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
46+
val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {
47+
48+
override def save(sc: SparkContext, path: String): Unit = {
49+
PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
50+
}
51+
52+
override protected def formatVersion: String = "1.0"
53+
}
54+
55+
object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
56+
override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
57+
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
58+
}
59+
60+
private[clustering]
61+
object SaveLoadV1_0 {
62+
63+
private val thisFormatVersion = "1.0"
64+
65+
private[clustering]
66+
val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
67+
68+
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
69+
val sqlContext = new SQLContext(sc)
70+
import sqlContext.implicits._
71+
72+
val metadata = compact(render(
73+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
74+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
75+
76+
val dataRDD = model.assignments.toDF()
77+
dataRDD.saveAsParquetFile(Loader.dataPath(path))
78+
}
79+
80+
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
81+
implicit val formats = DefaultFormats
82+
val sqlContext = new SQLContext(sc)
83+
84+
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
85+
assert(className == thisClassName)
86+
assert(formatVersion == thisFormatVersion)
87+
88+
val k = (metadata \ "k").extract[Int]
89+
val assignments = sqlContext.parquetFile(Loader.dataPath(path))
90+
Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
91+
92+
val assignmentsRDD = assignments.map {
93+
case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
94+
}
95+
96+
new PowerIterationClusteringModel(k, assignmentsRDD)
97+
}
98+
}
99+
}
42100

43101
/**
44102
* :: Experimental ::
@@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
135193
val v = powerIter(w, maxIterations)
136194
val assignments = kMeans(v, k).mapPartitions({ iter =>
137195
iter.map { case (id, cluster) =>
138-
new Assignment(id, cluster)
196+
Assignment(id, cluster)
139197
}
140198
}, preservesPartitioning = true)
141199
new PowerIterationClusteringModel(k, assignments)
@@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
152210
* @param cluster assigned cluster id
153211
*/
154212
@Experimental
155-
class Assignment(val id: Long, val cluster: Int) extends Serializable
213+
case class Assignment(id: Long, cluster: Int)
156214

157215
/**
158216
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).

mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
package org.apache.spark.mllib.clustering
1919

2020
import scala.collection.mutable
21+
import scala.util.Random
2122

2223
import org.scalatest.FunSuite
2324

25+
import org.apache.spark.SparkContext
2426
import org.apache.spark.graphx.{Edge, Graph}
2527
import org.apache.spark.mllib.util.MLlibTestSparkContext
2628
import org.apache.spark.mllib.util.TestingUtils._
29+
import org.apache.spark.util.Utils
2730

2831
class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
2932

@@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
110113
assert(x ~== u1(i.toInt) absTol 1e-14)
111114
}
112115
}
116+
117+
test("model save/load") {
118+
val tempDir = Utils.createTempDir()
119+
val path = tempDir.toURI.toString
120+
val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
121+
try {
122+
model.save(sc, path)
123+
val sameModel = PowerIterationClusteringModel.load(sc, path)
124+
PowerIterationClusteringSuite.checkEqual(model, sameModel)
125+
} finally {
126+
Utils.deleteRecursively(tempDir)
127+
}
128+
}
129+
}
130+
131+
object PowerIterationClusteringSuite extends FunSuite {
132+
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
133+
val assignments = sc.parallelize(
134+
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
135+
new PowerIterationClusteringModel(k, assignments)
136+
}
137+
138+
def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
139+
assert(a.k === b.k)
140+
141+
val aAssignments = a.assignments.map(x => (x.id, x.cluster))
142+
val bAssignments = b.assignments.map(x => (x.id, x.cluster))
143+
val unequalElements = aAssignments.join(bAssignments).filter {
144+
case (id, (c1, c2)) => c1 != c2 }.count()
145+
assert(unequalElements === 0L)
146+
}
113147
}

0 commit comments

Comments
 (0)