@@ -431,19 +431,28 @@ class Word2Vec extends Serializable with Logging {
431431class Word2VecModel private [mllib] (
432432 model : Map [String , Array [Float ]]) extends Serializable with Saveable {
433433
434- // Maintain a ordered list of words based on the index in the initial model.
434+ // wordList: Ordered list of words obtained from model.
435+ // wordIndex: Maps each word to an index, which can retrieve the corresponding
436+ // vector from wordVectors (see below)
437+ // vectorSize: Dimension of each vector.
438+ // numWords: Number of words.
435439 private val wordList : Array [String ] = model.keys.toArray
436440 private val wordIndex : Map [String , Int ] = wordList.zip(0 until model.size).toMap
437- private val numDim = model.head._2.size
441+ private val vectorSize = model.head._2.size
438442 private val numWords = wordIndex.size
439443
444+ // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
445+ // mapped with index i can be retrieved by the slice
446+ // (ind * vectorSize, ind * vectorSize + vectorSize)
447+ // wordVecNorms: Array of length numWords, each value being the Euclidean norm
448+ // of the wordVector.
440449 private val (wordVectors : Array [Float ], wordVecNorms : Array [Double ]) = {
441- val wordVectors = model.toSeq. flatMap { case (w, v) => v } .toArray
450+ val wordVectors = wordList. flatMap(word => model.get(word).get) .toArray
442451 val wordVecNorms = new Array [Double ](numWords)
443452 var i = 0
444453 while (i < numWords) {
445454 val vec = model.get(wordList(i)).get
446- wordVecNorms(i) = blas.snrm2(numDim , vec, 1 )
455+ wordVecNorms(i) = blas.snrm2(vectorSize , vec, 1 )
447456 i += 1
448457 }
449458 (wordVectors, wordVecNorms)
@@ -501,10 +510,10 @@ class Word2VecModel private[mllib] (
501510 val fVector = vector.toArray.map(_.toFloat)
502511 val cosineVec = Array .fill[Float ](numWords)(0 )
503512 val alpha : Float = 1
504- val beta : Float = 1
513+ val beta : Float = 0
505514
506515 blas.sgemv(
507- " T" , numDim , numWords, alpha, wordVectors, numDim , fVector, 1 , beta, cosineVec, 1 )
516+ " T" , vectorSize , numWords, alpha, wordVectors, vectorSize , fVector, 1 , beta, cosineVec, 1 )
508517
509518 // Need not divide with the norm of the given vector since it is constant.
510519 val updatedCosines = new Array [Double ](numWords)
@@ -526,7 +535,7 @@ class Word2VecModel private[mllib] (
526535 */
527536 def getVectors : Map [String , Array [Float ]] = {
528537 wordIndex.map { case (word, ind) =>
529- (word, wordVectors.slice(numDim * ind, numDim * ind + numDim ))
538+ (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize ))
530539 }
531540 }
532541}
0 commit comments