-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15784][ML]:Add Power Iteration Clustering to spark.ml #15770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
b80bb1f
75004e8
e1d9a33
f8343e0
c62a2c0
1277f75
f50873d
88a9ae0
0618815
04fddbd
b49f4c7
d3f86d0
655bc67
d5975bc
f012624
bef0594
a4bee89
0f97907
015383a
2d29570
af549e8
9b4f3d5
e35fe54
73485d8
bd5ca5d
3b0f71c
752b685
cfa18af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.clustering | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.Transformer | ||
| import org.apache.spark.ml.linalg.Vector | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} | ||
| import org.apache.spark.mllib.clustering.PowerIterationClustering.Assignment | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions.col | ||
| import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} | ||
|
|
||
| /** | ||
| * Common params for PowerIterationClustering | ||
| */ | ||
| private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter | ||
| with HasFeaturesCol with HasPredictionCol with HasWeightCol { | ||
|
|
||
| /** | ||
| * The number of clusters to create (k). Must be > 1. Default: 2. | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| final val k = new IntParam(this, "k", "The number of clusters to create. " + | ||
|
||
| "Must be > 1.", ParamValidators.gt(1)) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getK: Int = $(k) | ||
|
|
||
| /** | ||
| * Param for the initialization algorithm. This can be either "random" to use a random vector | ||
| * as vertex properties, or "degree" to use normalized sum similarities. Default: random. | ||
| */ | ||
| @Since("2.3.0") | ||
| final val initMode = { | ||
| val allowedParams = ParamValidators.inArray(Array("random", "degree")) | ||
| new Param[String](this, "initMode", "The initialization algorithm. " + | ||
| "Supported options: 'random' and 'degree'.", allowedParams) | ||
| } | ||
|
|
||
| /** @group expertGetParam */ | ||
| @Since("2.3.0") | ||
| def getInitMode: String = $(initMode) | ||
|
|
||
| /** | ||
| * Param for the column name for ids returned by PowerIterationClustering.transform(). | ||
| * Default: "id" | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| val idCol = new Param[String](this, "id", "column name for ids.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getIdCol: String = $(idCol) | ||
|
|
||
| /** | ||
| * Param for the column name for neighbors required by PowerIterationClustering.transform(). | ||
| * Default: "neighbor" | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| val neighborCol = new Param[String](this, "neighbor", "column name for neighbors.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getNeighborCol: String = $(neighborCol) | ||
|
|
||
| /** | ||
| * Validates the input schema | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: No need for doc like this which is explained by the method title |
||
| * @param schema input schema | ||
| */ | ||
| protected def validateSchema(schema: StructType): Unit = { | ||
|
||
| SchemaUtils.checkColumnType(schema, $(idCol), LongType) | ||
| SchemaUtils.checkColumnType(schema, $(predictionCol), IntegerType) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by | ||
| * <a href=http://www.icml2010.org/papers/387.pdf>Lin and Cohen</a>. From the abstract: | ||
| * PIC finds a very low-dimensional embedding of a dataset using truncated power | ||
| * iteration on a normalized pair-wise similarity matrix of the data. | ||
| * | ||
| * Note that we implement [[PowerIterationClustering]] as a transformer. The [[transform]] is an | ||
| * expensive operation, because it uses PIC algorithm to cluster the whole input dataset. | ||
| * | ||
| * @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering> | ||
| * Spectral clustering (Wikipedia)</a> | ||
| */ | ||
| @Since("2.3.0") | ||
| @Experimental | ||
| class PowerIterationClustering private[clustering] ( | ||
| @Since("2.3.0") override val uid: String) | ||
| extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { | ||
|
|
||
| setDefault( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: It'd be nice to put these defaults right next the Param definitions in PowerIterationClusteringParams so that the default specified in the docstring is close to the default specified by setDefault (to make sure they stay in sync). |
||
| k -> 2, | ||
| maxIter -> 20, | ||
| initMode -> "random", | ||
| idCol -> "id", | ||
| weightCol -> "weight", | ||
| neighborCol -> "neighbor") | ||
|
|
||
| @Since("2.3.0") | ||
| override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) | ||
|
|
||
| @Since("2.3.0") | ||
| def this() = this(Identifiable.randomUID("PowerIterationClustering")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Put constructors first, before other methods (like copy), to match the style of the rest of MLlib. |
||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setFeaturesCol(value: String): this.type = set(featuresCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setPredictionCol(value: String): this.type = set(predictionCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setK(value: Int): this.type = set(k, value) | ||
|
|
||
| /** @group expertSetParam */ | ||
| @Since("2.3.0") | ||
| def setInitMode(value: String): this.type = set(initMode, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setIdCol(value: String): this.type = set(idCol, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * Default is "weight" | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[neighborCol]]. | ||
| * Default is "neighbor" | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.3.0") | ||
| def setNeighborCol(value: String): this.type = set(neighborCol, value) | ||
|
|
||
| @Since("2.3.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| val sparkSession = dataset.sparkSession | ||
| val rdd: RDD[(Long, Long, Double)] = | ||
| dataset.select(col($(idCol)), col($(neighborCol)), col($(weightCol))).rdd.flatMap { | ||
| case Row(id: Long, nbr: Vector, weight: Vector) => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PIC require input graph matrix to be symmetric, and the weight should be non-negative. It is better to check them here. But checking symmetric seems cost too much, I have no good idea for now. cc @jkbradley Do you have some thoughts ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think checking symmetric is too much for PIC in this data format. Maybe, we can omit the check and put a comment and INFO on console to let users take care of it. @WeichenXu123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I agree.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree about not checking for symmetry as long as we document it. But I do have one suggestion: Let's take neighbors and weights as Arrays, not Vectors. That may help prevent users from mistakenly passing in feature Vectors. |
||
| require(nbr.size == weight.size, | ||
| "The length of neighbor list must be equal to the the length of the weight list.") | ||
| val ids = Array.fill(nbr.size)(id) | ||
| for (i <- 0 until ids.size) yield (ids(i), nbr(i).toLong, weight(i))} | ||
|
||
| val algorithm = new MLlibPowerIterationClustering() | ||
| .setK($(k)) | ||
| .setInitializationMode($(initMode)) | ||
| .setMaxIterations($(maxIter)) | ||
| val model = algorithm.run(rdd) | ||
|
|
||
| val rows: RDD[Row] = model.assignments.map { | ||
| case assignment: Assignment => Row(assignment.id, assignment.cluster) | ||
| } | ||
|
|
||
| val schema = transformSchema(new StructType(Array(StructField($(idCol), LongType), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not need to explicitly create a schema here. |
||
| StructField($(predictionCol), IntegerType)))) | ||
| val result = sparkSession.createDataFrame(rows, schema) | ||
|
|
||
| dataset.join(result, "id") | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| validateSchema(schema) | ||
| schema | ||
| } | ||
|
|
||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { | ||
|
|
||
| @Since("2.3.0") | ||
| override def load(path: String): PowerIterationClustering = super.load(path) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.clustering | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} | ||
|
|
||
| class PowerIterationClusteringSuite extends SparkFunSuite | ||
| with MLlibTestSparkContext with DefaultReadWriteTest { | ||
|
|
||
| @transient var data: Dataset[_] = _ | ||
| @transient var malData: Dataset[_] = _ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used |
||
| final val r1 = 1.0 | ||
| final val n1 = 10 | ||
| final val r2 = 4.0 | ||
| final val n2 = 40 | ||
|
|
||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
|
|
||
| data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) | ||
| } | ||
|
|
||
| test("default parameters") { | ||
| val pic = new PowerIterationClustering() | ||
|
|
||
| assert(pic.getK === 2) | ||
| assert(pic.getMaxIter === 20) | ||
| assert(pic.getInitMode === "random") | ||
| assert(pic.getFeaturesCol === "features") | ||
| assert(pic.getPredictionCol === "prediction") | ||
| assert(pic.getIdCol === "id") | ||
| assert(pic.getWeightCol === "weight") | ||
| assert(pic.getNeighborCol === "neighbor") | ||
| } | ||
|
|
||
| test("set parameters") { | ||
| val pic = new PowerIterationClustering() | ||
| .setK(9) | ||
| .setMaxIter(33) | ||
| .setInitMode("degree") | ||
| .setFeaturesCol("test_feature") | ||
| .setPredictionCol("test_prediction") | ||
| .setIdCol("test_id") | ||
| .setWeightCol("test_weight") | ||
| .setNeighborCol("test_neighbor") | ||
|
|
||
| assert(pic.getK === 9) | ||
| assert(pic.getMaxIter === 33) | ||
| assert(pic.getInitMode === "degree") | ||
| assert(pic.getFeaturesCol === "test_feature") | ||
| assert(pic.getPredictionCol === "test_prediction") | ||
| assert(pic.getIdCol === "test_id") | ||
| assert(pic.getWeightCol === "test_weight") | ||
| assert(pic.getNeighborCol === "test_neighbor") | ||
| } | ||
|
|
||
| test("parameters validation") { | ||
| intercept[IllegalArgumentException] { | ||
| new PowerIterationClustering().setK(1) | ||
| } | ||
| intercept[IllegalArgumentException] { | ||
| new PowerIterationClustering().setInitMode("no_such_a_mode") | ||
| } | ||
| } | ||
|
|
||
| test("power iteration clustering") { | ||
|
||
| val n = n1 + n2 | ||
|
|
||
| val model = new PowerIterationClustering() | ||
| .setK(2) | ||
| .setMaxIter(40) | ||
| val result = model.transform(data) | ||
|
|
||
| val predictions = Array.fill(2)(mutable.Set.empty[Long]) | ||
| result.select("id", "prediction").collect().foreach { | ||
| case Row(id: Long, cluster: Integer) => predictions(cluster) += id | ||
| } | ||
| assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) | ||
|
|
||
| val result2 = new PowerIterationClustering() | ||
| .setK(2) | ||
| .setMaxIter(10) | ||
| .setInitMode("degree") | ||
| .transform(data) | ||
| val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) | ||
| result2.select("id", "prediction").collect().foreach { | ||
| case Row(id: Long, cluster: Integer) => predictions2(cluster) += id | ||
| } | ||
| assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) | ||
|
|
||
| val expectedColumns = Array("id", "prediction") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to check this since it's already checks above by result2.select(...) |
||
| expectedColumns.foreach { column => | ||
| assert(result2.columns.contains(column)) | ||
| } | ||
| } | ||
|
|
||
| test("read/write") { | ||
| val t = new PowerIterationClustering() | ||
| .setK(4) | ||
| .setMaxIter(100) | ||
| .setInitMode("degree") | ||
| .setFeaturesCol("test_feature") | ||
| .setPredictionCol("test_prediction") | ||
| .setIdCol("test_id") | ||
| testDefaultReadWrite(t) | ||
| } | ||
| } | ||
|
|
||
| object PowerIterationClusteringSuite { | ||
|
|
||
| case class TestRow2(id: Long, neighbor: Vector, weight: Vector) | ||
| /** Generates a circle of points. */ | ||
| private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { | ||
| Array.tabulate(n) { i => | ||
| val theta = 2.0 * math.Pi * i / n | ||
| (r * math.cos(theta), r * math.sin(theta)) | ||
| } | ||
| } | ||
|
|
||
| /** Computes Gaussian similarity. */ | ||
| private def sim(x: (Double, Double), y: (Double, Double)): Double = { | ||
| val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) | ||
| math.exp(-dist2 / 2.0) | ||
| } | ||
|
|
||
| def generatePICData(spark: SparkSession, r1: Double, r2: Double, | ||
| n1: Int, n2: Int): DataFrame = { | ||
| // Generate two circles following the example in the PIC paper. | ||
| val n = n1 + n2 | ||
| val points = genCircle(r1, n1) ++ genCircle(r2, n2) | ||
|
|
||
| val similarities = for (i <- 1 until n) yield { | ||
| val neighbor = for (j <- 0 until i) yield { | ||
| j.toLong | ||
| } | ||
| val weight = for (j <- 0 until i) yield { | ||
| sim(points(i), points(j)) | ||
| } | ||
| (i.toLong, neighbor.toArray, weight.toArray) | ||
| } | ||
|
|
||
| val sc = spark.sparkContext | ||
|
|
||
| val rdd = sc.parallelize(similarities).map{ | ||
| case (id: Long, nbr: Array[Long], weight: Array[Double]) => | ||
| TestRow2(id, Vectors.dense(nbr.map(i => i.toDouble)), Vectors.dense(weight))} | ||
| spark.createDataFrame(rdd) | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use weightCol, which is for instance weights, not for this kind of adjacency. Let's add a new Param here, perhaps called neighborWeightCol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, featuresCol is not used, so it should be removed.