Skip to content

Commit d5377a9

Browse files
committed
use syn0Global and syn1Global to represent model
1 parent cad2011 commit d5377a9

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class Word2Vec extends Serializable with Logging {
236236
b = 0
237237
while (b < i) {
238238
vocab(a).code(i - b - 1) = code(b)
239-
vocab(a).point(i - b) = point(b)
239+
vocab(a).point(i - b) = point(b) - vocabSize
240240
b += 1
241241
}
242242
a += 1
@@ -285,15 +285,17 @@ class Word2Vec extends Serializable with Logging {
285285

286286
val newSentences = sentences.repartition(numPartitions).cache()
287287
val initRandom = new XORShiftRandom(seed)
288-
var synGlobal =
289-
Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
288+
var syn0Global =
289+
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
290+
var syn1Global = new Array[Float](vocabSize * vectorSize)
290291
var alpha = startingAlpha
291292
for (k <- 1 to numIterations) {
292293
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
293294
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
294-
val synModify = new Array[Int](2 * vocabSize)
295-
val model = iter.foldLeft((synGlobal, 0, 0)) {
296-
case ((syn, lastWordCount, wordCount), sentence) =>
295+
val syn0Modify = new Array[Int](vocabSize)
296+
val syn1Modify = new Array[Int](vocabSize)
297+
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
298+
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
297299
var lwc = lastWordCount
298300
var wc = wordCount
299301
if (wordCount - lastWordCount > 10000) {
@@ -324,45 +326,55 @@ class Word2Vec extends Serializable with Logging {
324326
val inner = bcVocab.value(word).point(d)
325327
val l2 = inner * vectorSize
326328
// Propagate hidden -> output
327-
var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1)
329+
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
328330
if (f > -MAX_EXP && f < MAX_EXP) {
329331
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
330332
f = expTable.value(ind)
331333
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
332-
blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1)
333-
blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1)
334-
synModify(inner) += 1
334+
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
335+
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
336+
syn1Modify(inner) += 1
335337
}
336338
d += 1
337339
}
338-
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1)
339-
synModify(lastWord) += 1
340+
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
341+
syn0Modify(lastWord) += 1
340342
}
341343
}
342344
a += 1
343345
}
344346
pos += 1
345347
}
346-
(syn, lwc, wc)
348+
(syn0, syn1, lwc, wc)
347349
}
348-
val synLocal = model._1
350+
val syn0Local = model._1
351+
val syn1Local = model._2
349352
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
350353
var index = 0
351-
while(index < 2 * vocabSize) {
352-
if (synModify(index) != 0) {
353-
synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize))
354+
while(index < vocabSize) {
355+
if (syn0Modify(index) != 0) {
356+
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
357+
}
358+
if (syn1Modify(index) != 0) {
359+
synOut.update(index + vocabSize,
360+
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
354361
}
355362
index += 1
356363
}
357364
Iterator(synOut)
358365
}
359366
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
360-
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
367+
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
361368
v1
362369
}.collect()
363370
var i = 0
364371
while (i < synAgg.length) {
365-
Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
372+
val index = synAgg(i)._1
373+
if (index < vocabSize) {
374+
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
375+
} else {
376+
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
377+
}
366378
i += 1
367379
}
368380
}
@@ -373,7 +385,7 @@ class Word2Vec extends Serializable with Logging {
373385
while (i < vocabSize) {
374386
val word = bcVocab.value(i).word
375387
val vector = new Array[Float](vectorSize)
376-
Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize)
388+
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
377389
word2VecMap += word -> vector
378390
i += 1
379391
}

0 commit comments

Comments
 (0)