Skip to content

Commit 29396e7

Browse files
bwahlgreenjkbradley
authored andcommitted
[SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector
## What changes were proposed in this pull request? * changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical) * adds a test that was failing before this change, but succeeds with these changes. The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue. ## How was this patch tested? I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`. ## ___ As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license. Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before. Author: Bjarne Fruergaard <[email protected]> Closes #15296 from bwahlgreen/bugfix-spark-17721.
1 parent 4ecc648 commit 29396e7

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable {
638638
val indEnd = Arows(rowCounter + 1)
639639
var sum = 0.0
640640
var k = 0
641-
while (k < xNnz && i < indEnd) {
641+
while (i < indEnd && k < xNnz) {
642642
if (xIndices(k) == Acols(i)) {
643643
sum += Avals(i) * xValues(k)
644+
k += 1
645+
i += 1
646+
} else if (xIndices(k) < Acols(i)) {
647+
k += 1
648+
} else {
644649
i += 1
645650
}
646-
k += 1
647651
}
648652
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
649653
rowCounter += 1

mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite {
392392
}
393393
}
394394

395+
val y17 = new DenseVector(Array(0.0, 0.0))
396+
val y18 = y17.copy
397+
398+
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
399+
.transpose
400+
val sA4 =
401+
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
402+
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
403+
404+
val expected4 = new DenseVector(Array(5.0, 4.0))
405+
406+
gemv(1.0, sA3, sx3, 0.0, y17)
407+
gemv(1.0, sA4, sx3, 0.0, y18)
408+
409+
assert(y17 ~== expected4 absTol 1e-15)
410+
assert(y18 ~== expected4 absTol 1e-15)
411+
395412
val dAT =
396413
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
397414
val sAT =

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging {
637637
val indEnd = Arows(rowCounter + 1)
638638
var sum = 0.0
639639
var k = 0
640-
while (k < xNnz && i < indEnd) {
640+
while (i < indEnd && k < xNnz) {
641641
if (xIndices(k) == Acols(i)) {
642642
sum += Avals(i) * xValues(k)
643+
k += 1
644+
i += 1
645+
} else if (xIndices(k) < Acols(i)) {
646+
k += 1
647+
} else {
643648
i += 1
644649
}
645-
k += 1
646650
}
647651
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
648652
rowCounter += 1

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite {
392392
}
393393
}
394394

395+
val y17 = new DenseVector(Array(0.0, 0.0))
396+
val y18 = y17.copy
397+
398+
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
399+
.transpose
400+
val sA4 =
401+
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
402+
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
403+
404+
val expected4 = new DenseVector(Array(5.0, 4.0))
405+
406+
gemv(1.0, sA3, sx3, 0.0, y17)
407+
gemv(1.0, sA4, sx3, 0.0, y18)
408+
409+
assert(y17 ~== expected4 absTol 1e-15)
410+
assert(y18 ~== expected4 absTol 1e-15)
411+
395412
val dAT =
396413
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
397414
val sAT =

0 commit comments

Comments
 (0)