Skip to content

Commit c35fdcb

Browse files
hhbyyhmengxr
authored andcommitted
[SPARK-10491] [MLLIB] move RowMatrix.dspr to BLAS
jira: https://issues.apache.org/jira/browse/SPARK-10491 We implemented dspr with sparse vector support in `RowMatrix`. This method is also used in WeightedLeastSquares and other places. It would be useful to move it to `linalg.BLAS`. Let me know if new UT needed. Author: Yuhao Yang <[email protected]> Closes #8663 from hhbyyh/movedspr.
1 parent 09b7e7c commit c35fdcb

File tree

4 files changed

+72
-41
lines changed

4 files changed

+72
-41
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares(
8888
if (fitIntercept) {
8989
// shift centers
9090
// A^T A - aBar aBar^T
91-
RowMatrix.dspr(-1.0, aBar, aaValues)
91+
BLAS.spr(-1.0, aBar, aaValues)
9292
// A^T b - bBar aBar
9393
BLAS.axpy(-bBar, aBar, abBar)
9494
}
@@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares {
203203
bbSum += w * b * b
204204
BLAS.axpy(w, a, aSum)
205205
BLAS.axpy(w * b, a, abSum)
206-
RowMatrix.dspr(w, a, aaSum.values)
206+
BLAS.spr(w, a, aaSum)
207207
this
208208
}
209209

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging {
236236
_nativeBLAS
237237
}
238238

239+
/**
240+
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
241+
*
242+
* @param U the upper triangular part of the matrix in a [[DenseVector]](column major)
243+
*/
244+
def spr(alpha: Double, v: Vector, U: DenseVector): Unit = {
245+
spr(alpha, v, U.values)
246+
}
247+
248+
/**
249+
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
250+
*
251+
* @param U the upper triangular part of the matrix packed in an array (column major)
252+
*/
253+
def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
254+
val n = v.size
255+
v match {
256+
case DenseVector(values) =>
257+
NativeBLAS.dspr("U", n, alpha, values, 1, U)
258+
case SparseVector(size, indices, values) =>
259+
val nnz = indices.length
260+
var colStartIdx = 0
261+
var prevCol = 0
262+
var col = 0
263+
var j = 0
264+
var i = 0
265+
var av = 0.0
266+
while (j < nnz) {
267+
col = indices(j)
268+
// Skip empty columns.
269+
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
270+
col = indices(j)
271+
av = alpha * values(j)
272+
i = 0
273+
while (i <= j) {
274+
U(colStartIdx + indices(i)) += av * values(i)
275+
i += 1
276+
}
277+
j += 1
278+
prevCol = col
279+
}
280+
}
281+
}
282+
239283
/**
240284
* A := alpha * x * x^T^ + A
241285
* @param alpha a real scalar that will be multiplied to x * x^T^.

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer
2424
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy,
2525
svd => brzSvd, MatrixSingularException, inv}
2626
import breeze.numerics.{sqrt => brzSqrt}
27-
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2827

2928
import org.apache.spark.Logging
3029
import org.apache.spark.SparkContext._
@@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") (
123122
// Compute the upper triangular part of the gram matrix.
124123
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
125124
seqOp = (U, v) => {
126-
RowMatrix.dspr(1.0, v, U.data)
125+
BLAS.spr(1.0, v, U.data)
127126
U
128127
}, combOp = (U1, U2) => U1 += U2)
129128

@@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") (
673672
@Experimental
674673
object RowMatrix {
675674

676-
/**
677-
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
678-
*
679-
* @param U the upper triangular part of the matrix packed in an array (column major)
680-
*/
681-
// TODO: SPARK-10491 - move this method to linalg.BLAS
682-
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
683-
// TODO: Find a better home (breeze?) for this method.
684-
val n = v.size
685-
v match {
686-
case DenseVector(values) =>
687-
blas.dspr("U", n, alpha, values, 1, U)
688-
case SparseVector(size, indices, values) =>
689-
val nnz = indices.length
690-
var colStartIdx = 0
691-
var prevCol = 0
692-
var col = 0
693-
var j = 0
694-
var i = 0
695-
var av = 0.0
696-
while (j < nnz) {
697-
col = indices(j)
698-
// Skip empty columns.
699-
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
700-
col = indices(j)
701-
av = alpha * values(j)
702-
i = 0
703-
while (i <= j) {
704-
U(colStartIdx + indices(i)) += av * values(i)
705-
i += 1
706-
}
707-
j += 1
708-
prevCol = col
709-
}
710-
}
711-
}
712-
713675
/**
714676
* Fills a full square matrix from its upper triangular part.
715677
*/

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite {
126126
}
127127
}
128128

129+
test("spr") {
130+
// test dense vector
131+
val alpha = 0.1
132+
val x = new DenseVector(Array(1.0, 2, 2.1, 4))
133+
val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
134+
val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6))
135+
136+
spr(alpha, x, U)
137+
assert(U ~== expected absTol 1e-9)
138+
139+
val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5))
140+
withClue("Size of vector must match the rank of matrix") {
141+
intercept[Exception] {
142+
spr(alpha, x, matrix33)
143+
}
144+
}
145+
146+
// test sparse vector
147+
val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2))
148+
val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
149+
spr(0.1, sv, U2)
150+
val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4))
151+
assert(U2 ~== expectedSparse absTol 1e-15)
152+
}
153+
129154
test("syr") {
130155
val dA = new DenseMatrix(4, 4,
131156
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))

0 commit comments

Comments
 (0)