Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging {
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
if (alpha == 0.0 && beta == 1.0) {
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
} else if (alpha == 0.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
} else {
A match {
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
Expand Down Expand Up @@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0) {
// Scale matrix first if `beta` is not equal to 1.0
if (beta != 1.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
Expand Down Expand Up @@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging {
s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
require(A.numRows == y.size,
s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}")
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
if (alpha == 0.0 && beta == 1.0) {
logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.")
} else if (alpha == 0.0) {
scal(beta, y)
} else {
(A, x) match {
case (smA: SparseMatrix, dvx: DenseVector) =>
Expand Down Expand Up @@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging {
val xValues = x.values
val yValues = y.values

if (alpha == 0.0) {
scal(beta, y)
return
}

if (A.isTransposed) {
var rowCounterForA = 0
while (rowCounterForA < mA) {
Expand Down Expand Up @@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging {
val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices

if (alpha == 0.0) {
scal(beta, y)
return
}

if (A.isTransposed) {
var rowCounter = 0
while (rowCounter < mA) {
Expand All @@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
scal(beta, y)
if (beta != 1.0) scal(beta, y)

var colCounterForA = 0
var k = 0
Expand Down Expand Up @@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging {
rowCounter += 1
}
} else {
scal(beta, y)
if (beta != 1.0) scal(beta, y)
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite {
val C14 = C1.copy
val C15 = C1.copy
val C16 = C1.copy
val C17 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
Expand All @@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite {
assert(C2 ~== expected2 absTol 1e-15)
assert(C3 ~== expected3 absTol 1e-15)
assert(C4 ~== expected3 absTol 1e-15)
gemm(1.0, dA, B, 0.0, C17)
assert(C17 ~== expected absTol 1e-15)
gemm(1.0, sA, B, 0.0, C17)
assert(C17 ~== expected absTol 1e-15)

withClue("columns of A don't match the rows of B") {
intercept[Exception] {
Expand Down