Skip to content

Commit ffc9240

Browse files
committed
Minor
1 parent 6b74c81 commit ffc9240

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,19 +431,28 @@ class Word2Vec extends Serializable with Logging {
431431
class Word2VecModel private[mllib] (
432432
model: Map[String, Array[Float]]) extends Serializable with Saveable {
433433

434-
// Maintain a ordered list of words based on the index in the initial model.
434+
// wordList: Ordered list of words obtained from model.
435+
// wordIndex: Maps each word to an index, which can retrieve the corresponding
436+
// vector from wordVectors (see below)
437+
// vectorSize: Dimension of each vector.
438+
// numWords: Number of words.
435439
private val wordList: Array[String] = model.keys.toArray
436440
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
437-
private val numDim = model.head._2.size
441+
private val vectorSize = model.head._2.size
438442
private val numWords = wordIndex.size
439443

444+
// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
445+
// mapped with index i can be retrieved by the slice
446+
// (ind * vectorSize, ind * vectorSize + vectorSize)
447+
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
448+
// of the wordVector.
440449
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
441-
val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray
450+
val wordVectors = wordList.flatMap(word => model.get(word).get).toArray
442451
val wordVecNorms = new Array[Double](numWords)
443452
var i = 0
444453
while (i < numWords) {
445454
val vec = model.get(wordList(i)).get
446-
wordVecNorms(i) = blas.snrm2(numDim, vec, 1)
455+
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
447456
i += 1
448457
}
449458
(wordVectors, wordVecNorms)
@@ -501,10 +510,10 @@ class Word2VecModel private[mllib] (
501510
val fVector = vector.toArray.map(_.toFloat)
502511
val cosineVec = Array.fill[Float](numWords)(0)
503512
val alpha: Float = 1
504-
val beta: Float = 1
513+
val beta: Float = 0
505514

506515
blas.sgemv(
507-
"T", numDim, numWords, alpha, wordVectors, numDim, fVector, 1, beta, cosineVec, 1)
516+
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
508517

509518
// Need not divide with the norm of the given vector since it is constant.
510519
val updatedCosines = new Array[Double](numWords)
@@ -526,7 +535,7 @@ class Word2VecModel private[mllib] (
526535
*/
527536
def getVectors: Map[String, Array[Float]] = {
528537
wordIndex.map { case (word, ind) =>
529-
(word, wordVectors.slice(numDim * ind, numDim * ind + numDim))
538+
(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
530539
}
531540
}
532541
}

0 commit comments

Comments
 (0)