Skip to content

Commit 6b74c81

Browse files
committed
Switch back to native blas calls
1 parent da1642d commit 6b74c81

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,11 @@ class Word2VecModel private[mllib] (
434434
// Maintain a ordered list of words based on the index in the initial model.
435435
private val wordList: Array[String] = model.keys.toArray
436436
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
437+
private val numDim = model.head._2.size
438+
private val numWords = wordIndex.size
437439

438-
private val (wordVectors: DenseMatrix, wordVecNorms: Array[Double]) = {
439-
val numDim = model.head._2.size
440-
val numWords = wordIndex.size
441-
val flatVec = model.toSeq.flatMap { case(w, v) =>
442-
v.map(_.toDouble)}.toArray
443-
val wordVectors = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
440+
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
441+
val wordVectors = model.toSeq.flatMap { case (w, v) => v }.toArray
444442
val wordVecNorms = new Array[Double](numWords)
445443
var i = 0
446444
while (i < numWords) {
@@ -500,9 +498,13 @@ class Word2VecModel private[mllib] (
500498
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
501499
require(num > 0, "Number of similar words should > 0")
502500

503-
val numWords = wordVectors.numRows
504-
val cosineVec = Vectors.zeros(numWords).asInstanceOf[DenseVector]
505-
BLAS.gemv(1.0, wordVectors, new DenseVector(vector.toArray), 0.0, cosineVec)
501+
val fVector = vector.toArray.map(_.toFloat)
502+
val cosineVec = Array.fill[Float](numWords)(0)
503+
val alpha: Float = 1
504+
val beta: Float = 1
505+
506+
blas.sgemv(
507+
"T", numDim, numWords, alpha, wordVectors, numDim, fVector, 1, beta, cosineVec, 1)
506508

507509
// Need not divide with the norm of the given vector since it is constant.
508510
val updatedCosines = new Array[Double](numWords)
@@ -523,11 +525,9 @@ class Word2VecModel private[mllib] (
523525
* Returns a map of words to their vector representations.
524526
*/
525527
def getVectors: Map[String, Array[Float]] = {
526-
val numDim = wordVectors.numCols
527528
wordIndex.map { case (word, ind) =>
528-
val startInd = numDim * ind
529-
val endInd = startInd + numDim
530-
(word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
529+
(word, wordVectors.slice(numDim * ind, numDim * ind + numDim))
530+
}
531531
}
532532
}
533533

0 commit comments

Comments
 (0)