Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an empty line after imports

* Entry in vocabulary
Expand Down Expand Up @@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
var syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)

var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
Expand Down Expand Up @@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
// Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * vectorSize
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
Expand All @@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
}
}
a += 1
Expand All @@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
}
(syn0, syn1, lwc, wc)
}
Iterator(model)
val syn0Local = model._1
val syn1Local = model._2
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
var index = 0
while(index < vocabSize) {
if (syn0Modify(index) != 0) {
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
}
if (syn1Modify(index) != 0) {
synOut.update(index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
}
index += 1
}
Iterator(synOut)
}
val (aggSyn0, aggSyn1, _, _) =
partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
val n = syn0_1.length
val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
blas.sscal(n, weight1, syn0_1, 1)
blas.sscal(n, weight1, syn1_1, 1)
blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
(syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()
var i = 0
while (i < synAgg.length) {
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
syn0Global = aggSyn0
syn1Global = aggSyn1
i += 1
}
}
newSentences.unpersist()

Expand Down