@@ -35,6 +35,7 @@ import org.apache.spark.rdd._
3535import org .apache .spark .util .Utils
3636import org .apache .spark .util .random .XORShiftRandom
3737import org .apache .spark .util .collection .PrimitiveKeyOpenHashMap
38+
3839/**
3940 * Entry in vocabulary
4041 */
@@ -323,14 +324,14 @@ class Word2Vec extends Serializable with Logging {
323324 val ind = bcVocab.value(word).point(d)
324325 val l2 = ind * vectorSize
325326 // Propagate hidden -> output
326- synModify(ind) += 1
327327 var f = blas.sdot(vectorSize, syn, l1, 1 , syn, l2, 1 )
328328 if (f > - MAX_EXP && f < MAX_EXP ) {
329329 val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
330330 f = expTable.value(ind)
331331 val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
332332 blas.saxpy(vectorSize, g, syn, l2, 1 , neu1e, 0 , 1 )
333333 blas.saxpy(vectorSize, g, syn, l1, 1 , syn, l2, 1 )
334+ synModify(ind) += 1
334335 }
335336 d += 1
336337 }
@@ -355,11 +356,15 @@ class Word2Vec extends Serializable with Logging {
355356 }
356357 Iterator (synOut)
357358 }
358- synGlobal = partial.flatMap(x => x).reduceByKey {
359- case (v1,v2) =>
359+ val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
360360 blas.saxpy(vectorSize, 1.0f , v2, 1 , v1, 1 )
361361 v1
362- }.collect().sortBy(_._1).flatMap(x => x._2)
362+ }.collect()
363+ var i = 0
364+ while (i < synAgg.length) {
365+ Array .copy(synAgg(i)._2, 0 , synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
366+ i += 1
367+ }
363368 }
364369 newSentences.unpersist()
365370
0 commit comments