1717
1818package org .apache .spark .mllib .clustering
1919
20- import org .apache .spark .{Logging , SparkException }
20+ import org .json4s .JsonDSL ._
21+ import org .json4s ._
22+ import org .json4s .jackson .JsonMethods ._
23+
2124import org .apache .spark .annotation .Experimental
2225import org .apache .spark .api .java .JavaRDD
2326import org .apache .spark .graphx ._
2427import org .apache .spark .graphx .impl .GraphImpl
2528import org .apache .spark .mllib .linalg .Vectors
26- import org .apache .spark .mllib .util .MLUtils
29+ import org .apache .spark .mllib .util .{ Loader , MLUtils , Saveable }
2730import org .apache .spark .rdd .RDD
31+ import org .apache .spark .sql .{Row , SQLContext }
2832import org .apache .spark .util .random .XORShiftRandom
33+ import org .apache .spark .{Logging , SparkContext , SparkException }
2934
3035/**
3136 * :: Experimental ::
@@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
3843@ Experimental
3944class PowerIterationClusteringModel (
4045 val k : Int ,
41- val assignments : RDD [PowerIterationClustering .Assignment ]) extends Serializable
46+ val assignments : RDD [PowerIterationClustering .Assignment ]) extends Saveable with Serializable {
47+
48+ override def save (sc : SparkContext , path : String ): Unit = {
49+ PowerIterationClusteringModel .SaveLoadV1_0 .save(sc, this , path)
50+ }
51+
52+ override protected def formatVersion : String = " 1.0"
53+ }
54+
55+ object PowerIterationClusteringModel extends Loader [PowerIterationClusteringModel ] {
56+ override def load (sc : SparkContext , path : String ): PowerIterationClusteringModel = {
57+ PowerIterationClusteringModel .SaveLoadV1_0 .load(sc, path)
58+ }
59+
60+ private [clustering]
61+ object SaveLoadV1_0 {
62+
63+ private val thisFormatVersion = " 1.0"
64+
65+ private [clustering]
66+ val thisClassName = " org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
67+
68+ def save (sc : SparkContext , model : PowerIterationClusteringModel , path : String ): Unit = {
69+ val sqlContext = new SQLContext (sc)
70+ import sqlContext .implicits ._
71+
72+ val metadata = compact(render(
73+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~ (" k" -> model.k)))
74+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(Loader .metadataPath(path))
75+
76+ val dataRDD = model.assignments.toDF()
77+ dataRDD.saveAsParquetFile(Loader .dataPath(path))
78+ }
79+
80+ def load (sc : SparkContext , path : String ): PowerIterationClusteringModel = {
81+ implicit val formats = DefaultFormats
82+ val sqlContext = new SQLContext (sc)
83+
84+ val (className, formatVersion, metadata) = Loader .loadMetadata(sc, path)
85+ assert(className == thisClassName)
86+ assert(formatVersion == thisFormatVersion)
87+
88+ val k = (metadata \ " k" ).extract[Int ]
89+ val assignments = sqlContext.parquetFile(Loader .dataPath(path))
90+ Loader .checkSchema[PowerIterationClustering .Assignment ](assignments.schema)
91+
92+ val assignmentsRDD = assignments.map {
93+ case Row (id : Long , cluster : Int ) => PowerIterationClustering .Assignment (id, cluster)
94+ }
95+
96+ new PowerIterationClusteringModel (k, assignmentsRDD)
97+ }
98+ }
99+ }
42100
43101/**
44102 * :: Experimental ::
@@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
135193 val v = powerIter(w, maxIterations)
136194 val assignments = kMeans(v, k).mapPartitions({ iter =>
137195 iter.map { case (id, cluster) =>
138- new Assignment (id, cluster)
196+ Assignment (id, cluster)
139197 }
140198 }, preservesPartitioning = true )
141199 new PowerIterationClusteringModel (k, assignments)
@@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
152210 * @param cluster assigned cluster id
153211 */
154212 @ Experimental
155- class Assignment (val id : Long , val cluster : Int ) extends Serializable
213+ case class Assignment (id : Long , cluster : Int )
156214
157215 /**
158216 * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
0 commit comments