@@ -236,7 +236,7 @@ class Word2Vec extends Serializable with Logging {
236236 b = 0
237237 while (b < i) {
238238 vocab(a).code(i - b - 1 ) = code(b)
239- vocab(a).point(i - b) = point(b)
239+ vocab(a).point(i - b) = point(b) - vocabSize
240240 b += 1
241241 }
242242 a += 1
@@ -285,15 +285,17 @@ class Word2Vec extends Serializable with Logging {
285285
286286 val newSentences = sentences.repartition(numPartitions).cache()
287287 val initRandom = new XORShiftRandom (seed)
288- var synGlobal =
289- Array .fill[Float ](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f ) / vectorSize)
288+ var syn0Global =
289+ Array .fill[Float ](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f ) / vectorSize)
290+ var syn1Global = new Array [Float ](vocabSize * vectorSize)
290291 var alpha = startingAlpha
291292 for (k <- 1 to numIterations) {
292293 val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
293294 val random = new XORShiftRandom (seed ^ ((idx + 1 ) << 16 ) ^ ((- k - 1 ) << 8 ))
294- val synModify = new Array [Int ](2 * vocabSize)
295- val model = iter.foldLeft((synGlobal, 0 , 0 )) {
296- case ((syn, lastWordCount, wordCount), sentence) =>
295+ val syn0Modify = new Array [Int ](vocabSize)
296+ val syn1Modify = new Array [Int ](vocabSize)
297+ val model = iter.foldLeft((syn0Global, syn1Global, 0 , 0 )) {
298+ case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
297299 var lwc = lastWordCount
298300 var wc = wordCount
299301 if (wordCount - lastWordCount > 10000 ) {
@@ -324,45 +326,55 @@ class Word2Vec extends Serializable with Logging {
324326 val inner = bcVocab.value(word).point(d)
325327 val l2 = inner * vectorSize
326328 // Propagate hidden -> output
327- var f = blas.sdot(vectorSize, syn , l1, 1 , syn , l2, 1 )
329+ var f = blas.sdot(vectorSize, syn0 , l1, 1 , syn1 , l2, 1 )
328330 if (f > - MAX_EXP && f < MAX_EXP ) {
329331 val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
330332 f = expTable.value(ind)
331333 val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
332- blas.saxpy(vectorSize, g, syn , l2, 1 , neu1e, 0 , 1 )
333- blas.saxpy(vectorSize, g, syn , l1, 1 , syn , l2, 1 )
334- synModify (inner) += 1
334+ blas.saxpy(vectorSize, g, syn1 , l2, 1 , neu1e, 0 , 1 )
335+ blas.saxpy(vectorSize, g, syn0 , l1, 1 , syn1 , l2, 1 )
336+ syn1Modify (inner) += 1
335337 }
336338 d += 1
337339 }
338- blas.saxpy(vectorSize, 1.0f , neu1e, 0 , 1 , syn , l1, 1 )
339- synModify (lastWord) += 1
340+ blas.saxpy(vectorSize, 1.0f , neu1e, 0 , 1 , syn0 , l1, 1 )
341+ syn0Modify (lastWord) += 1
340342 }
341343 }
342344 a += 1
343345 }
344346 pos += 1
345347 }
346- (syn , lwc, wc)
348+ (syn0, syn1 , lwc, wc)
347349 }
348- val synLocal = model._1
350+ val syn0Local = model._1
351+ val syn1Local = model._2
349352 val synOut = new PrimitiveKeyOpenHashMap [Int , Array [Float ]](vocabSize * 2 )
350353 var index = 0
351- while (index < 2 * vocabSize) {
352- if (synModify(index) != 0 ) {
353- synOut.update(index, synLocal.slice(index * vectorSize, (index + 1 ) * vectorSize))
354+ while (index < vocabSize) {
355+ if (syn0Modify(index) != 0 ) {
356+ synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1 ) * vectorSize))
357+ }
358+ if (syn1Modify(index) != 0 ) {
359+ synOut.update(index + vocabSize,
360+ syn1Local.slice(index * vectorSize, (index + 1 ) * vectorSize))
354361 }
355362 index += 1
356363 }
357364 Iterator (synOut)
358365 }
359366 val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
360- blas.saxpy(vectorSize, 1.0f , v2, 1 , v1, 1 )
367+ blas.saxpy(vectorSize, 1.0f , v2, 1 , v1, 1 )
361368 v1
362369 }.collect()
363370 var i = 0
364371 while (i < synAgg.length) {
365- Array .copy(synAgg(i)._2, 0 , synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
372+ val index = synAgg(i)._1
373+ if (index < vocabSize) {
374+ Array .copy(synAgg(i)._2, 0 , syn0Global, index * vectorSize, vectorSize)
375+ } else {
376+ Array .copy(synAgg(i)._2, 0 , syn1Global, (index - vocabSize) * vectorSize, vectorSize)
377+ }
366378 i += 1
367379 }
368380 }
@@ -373,7 +385,7 @@ class Word2Vec extends Serializable with Logging {
373385 while (i < vocabSize) {
374386 val word = bcVocab.value(i).word
375387 val vector = new Array [Float ](vectorSize)
376- Array .copy(synGlobal , i * vectorSize, vector, 0 , vectorSize)
388+ Array .copy(syn0Global , i * vectorSize, vector, 0 , vectorSize)
377389 word2VecMap += word -> vector
378390 i += 1
379391 }
0 commit comments