@@ -434,13 +434,11 @@ class Word2VecModel private[mllib] (
434434 // Maintain a ordered list of words based on the index in the initial model.
435435 private val wordList : Array [String ] = model.keys.toArray
436436 private val wordIndex : Map [String , Int ] = wordList.zip(0 until model.size).toMap
437+ private val numDim = model.head._2.size
438+ private val numWords = wordIndex.size
437439
438- private val (wordVectors : DenseMatrix , wordVecNorms : Array [Double ]) = {
439- val numDim = model.head._2.size
440- val numWords = wordIndex.size
441- val flatVec = model.toSeq.flatMap { case (w, v) =>
442- v.map(_.toDouble)}.toArray
443- val wordVectors = new DenseMatrix (numWords, numDim, flatVec, isTransposed= true )
440+ private val (wordVectors : Array [Float ], wordVecNorms : Array [Double ]) = {
441+ val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray
444442 val wordVecNorms = new Array [Double ](numWords)
445443 var i = 0
446444 while (i < numWords) {
@@ -500,9 +498,13 @@ class Word2VecModel private[mllib] (
500498 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
501499 require(num > 0 , " Number of similar words should > 0" )
502500
503- val numWords = wordVectors.numRows
504- val cosineVec = Vectors .zeros(numWords).asInstanceOf [DenseVector ]
505- BLAS .gemv(1.0 , wordVectors, new DenseVector (vector.toArray), 0.0 , cosineVec)
501+ val fVector = vector.toArray.map(_.toFloat)
502+ val cosineVec = Array .fill[Float ](numWords)(0 )
503+ val alpha : Float = 1
504+ val beta : Float = 1
505+
506+ blas.sgemv(
507+ " T" , numDim, numWords, alpha, wordVectors, numDim, fVector, 1 , beta, cosineVec, 1 )
506508
507509 // Need not divide with the norm of the given vector since it is constant.
508510 val updatedCosines = new Array [Double ](numWords)
@@ -523,11 +525,9 @@ class Word2VecModel private[mllib] (
523525 * Returns a map of words to their vector representations.
524526 */
525527 def getVectors : Map [String , Array [Float ]] = {
526- val numDim = wordVectors.numCols
527528 wordIndex.map { case (word, ind) =>
528- val startInd = numDim * ind
529- val endInd = startInd + numDim
530- (word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
529+ (word, wordVectors.slice(numDim * ind, numDim * ind + numDim))
530+ }
531531 }
532532}
533533
0 commit comments