Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {


/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
*/
val getVectors: DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
sc.parallelize(wordVec.toSeq).toDF("word", "vector")
}

/**
* Find "num" no. words closest in similarity to the given word.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"no." --> "number of"

* Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word.
*/
def findSynonyms(word: String, num: Int): DataFrame = {
findSynonyms(wordVectors.transform(word), num)
}

/**
* Find "num" no. words closest to similarity to the vector representation of the word.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Find "num" number of words closest in similarity to the given vector representation of a word."

* Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"given word" --> "given word vector"

*/
def findSynonyms(word: Vector, num: Int): DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}

test("getVectors") {

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

val codes = Map(
"a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were these values computed using another library? If so, it'd be nice to document that.

Also, can you avoid having multiple copies of them in the unit test, and instead put them in a static object instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I just used the Word2Vec in mllib . Should I test it with another library or is this sufficient?
This effectively checks that the getVectors implementation in mllib is consistent with ml.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine. We need to beef up the mllib tests at some point. I'll make a JIRA for that.

"b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
"c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
)
val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }

val docDF = doc.zip(doc).toDF("text", "alsotext")

val model = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)

val realVectors = model.getVectors.sort("word").select("vector").map {
case Row(v: Vector) => v
}.collect()

realVectors.zip(expectedVectors).foreach {
case (real, expected) =>
assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
}
}

test("findSynonyms") {

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val docDF = doc.zip(doc).toDF("text", "alsotext")

val model = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)

val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How were these computed?

val (synonyms, similarity) = model.findSynonyms("a", 2).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip

assert(synonyms.toArray === Array("b", "c"))
expectedSimilarity.zip(similarity).map {
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
}

}
}