Skip to content

Commit e5c923b

Browse files
committed
fix random seed in word2vec; move model to local
1 parent 184048f commit e5c923b

File tree

2 files changed

+56
-54
lines changed

2 files changed

+56
-54
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ package org.apache.spark.mllib.feature
1919

2020
import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
22-
import scala.util.Random
2322

2423
import com.github.fommil.netlib.BLAS.{getInstance => blas}
25-
import org.apache.spark.{HashPartitioner, Logging}
24+
25+
import org.apache.spark.Logging
2626
import org.apache.spark.SparkContext._
2727
import org.apache.spark.annotation.Experimental
2828
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2929
import org.apache.spark.mllib.rdd.RDDFunctions._
3030
import org.apache.spark.rdd._
31-
import org.apache.spark.storage.StorageLevel
31+
import org.apache.spark.util.random.XORShiftRandom
3232

3333
/**
3434
* Entry in vocabulary
@@ -94,12 +94,12 @@ class Word2Vec(
9494
private var vocabHash = mutable.HashMap.empty[String, Int]
9595
private var alpha = startingAlpha
9696

97-
private def learnVocab(words:RDD[String]): Unit = {
97+
private def learnVocab(words: RDD[String]): Unit = {
9898
vocab = words.map(w => (w, 1))
9999
.reduceByKey(_ + _)
100100
.map(x => VocabWord(
101-
x._1,
102-
x._2,
101+
x._1,
102+
x._2,
103103
new Array[Int](MAX_CODE_LENGTH),
104104
new Array[Int](MAX_CODE_LENGTH),
105105
0))
@@ -246,31 +246,32 @@ class Word2Vec(
246246
}
247247

248248
val newSentences = sentences.repartition(parallelism).cache()
249+
val seed = 5875483L
250+
val initRandom = new XORShiftRandom(seed)
249251
var syn0Global =
250-
Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
252+
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
251253
var syn1Global = new Array[Float](vocabSize * layer1Size)
252-
253-
for(iter <- 1 to numIterations) {
254-
val (aggSyn0, aggSyn1, _, _) =
255-
// TODO: broadcast temp instead of serializing it directly
256-
// or initialize the model in each executor
257-
newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
258-
seqOp = (c, v) => (c, v) match {
254+
255+
for (k <- 1 to numIterations) {
256+
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
257+
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
258+
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
259259
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
260260
var lwc = lastWordCount
261-
var wc = wordCount
261+
var wc = wordCount
262262
if (wordCount - lastWordCount > 10000) {
263263
lwc = wordCount
264-
alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
264+
// TODO: discount by iteration?
265+
alpha =
266+
startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
265267
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
266268
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
267269
}
268270
wc += sentence.size
269271
var pos = 0
270272
while (pos < sentence.size) {
271273
val word = sentence(pos)
272-
// TODO: fix random seed
273-
val b = Random.nextInt(window)
274+
val b = random.nextInt(window)
274275
// Train Skip-gram
275276
var a = b
276277
while (a < window * 2 + 1 - b) {
@@ -280,7 +281,7 @@ class Word2Vec(
280281
val lastWord = sentence(c)
281282
val l1 = lastWord * layer1Size
282283
val neu1e = new Array[Float](layer1Size)
283-
// Hierarchical softmax
284+
// Hierarchical softmax
284285
var d = 0
285286
while (d < bcVocab.value(word).codeLen) {
286287
val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +304,44 @@ class Word2Vec(
303304
pos += 1
304305
}
305306
(syn0, syn1, lwc, wc)
306-
},
307-
combOp = (c1, c2) => (c1, c2) match {
308-
case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
309-
val n = syn0_1.length
310-
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311-
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312-
blas.sscal(n, weight1, syn0_1, 1)
313-
blas.sscal(n, weight1, syn1_1, 1)
314-
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
315-
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
316-
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
317-
})
307+
}
308+
Iterator(model)
309+
}
310+
val (aggSyn0, aggSyn1, _, _) =
311+
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
312+
val n = syn0_1.length
313+
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
314+
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
315+
blas.sscal(n, weight1, syn0_1, 1)
316+
blas.sscal(n, weight1, syn1_1, 1)
317+
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
318+
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
319+
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
320+
}
318321
syn0Global = aggSyn0
319322
syn1Global = aggSyn1
320323
}
321324
newSentences.unpersist()
322325

323-
val wordMap = new Array[(String, Array[Float])](vocabSize)
326+
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
324327
var i = 0
325328
while (i < vocabSize) {
326329
val word = bcVocab.value(i).word
327330
val vector = new Array[Float](layer1Size)
328331
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
329-
wordMap(i) = (word, vector)
332+
word2VecMap += word -> vector
330333
i += 1
331334
}
332-
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
333-
.partitionBy(new HashPartitioner(modelPartitionNum))
334-
.persist(StorageLevel.MEMORY_AND_DISK)
335-
336-
new Word2VecModel(modelRDD)
335+
336+
new Word2VecModel(word2VecMap.toMap)
337337
}
338338
}
339339

340340
/**
341341
* Word2Vec model
342-
*/
343-
class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
342+
*/
343+
class Word2VecModel private[mllib] (
344+
private val model: Map[String, Array[Float]]) extends Serializable {
344345

345346
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
346347
require(v1.length == v2.length, "Vectors should have the same length")
@@ -357,11 +358,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
357358
* @return vector representation of word
358359
*/
359360
def transform(word: String): Vector = {
360-
val result = model.lookup(word)
361-
if (result.isEmpty) {
362-
throw new IllegalStateException(s"$word not in vocabulary")
361+
model.get(word) match {
362+
case Some(vec) =>
363+
Vectors.dense(vec.map(_.toDouble))
364+
case None =>
365+
throw new IllegalStateException(s"$word not in vocabulary")
363366
}
364-
else Vectors.dense(result(0).map(_.toDouble))
365367
}
366368

367369
/**
@@ -392,14 +394,14 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
392394
*/
393395
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
394396
require(num > 0, "Number of similar words should > 0")
395-
val topK = model.map { case(w, vec) =>
396-
(cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
397-
.sortByKey(ascending = false)
398-
.take(num + 1)
399-
.map(_.swap)
400-
.tail
401-
402-
topK
397+
// TODO: optimize top-k
398+
val fVector = vector.toArray.map(_.toFloat)
399+
model.mapValues(vec => cosineSimilarity(fVector, vec))
400+
.toSeq
401+
.sortBy(- _._2)
402+
.take(num + 1)
403+
.tail
404+
.toArray
403405
}
404406
}
405407

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
4646

4747
test("Word2VecModel") {
4848
val num = 2
49-
val localModel = Seq(
49+
val word2VecMap = Map(
5050
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
5151
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
5252
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
5353
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
5454
)
55-
val model = new Word2VecModel(sc.parallelize(localModel, 2))
55+
val model = new Word2VecModel(word2VecMap)
5656
val syms = model.findSynonyms("china", num)
5757
assert(syms.length == num)
5858
assert(syms(0)._1 == "taiwan")

0 commit comments

Comments
 (0)