-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-11689] [ML] Add user guide and example code for LDA under spark.ml #9722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aeea790
09b59de
8a6d2d6
195c85c
f737bff
5e22624
c794096
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| --- | ||
| layout: global | ||
| title: Clustering - ML | ||
| displayTitle: <a href="ml-guide.html">ML</a> - Clustering | ||
| --- | ||
|
|
||
| In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). | ||
|
|
||
| ## Latent Dirichlet allocation (LDA) | ||
|
|
||
| `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, | ||
| and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by | ||
| `EMLDAOptimizer` to a `DistributedLDAModel` if needed. | ||
|
|
||
| <div class="codetabs"> | ||
|
|
||
| Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. | ||
|
|
||
| <div data-lang="scala" markdown="1"> | ||
| {% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} | ||
| </div> | ||
|
|
||
| </div> | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., | |
| provide class probabilities, and linear models provide model summaries. | ||
|
|
||
| * [Feature extraction, transformation, and selection](ml-features.html) | ||
| * [Clustering](ml-clustering.html) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just noticed that "Feature extraction..." is not alphabetized (sorry about my earlier comment!).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's quite all right. |
||
| * [Decision Trees for classification and regression](ml-decision-tree.html) | ||
| * [Ensembles](ml-ensembles.html) | ||
| * [Linear methods with elastic net regularization](ml-linear-methods.html) | ||
|
|
@@ -950,4 +951,4 @@ model.transform(test) | |
| {% endhighlight %} | ||
| </div> | ||
|
|
||
| </div> | ||
| </div> | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml; | ||
|
|
||
| import java.util.regex.Pattern; | ||
|
|
||
| import org.apache.spark.SparkConf; | ||
| import org.apache.spark.api.java.JavaRDD; | ||
| import org.apache.spark.api.java.JavaSparkContext; | ||
| import org.apache.spark.api.java.function.Function; | ||
| import org.apache.spark.ml.clustering.LDA; | ||
| import org.apache.spark.ml.clustering.LDAModel; | ||
| import org.apache.spark.mllib.linalg.Vector; | ||
| import org.apache.spark.mllib.linalg.VectorUDT; | ||
| import org.apache.spark.mllib.linalg.Vectors; | ||
| import org.apache.spark.sql.DataFrame; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.SQLContext; | ||
| import org.apache.spark.sql.catalyst.expressions.GenericRow; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
|
|
||
| /** | ||
| * An example demonstrating LDA | ||
| * Run with | ||
| * <pre> | ||
| * bin/run-example ml.JavaLDAExample | ||
| * </pre> | ||
| */ | ||
| public class JavaLDAExample { | ||
|
|
||
| private static class ParseVector implements Function<String, Row> { | ||
| private static final Pattern separator = Pattern.compile(" "); | ||
|
|
||
| @Override | ||
| public Row call(String line) { | ||
| String[] tok = separator.split(line); | ||
| double[] point = new double[tok.length]; | ||
| for (int i = 0; i < tok.length; ++i) { | ||
| point[i] = Double.parseDouble(tok[i]); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're expecting a text file containing count vectors here? Seems a bit odd. IMO an example taking a document of text and using pipelines to generate the features would be more natural, e.g. https://gist.github.com/feynmanliang/3b6555758a27adcb527d
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the scala one. For the java example I keep it as it is in the mllib model.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be confusing when a reader of the docs gets two different examples after flipping between languages. I'm really sorry, but do you mind changing it back so that they match (we can keep the examples using count vectors).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
| } | ||
| Vector[] points = {Vectors.dense(point)}; | ||
| return new GenericRow(points); | ||
| } | ||
| } | ||
|
|
||
| public static void main(String[] args) { | ||
|
|
||
| String inputFile = "data/mllib/sample_lda_data.txt"; | ||
|
|
||
| // Parses the arguments | ||
| SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); | ||
| JavaSparkContext jsc = new JavaSparkContext(conf); | ||
| SQLContext sqlContext = new SQLContext(jsc); | ||
|
|
||
| // Loads data | ||
| JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParseVector()); | ||
| StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; | ||
| StructType schema = new StructType(fields); | ||
| DataFrame dataset = sqlContext.createDataFrame(points, schema); | ||
|
|
||
| // Trains a LDA model | ||
| LDA lda = new LDA() | ||
| .setK(10) | ||
| .setMaxIter(10); | ||
| LDAModel model = lda.fit(dataset); | ||
|
|
||
| System.out.println(model.logLikelihood(dataset)); | ||
| System.out.println(model.logPerplexity(dataset)); | ||
|
|
||
| // Shows the result | ||
| DataFrame topics = model.describeTopics(3); | ||
| topics.show(false); | ||
| model.transform(dataset).show(false); | ||
|
|
||
| jsc.stop(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml | ||
|
|
||
| // scalastyle:off println | ||
| import org.apache.spark.{SparkContext, SparkConf} | ||
| import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} | ||
| // $example on$ | ||
| import org.apache.spark.ml.clustering.LDA | ||
| import org.apache.spark.sql.{Row, SQLContext} | ||
| import org.apache.spark.sql.types.{StructField, StructType} | ||
| // $example off$ | ||
|
|
||
| /** | ||
| * An example demonstrating a LDA of ML pipeline. | ||
| * Run with | ||
| * {{{ | ||
| * bin/run-example ml.LDAExample | ||
| * }}} | ||
| */ | ||
| object LDAExample { | ||
|
|
||
| final val FEATURES_COL = "features" | ||
|
|
||
| def main(args: Array[String]): Unit = { | ||
|
|
||
| val input = "data/mllib/sample_lda_data.txt" | ||
| // Creates a Spark context and a SQL context | ||
| val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") | ||
| val sc = new SparkContext(conf) | ||
| val sqlContext = new SQLContext(sc) | ||
|
|
||
| // $example on$ | ||
| // Loads data | ||
| val rowRDD = sc.textFile(input).filter(_.nonEmpty) | ||
| .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto about input format being a text file of count vectors |
||
| val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) | ||
| val dataset = sqlContext.createDataFrame(rowRDD, schema) | ||
|
|
||
| // Trains a LDA model | ||
| val lda = new LDA() | ||
| .setK(10) | ||
| .setMaxIter(10) | ||
| .setFeaturesCol(FEATURES_COL) | ||
| val model = lda.fit(dataset) | ||
| val transformed = model.transform(dataset) | ||
|
|
||
| val ll = model.logLikelihood(dataset) | ||
| val lp = model.logPerplexity(dataset) | ||
|
|
||
| // describeTopics | ||
| val topics = model.describeTopics(3) | ||
|
|
||
| // Shows the result | ||
| topics.show(false) | ||
| transformed.show(false) | ||
|
|
||
| // $example off$ | ||
| sc.stop() | ||
| } | ||
| } | ||
| // scalastyle:on println | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please link API docs for each language in code tab before example