@@ -19,16 +19,16 @@ package org.apache.spark.mllib.feature
1919
2020import scala .collection .mutable
2121import scala .collection .mutable .ArrayBuffer
22- import scala .util .Random
2322
2423import com .github .fommil .netlib .BLAS .{getInstance => blas }
25- import org .apache .spark .{HashPartitioner , Logging }
24+
25+ import org .apache .spark .Logging
2626import org .apache .spark .SparkContext ._
2727import org .apache .spark .annotation .Experimental
2828import org .apache .spark .mllib .linalg .{Vector , Vectors }
2929import org .apache .spark .mllib .rdd .RDDFunctions ._
3030import org .apache .spark .rdd ._
31- import org .apache .spark .storage . StorageLevel
31+ import org .apache .spark .util . random . XORShiftRandom
3232
3333/**
3434 * Entry in vocabulary
@@ -94,12 +94,12 @@ class Word2Vec(
9494 private var vocabHash = mutable.HashMap .empty[String , Int ]
9595 private var alpha = startingAlpha
9696
97- private def learnVocab (words: RDD [String ]): Unit = {
97+ private def learnVocab (words : RDD [String ]): Unit = {
9898 vocab = words.map(w => (w, 1 ))
9999 .reduceByKey(_ + _)
100100 .map(x => VocabWord (
101- x._1,
102- x._2,
101+ x._1,
102+ x._2,
103103 new Array [Int ](MAX_CODE_LENGTH ),
104104 new Array [Int ](MAX_CODE_LENGTH ),
105105 0 ))
@@ -246,31 +246,32 @@ class Word2Vec(
246246 }
247247
248248 val newSentences = sentences.repartition(parallelism).cache()
249+ val seed = 5875483L
250+ val initRandom = new XORShiftRandom (seed)
249251 var syn0Global =
250- Array .fill[Float ](vocabSize * layer1Size)((Random .nextFloat() - 0.5f ) / layer1Size)
252+ Array .fill[Float ](vocabSize * layer1Size)((initRandom .nextFloat() - 0.5f ) / layer1Size)
251253 var syn1Global = new Array [Float ](vocabSize * layer1Size)
252-
253- for (iter <- 1 to numIterations) {
254- val (aggSyn0, aggSyn1, _, _) =
255- // TODO: broadcast temp instead of serializing it directly
256- // or initialize the model in each executor
257- newSentences.treeAggregate((syn0Global, syn1Global, 0 , 0 ))(
258- seqOp = (c, v) => (c, v) match {
254+
255+ for (k <- 1 to numIterations) {
256+ val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
257+ val random = new XORShiftRandom (seed ^ ((idx + 1 ) << 16 ) ^ ((- k - 1 ) << 8 ))
258+ val model = iter.foldLeft((syn0Global, syn1Global, 0 , 0 )) {
259259 case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
260260 var lwc = lastWordCount
261- var wc = wordCount
261+ var wc = wordCount
262262 if (wordCount - lastWordCount > 10000 ) {
263263 lwc = wordCount
264- alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
264+ // TODO: discount by iteration?
265+ alpha =
266+ startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
265267 if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
266268 logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
267269 }
268270 wc += sentence.size
269271 var pos = 0
270272 while (pos < sentence.size) {
271273 val word = sentence(pos)
272- // TODO: fix random seed
273- val b = Random .nextInt(window)
274+ val b = random.nextInt(window)
274275 // Train Skip-gram
275276 var a = b
276277 while (a < window * 2 + 1 - b) {
@@ -280,7 +281,7 @@ class Word2Vec(
280281 val lastWord = sentence(c)
281282 val l1 = lastWord * layer1Size
282283 val neu1e = new Array [Float ](layer1Size)
283- // Hierarchical softmax
284+ // Hierarchical softmax
284285 var d = 0
285286 while (d < bcVocab.value(word).codeLen) {
286287 val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +304,44 @@ class Word2Vec(
303304 pos += 1
304305 }
305306 (syn0, syn1, lwc, wc)
306- },
307- combOp = (c1, c2) => (c1, c2) match {
308- case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
309- val n = syn0_1.length
310- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312- blas.sscal(n, weight1, syn0_1, 1 )
313- blas.sscal(n, weight1, syn1_1, 1 )
314- blas.saxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
315- blas.saxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
316- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
317- })
307+ }
308+ Iterator (model)
309+ }
310+ val (aggSyn0, aggSyn1, _, _) =
311+ partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
312+ val n = syn0_1.length
313+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
314+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
315+ blas.sscal(n, weight1, syn0_1, 1 )
316+ blas.sscal(n, weight1, syn1_1, 1 )
317+ blas.saxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
318+ blas.saxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
319+ (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
320+ }
318321 syn0Global = aggSyn0
319322 syn1Global = aggSyn1
320323 }
321324 newSentences.unpersist()
322325
323- val wordMap = new Array [( String , Array [Float ])](vocabSize)
326+ val word2VecMap = mutable. HashMap .empty[ String , Array [Float ]]
324327 var i = 0
325328 while (i < vocabSize) {
326329 val word = bcVocab.value(i).word
327330 val vector = new Array [Float ](layer1Size)
328331 Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
329- wordMap(i) = ( word, vector)
332+ word2VecMap += word -> vector
330333 i += 1
331334 }
332- val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
333- .partitionBy(new HashPartitioner (modelPartitionNum))
334- .persist(StorageLevel .MEMORY_AND_DISK )
335-
336- new Word2VecModel (modelRDD)
335+
336+ new Word2VecModel (word2VecMap.toMap)
337337 }
338338}
339339
340340/**
341341* Word2Vec model
342- */
343- class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
342+ */
343+ class Word2VecModel private [mllib] (
344+ private val model : Map [String , Array [Float ]]) extends Serializable {
344345
345346 private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
346347 require(v1.length == v2.length, " Vectors should have the same length" )
@@ -357,11 +358,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
357358 * @return vector representation of word
358359 */
359360 def transform (word : String ): Vector = {
360- val result = model.lookup(word)
361- if (result.isEmpty) {
362- throw new IllegalStateException (s " $word not in vocabulary " )
361+ model.get(word) match {
362+ case Some (vec) =>
363+ Vectors .dense(vec.map(_.toDouble))
364+ case None =>
365+ throw new IllegalStateException (s " $word not in vocabulary " )
363366 }
364- else Vectors .dense(result(0 ).map(_.toDouble))
365367 }
366368
367369 /**
@@ -392,14 +394,14 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
392394 */
393395 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
394396 require(num > 0 , " Number of similar words should > 0" )
395- val topK = model.map { case (w, vec) =>
396- (cosineSimilarity( vector.toArray.map(_.toFloat), vec), w) }
397- .sortByKey(ascending = false )
398- .take(num + 1 )
399- .map(_.swap )
400- .tail
401-
402- topK
397+ // TODO: optimize top-k
398+ val fVector = vector.toArray.map(_.toFloat)
399+ model.mapValues(vec => cosineSimilarity(fVector, vec) )
400+ .toSeq
401+ .sortBy( - _._2 )
402+ .take(num + 1 )
403+ .tail
404+ .toArray
403405 }
404406}
405407
0 commit comments