Skip to content

Commit 0a04721

Browse files
bwahlgreenjkbradley
authored andcommitted
[SPARK-17721][MLLIB][BACKPORT] Fix for multiplying transposed SparseMatrix with SparseVector
Backport PR of changes relevant to mllib only, but otherwise identical to #15296 jkbradley Author: Bjarne Fruergaard <[email protected]> Closes #15311 from bwahlgreen/bugfix-spark-17721-1.6. (cherry picked from commit 376545e) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 576265f commit 0a04721

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,12 +587,16 @@ private[spark] object BLAS extends Serializable with Logging {
587587
val indEnd = Arows(rowCounter + 1)
588588
var sum = 0.0
589589
var k = 0
590-
while (k < xNnz && i < indEnd) {
590+
while (i < indEnd && k < xNnz) {
591591
if (xIndices(k) == Acols(i)) {
592592
sum += Avals(i) * xValues(k)
593+
k += 1
594+
i += 1
595+
} else if (xIndices(k) < Acols(i)) {
596+
k += 1
597+
} else {
593598
i += 1
594599
}
595-
k += 1
596600
}
597601
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
598602
rowCounter += 1

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,23 @@ class BLASSuite extends SparkFunSuite {
367367
}
368368
}
369369

370+
val y17 = new DenseVector(Array(0.0, 0.0))
371+
val y18 = y17.copy
372+
373+
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
374+
.transpose
375+
val sA4 =
376+
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
377+
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
378+
379+
val expected4 = new DenseVector(Array(5.0, 4.0))
380+
381+
gemv(1.0, sA3, sx3, 0.0, y17)
382+
gemv(1.0, sA4, sx3, 0.0, y18)
383+
384+
assert(y17 ~== expected4 absTol 1e-15)
385+
assert(y18 ~== expected4 absTol 1e-15)
386+
370387
val dAT =
371388
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
372389
val sAT =

0 commit comments

Comments
 (0)