-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-8874] [ML] Add missing methods in Word2Vec #7263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,15 +18,17 @@ | |
| package org.apache.spark.ml.feature | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util.{Identifiable, SchemaUtils} | ||
| import org.apache.spark.mllib.feature | ||
| import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} | ||
| import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} | ||
| import org.apache.spark.mllib.linalg.BLAS._ | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.SQLContext | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
|
|
@@ -146,6 +148,40 @@ class Word2VecModel private[ml] ( | |
| wordVectors: feature.Word2VecModel) | ||
| extends Model[Word2VecModel] with Word2VecBase { | ||
|
|
||
|
|
||
| /** | ||
| * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and | ||
| * and the vector the DenseVector that it is mapped to. | ||
| */ | ||
| val getVectors: DataFrame = { | ||
| val sc = SparkContext.getOrCreate() | ||
| val sqlContext = SQLContext.getOrCreate(sc) | ||
| import sqlContext.implicits._ | ||
| val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) | ||
| sc.parallelize(wordVec.toSeq).toDF("word", "vector") | ||
| } | ||
|
|
||
| /** | ||
| * Find "num" no. words closest in similarity to the given word. | ||
| * Returns a dataframe with the words and the cosine similarities between the | ||
| * synonyms and the given word. | ||
| */ | ||
| def findSynonyms(word: String, num: Int): DataFrame = { | ||
| findSynonyms(wordVectors.transform(word), num) | ||
| } | ||
|
|
||
| /** | ||
| * Find "num" no. words closest to similarity to the vector representation of the word. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Find "num" number of words closest in similarity to the given vector representation of a word." |
||
| * Returns a dataframe with the words and the cosine similarities between the | ||
| * synonyms and the given word. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "given word" --> "given word vector" |
||
| */ | ||
| def findSynonyms(word: Vector, num: Int): DataFrame = { | ||
| val sc = SparkContext.getOrCreate() | ||
| val sqlContext = SQLContext.getOrCreate(sc) | ||
| import sqlContext.implicits._ | ||
| sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") | ||
| } | ||
| } | ||
|
|
||
| test("getVectors") { | ||
|
|
||
| val sqlContext = new SQLContext(sc) | ||
| import sqlContext.implicits._ | ||
|
|
||
| val sentence = "a b " * 100 + "a c " * 10 | ||
| val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) | ||
|
|
||
| val codes = Map( | ||
| "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Were these values computed using another library? If so, it'd be nice to document that. Also, can you avoid having multiple copies of them in the unit test, and instead put them in a static object instead?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I think I just used the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's fine. We need to beef up the mllib tests at some point. I'll make a JIRA for that. |
||
| "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), | ||
| "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) | ||
| ) | ||
| val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } | ||
|
|
||
| val docDF = doc.zip(doc).toDF("text", "alsotext") | ||
|
|
||
| val model = new Word2Vec() | ||
| .setVectorSize(3) | ||
| .setInputCol("text") | ||
| .setOutputCol("result") | ||
| .setSeed(42L) | ||
| .fit(docDF) | ||
|
|
||
| val realVectors = model.getVectors.sort("word").select("vector").map { | ||
| case Row(v: Vector) => v | ||
| }.collect() | ||
|
|
||
| realVectors.zip(expectedVectors).foreach { | ||
| case (real, expected) => | ||
| assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") | ||
| } | ||
| } | ||
|
|
||
| test("findSynonyms") { | ||
|
|
||
| val sqlContext = new SQLContext(sc) | ||
| import sqlContext.implicits._ | ||
|
|
||
| val sentence = "a b " * 100 + "a c " * 10 | ||
| val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) | ||
| val docDF = doc.zip(doc).toDF("text", "alsotext") | ||
|
|
||
| val model = new Word2Vec() | ||
| .setVectorSize(3) | ||
| .setInputCol("text") | ||
| .setOutputCol("result") | ||
| .setSeed(42L) | ||
| .fit(docDF) | ||
|
|
||
| val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How were these computed? |
||
| val (synonyms, similarity) = model.findSynonyms("a", 2).map { | ||
| case Row(w: String, sim: Double) => (w, sim) | ||
| }.collect().unzip | ||
|
|
||
| assert(synonyms.toArray === Array("b", "c")) | ||
| expectedSimilarity.zip(similarity).map { | ||
| case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) | ||
| } | ||
|
|
||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"no." --> "number of"