Skip to content

Commit 083aa66

Browse files
committed
update synGlobal in place and reduce synOut size
1 parent 9075e1c commit 083aa66

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.rdd._
3535
import org.apache.spark.util.Utils
3636
import org.apache.spark.util.random.XORShiftRandom
3737
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
38+
3839
/**
3940
* Entry in vocabulary
4041
*/
@@ -323,14 +324,14 @@ class Word2Vec extends Serializable with Logging {
323324
val ind = bcVocab.value(word).point(d)
324325
val l2 = ind * vectorSize
325326
// Propagate hidden -> output
326-
synModify(ind) += 1
327327
var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1)
328328
if (f > -MAX_EXP && f < MAX_EXP) {
329329
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
330330
f = expTable.value(ind)
331331
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
332332
blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1)
333333
blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1)
334+
synModify(ind) += 1
334335
}
335336
d += 1
336337
}
@@ -355,11 +356,15 @@ class Word2Vec extends Serializable with Logging {
355356
}
356357
Iterator(synOut)
357358
}
358-
synGlobal = partial.flatMap(x => x).reduceByKey {
359-
case (v1,v2) =>
359+
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
360360
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
361361
v1
362-
}.collect().sortBy(_._1).flatMap(x => x._2)
362+
}.collect()
363+
var i = 0
364+
while (i < synAgg.length) {
365+
Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
366+
i += 1
367+
}
363368
}
364369
newSentences.unpersist()
365370

0 commit comments

Comments
 (0)