Skip to content

Commit 85c42fd

Browse files
mengxrDB Tsai
authored andcommitted
[SPARK-13927][MLLIB] add row/column iterator to local matrices
## What changes were proposed in this pull request? Add row/column iterator to local matrices to simplify tasks like BlockMatrix => RowMatrix conversion. It handles dense and sparse matrices properly. ## How was this patch tested? Unit tests on sparse and dense matrix. cc: dbtsai Author: Xiangrui Meng <[email protected]> Closes #11757 from mengxr/SPARK-13927.
1 parent 6fc2b65 commit 85c42fd

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ import java.util.{Arrays, Random}
2222
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet}
2323

2424
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
25+
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2526

2627
import org.apache.spark.annotation.{DeveloperApi, Since}
27-
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
2828
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
2930
import org.apache.spark.sql.catalyst.util.GenericArrayData
3031
import org.apache.spark.sql.types._
3132

@@ -58,6 +59,20 @@ sealed trait Matrix extends Serializable {
5859
newArray
5960
}
6061

62+
/**
63+
* Returns an iterator of column vectors.
64+
* This operation could be expensive, depending on the underlying storage.
65+
*/
66+
@Since("2.0.0")
67+
def colIter: Iterator[Vector]
68+
69+
/**
70+
* Returns an iterator of row vectors.
71+
* This operation could be expensive, depending on the underlying storage.
72+
*/
73+
@Since("2.0.0")
74+
def rowIter: Iterator[Vector] = this.transpose.colIter
75+
6176
/** Converts to a breeze matrix. */
6277
private[mllib] def toBreeze: BM[Double]
6378

@@ -386,6 +401,21 @@ class DenseMatrix @Since("1.3.0") (
386401
}
387402
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result())
388403
}
404+
405+
@Since("2.0.0")
406+
override def colIter: Iterator[Vector] = {
407+
if (isTransposed) {
408+
Iterator.tabulate(numCols) { j =>
409+
val col = new Array[Double](numRows)
410+
blas.dcopy(numRows, values, j, numCols, col, 0, 1)
411+
new DenseVector(col)
412+
}
413+
} else {
414+
Iterator.tabulate(numCols) { j =>
415+
new DenseVector(values.slice(j * numRows, (j + 1) * numRows))
416+
}
417+
}
418+
}
389419
}
390420

391421
/**
@@ -656,6 +686,38 @@ class SparseMatrix @Since("1.3.0") (
656686
@Since("1.5.0")
657687
override def numActives: Int = values.length
658688

689+
@Since("2.0.0")
690+
override def colIter: Iterator[Vector] = {
691+
if (isTransposed) {
692+
val indicesArray = Array.fill(numCols)(MArrayBuilder.make[Int])
693+
val valuesArray = Array.fill(numCols)(MArrayBuilder.make[Double])
694+
var i = 0
695+
while (i < numRows) {
696+
var k = colPtrs(i)
697+
val rowEnd = colPtrs(i + 1)
698+
while (k < rowEnd) {
699+
val j = rowIndices(k)
700+
indicesArray(j) += i
701+
valuesArray(j) += values(k)
702+
k += 1
703+
}
704+
i += 1
705+
}
706+
Iterator.tabulate(numCols) { j =>
707+
val ii = indicesArray(j).result()
708+
val vv = valuesArray(j).result()
709+
new SparseVector(numRows, ii, vv)
710+
}
711+
} else {
712+
Iterator.tabulate(numCols) { j =>
713+
val colStart = colPtrs(j)
714+
val colEnd = colPtrs(j + 1)
715+
val ii = rowIndices.slice(colStart, colEnd)
716+
val vv = values.slice(colStart, colEnd)
717+
new SparseVector(numRows, ii, vv)
718+
}
719+
}
720+
}
659721
}
660722

661723
/**

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,4 +494,17 @@ class MatricesSuite extends SparkFunSuite {
494494
assert(sm1.numNonzeros === 1)
495495
assert(sm1.numActives === 3)
496496
}
497+
498+
test("row/col iterator") {
499+
val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0))
500+
val sm = dm.toSparse
501+
val rows = Seq(Vectors.dense(0, 3), Vectors.dense(1, 4), Vectors.dense(2, 0))
502+
val cols = Seq(Vectors.dense(0, 1, 2), Vectors.dense(3, 4, 0))
503+
for (m <- Seq(dm, sm)) {
504+
assert(m.rowIter.toSeq === rows)
505+
assert(m.colIter.toSeq === cols)
506+
assert(m.transpose.rowIter.toSeq === cols)
507+
assert(m.transpose.colIter.toSeq === rows)
508+
}
509+
}
497510
}

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,10 @@ object MimaExcludes {
531531
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"),
532532
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"),
533533
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert")
534+
) ++ Seq(
535+
// SPARK-13927: add row/column iterator to local matrices
536+
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"),
537+
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter")
534538
)
535539
case v if v.startsWith("1.6") =>
536540
Seq(

0 commit comments

Comments
 (0)