Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 54 additions & 52 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.mllib.feature

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.{HashPartitioner, Logging}

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom

/**
* Entry in vocabulary
Expand Down Expand Up @@ -94,12 +94,12 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha

private def learnVocab(words:RDD[String]): Unit = {
private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
x._1,
x._2,
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
Expand Down Expand Up @@ -246,31 +246,32 @@ class Word2Vec(
}

val newSentences = sentences.repartition(parallelism).cache()
val seed = 5875483L
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does more than fix the seed for unit tests, but for every call. Is it not a bit better to make the RNG injectable via a discreet package-private setter and let the tests inject a seeded RNG?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added setters and made seed configurable.

val initRandom = new XORShiftRandom(seed)
var syn0Global =
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)

for(iter <- 1 to numIterations) {
val (aggSyn0, aggSyn1, _, _) =
// TODO: broadcast temp instead of serializing it directly
// or initialize the model in each executor
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
seqOp = (c, v) => (c, v) match {

for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
// TODO: discount by iteration?
alpha =
startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
// TODO: fix random seed
val b = Random.nextInt(window)
val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
Expand All @@ -280,7 +281,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Float](layer1Size)
// Hierarchical softmax
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
Expand All @@ -303,44 +304,44 @@ class Word2Vec(
pos += 1
}
(syn0, syn1, lwc, wc)
},
combOp = (c1, c2) => (c1, c2) match {
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
})
}
Iterator(model)
}
val (aggSyn0, aggSyn1, _, _) =
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
}
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()

val wordMap = new Array[(String, Array[Float])](vocabSize)
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
wordMap(i) = (word, vector)
word2VecMap += word -> vector
i += 1
}
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
.partitionBy(new HashPartitioner(modelPartitionNum))
.persist(StorageLevel.MEMORY_AND_DISK)

new Word2VecModel(modelRDD)

new Word2VecModel(word2VecMap.toMap)
}
}

/**
* Word2Vec model
*/
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
*/
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {

private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
Expand All @@ -357,11 +358,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
* @return vector representation of word
*/
def transform(word: String): Vector = {
val result = model.lookup(word)
if (result.isEmpty) {
throw new IllegalStateException(s"$word not in vocabulary")
model.get(word) match {
case Some(vec) =>
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
}
else Vectors.dense(result(0).map(_.toDouble))
}

/**
Expand Down Expand Up @@ -392,14 +394,14 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
val topK = model.map { case(w, vec) =>
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
.sortByKey(ascending = false)
.take(num + 1)
.map(_.swap)
.tail

topK
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
model.mapValues(vec => cosineSimilarity(fVector, vec))
.toSeq
.sortBy(- _._2)
.take(num + 1)
.tail
.toArray
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {

test("Word2VecModel") {
val num = 2
val localModel = Seq(
val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(sc.parallelize(localModel, 2))
val model = new Word2VecModel(word2VecMap)
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")
Expand Down