Skip to content

Commit 26d70bd

Browse files
yu-iskwjkbradley
authored andcommitted
[SPARK-12215][ML][DOC] User guide section for KMeans in spark.ml
cc jkbradley Author: Yu ISHIKAWA <[email protected]> Closes #10244 from yu-iskw/SPARK-12215.
1 parent 22f6cd8 commit 26d70bd

File tree

3 files changed

+100
-28
lines changed

3 files changed

+100
-28
lines changed

docs/ml-clustering.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c
1111
* This will become a table of contents (this text will be scraped).
1212
{:toc}
1313

14+
## K-means
15+
16+
[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the
17+
most commonly used clustering algorithms that clusters the data points into a
18+
predefined number of clusters. The MLlib implementation includes a parallelized
19+
variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method
20+
called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf).
21+
22+
`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model.
23+
24+
### Input Columns
25+
26+
<table class="table">
27+
<thead>
28+
<tr>
29+
<th align="left">Param name</th>
30+
<th align="left">Type(s)</th>
31+
<th align="left">Default</th>
32+
<th align="left">Description</th>
33+
</tr>
34+
</thead>
35+
<tbody>
36+
<tr>
37+
<td>featuresCol</td>
38+
<td>Vector</td>
39+
<td>"features"</td>
40+
<td>Feature vector</td>
41+
</tr>
42+
</tbody>
43+
</table>
44+
45+
### Output Columns
46+
47+
<table class="table">
48+
<thead>
49+
<tr>
50+
<th align="left">Param name</th>
51+
<th align="left">Type(s)</th>
52+
<th align="left">Default</th>
53+
<th align="left">Description</th>
54+
</tr>
55+
</thead>
56+
<tbody>
57+
<tr>
58+
<td>predictionCol</td>
59+
<td>Int</td>
60+
<td>"prediction"</td>
61+
<td>Predicted cluster center</td>
62+
</tr>
63+
</tbody>
64+
</table>
65+
66+
### Example
67+
68+
<div class="codetabs">
69+
70+
<div data-lang="scala" markdown="1">
71+
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details.
72+
73+
{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %}
74+
</div>
75+
76+
<div data-lang="java" markdown="1">
77+
Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details.
78+
79+
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
80+
</div>
81+
82+
</div>
83+
84+
1485
## Latent Dirichlet allocation (LDA)
1586

1687
`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`,

examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@
2323
import org.apache.spark.api.java.JavaRDD;
2424
import org.apache.spark.api.java.JavaSparkContext;
2525
import org.apache.spark.api.java.function.Function;
26+
import org.apache.spark.sql.SQLContext;
27+
import org.apache.spark.sql.catalyst.expressions.GenericRow;
28+
// $example on$
2629
import org.apache.spark.ml.clustering.KMeansModel;
2730
import org.apache.spark.ml.clustering.KMeans;
2831
import org.apache.spark.mllib.linalg.Vector;
2932
import org.apache.spark.mllib.linalg.VectorUDT;
3033
import org.apache.spark.mllib.linalg.Vectors;
3134
import org.apache.spark.sql.DataFrame;
3235
import org.apache.spark.sql.Row;
33-
import org.apache.spark.sql.SQLContext;
34-
import org.apache.spark.sql.catalyst.expressions.GenericRow;
3536
import org.apache.spark.sql.types.Metadata;
3637
import org.apache.spark.sql.types.StructField;
3738
import org.apache.spark.sql.types.StructType;
39+
// $example off$
3840

3941

4042
/**
@@ -74,6 +76,7 @@ public static void main(String[] args) {
7476
JavaSparkContext jsc = new JavaSparkContext(conf);
7577
SQLContext sqlContext = new SQLContext(jsc);
7678

79+
// $example on$
7780
// Loads data
7881
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint());
7982
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
@@ -91,6 +94,7 @@ public static void main(String[] args) {
9194
for (Vector center: centers) {
9295
System.out.println(center);
9396
}
97+
// $example off$
9498

9599
jsc.stop();
96100
}

examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,54 @@
1717

1818
package org.apache.spark.examples.ml
1919

20-
import org.apache.spark.{SparkContext, SparkConf}
21-
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
22-
import org.apache.spark.ml.clustering.KMeans
23-
import org.apache.spark.sql.{Row, SQLContext}
24-
import org.apache.spark.sql.types.{StructField, StructType}
20+
// scalastyle:off println
2521

22+
import org.apache.spark.{SparkConf, SparkContext}
23+
// $example on$
24+
import org.apache.spark.ml.clustering.KMeans
25+
import org.apache.spark.mllib.linalg.Vectors
26+
// $example off$
27+
import org.apache.spark.sql.{DataFrame, SQLContext}
2628

2729
/**
2830
* An example demonstrating a k-means clustering.
2931
* Run with
3032
* {{{
31-
* bin/run-example ml.KMeansExample <file> <k>
33+
* bin/run-example ml.KMeansExample
3234
* }}}
3335
*/
3436
object KMeansExample {
3537

36-
final val FEATURES_COL = "features"
37-
3838
def main(args: Array[String]): Unit = {
39-
if (args.length != 2) {
40-
// scalastyle:off println
41-
System.err.println("Usage: ml.KMeansExample <file> <k>")
42-
// scalastyle:on println
43-
System.exit(1)
44-
}
45-
val input = args(0)
46-
val k = args(1).toInt
47-
4839
// Creates a Spark context and a SQL context
4940
val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}")
5041
val sc = new SparkContext(conf)
5142
val sqlContext = new SQLContext(sc)
5243

53-
// Loads data
54-
val rowRDD = sc.textFile(input).filter(_.nonEmpty)
55-
.map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
56-
val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
57-
val dataset = sqlContext.createDataFrame(rowRDD, schema)
44+
// $example on$
45+
// Crates a DataFrame
46+
val dataset: DataFrame = sqlContext.createDataFrame(Seq(
47+
(1, Vectors.dense(0.0, 0.0, 0.0)),
48+
(2, Vectors.dense(0.1, 0.1, 0.1)),
49+
(3, Vectors.dense(0.2, 0.2, 0.2)),
50+
(4, Vectors.dense(9.0, 9.0, 9.0)),
51+
(5, Vectors.dense(9.1, 9.1, 9.1)),
52+
(6, Vectors.dense(9.2, 9.2, 9.2))
53+
)).toDF("id", "features")
5854

5955
// Trains a k-means model
6056
val kmeans = new KMeans()
61-
.setK(k)
62-
.setFeaturesCol(FEATURES_COL)
57+
.setK(2)
58+
.setFeaturesCol("features")
59+
.setPredictionCol("prediction")
6360
val model = kmeans.fit(dataset)
6461

6562
// Shows the result
66-
// scalastyle:off println
6763
println("Final Centers: ")
6864
model.clusterCenters.foreach(println)
69-
// scalastyle:on println
65+
// $example off$
7066

7167
sc.stop()
7268
}
7369
}
70+
// scalastyle:on println

0 commit comments

Comments
 (0)