Skip to content

Commit 4e95e24

Browse files
committed
simplify fromCOO implementation
1 parent 10a63a6 commit 4e95e24

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -362,46 +362,46 @@ object SparseMatrix {
362362
* @return The corresponding `SparseMatrix`
363363
*/
364364
def fromCOO(numRows: Int, numCols: Int, entries: Array[(Int, Int, Double)]): SparseMatrix = {
365+
val numEntries = entries.size
365366
val sortedEntries = entries.sortBy(v => (v._2, v._1))
366-
val colPtrs = new Array[Int](numCols + 1)
367-
var nnz = 0
368-
var lastCol = -1
369-
var lastIndex = -1
370-
sortedEntries.foreach { case (i, j, v) =>
371-
require(i >= 0 && j >= 0, "Negative indices given. Please make sure all indices are " +
372-
s"greater than or equal to zero. i: $i, j: $j, value: $v")
373-
if (v != 0.0) {
374-
while (j != lastCol) {
375-
colPtrs(lastCol + 1) = nnz
376-
lastCol += 1
377-
}
378-
val index = j * numRows + i
379-
if (lastIndex != index) {
380-
nnz += 1
381-
lastIndex = index
382-
}
367+
if (sortedEntries.nonEmpty) {
368+
// Since the entries are sorted by column index, we only need to check the first and the last.
369+
for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) {
370+
require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.")
383371
}
384372
}
385-
while (numCols > lastCol) {
386-
colPtrs(lastCol + 1) = nnz
387-
lastCol += 1
388-
}
389-
val values = new Array[Double](nnz)
390-
val rowIndices = new Array[Int](nnz)
391-
lastIndex = -1
392-
var cnt = -1
393-
sortedEntries.foreach { case (i, j, v) =>
394-
if (v != 0.0) {
395-
val index = j * numRows + i
396-
if (lastIndex != index) {
397-
cnt += 1
398-
lastIndex = index
373+
val colPtrs = new Array[Int](numCols + 1)
374+
val rowIndices = MArrayBuilder.make[Int]
375+
rowIndices.sizeHint(numEntries)
376+
val values = MArrayBuilder.make[Double]
377+
values.sizeHint(numEntries)
378+
var nnz = 0
379+
var prevCol = 0
380+
var prevRow = -1
381+
var prevVal = 0.0
382+
// Append a dummy entry to include the last one at the end of the loop.
383+
(sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) =>
384+
if (v != 0) {
385+
if (i == prevRow && j == prevCol) {
386+
prevVal += v
387+
} else {
388+
if (prevVal != 0) {
389+
require(prevRow >= 0 && prevRow < numRows,
390+
s"Row index out of range [0, $numRows): $prevRow.")
391+
nnz += 1
392+
rowIndices += prevRow
393+
values += prevVal
394+
}
395+
prevRow = i
396+
prevVal = v
397+
while (prevCol < j) {
398+
colPtrs(prevCol + 1) = nnz
399+
prevCol += 1
400+
}
399401
}
400-
values(cnt) += v
401-
rowIndices(cnt) = i
402402
}
403403
}
404-
new SparseMatrix(numRows, numCols, colPtrs.toArray, rowIndices, values)
404+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result())
405405
}
406406

407407
/**

0 commit comments

Comments
 (0)