Skip to content

Commit f62d6c7

Browse files
committed
[SPARK-4409] Modified genRandMatrix
1 parent 3971c93 commit f62d6c7

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg
1919

2020
import java.util.{Arrays, Random}
2121

22-
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map}
22+
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map => MutableMap}
2323

2424
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
2525

@@ -408,8 +408,15 @@ object SparseMatrix {
408408
require(density >= 0.0 && density <= 1.0, "density must be a double in the range " +
409409
s"0.0 <= d <= 1.0. Currently, density: $density")
410410
val length = math.ceil(numRows * numCols * density).toInt
411-
val entries = Map[(Int, Int), Double]()
411+
val entries = MutableMap[(Int, Int), Double]()
412412
var i = 0
413+
if (density == 0.0) {
414+
return new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1),
415+
Array[Int](), Array[Double]())
416+
} else if (density == 1.0) {
417+
return new SparseMatrix(numRows, numCols, (0 to numRows * numCols by numRows).toArray,
418+
(0 until numRows * numCols).toArray, Array.fill(numRows * numCols)(method(rng)))
419+
}
413420
// Expected number of iterations is less than 1.5 * length
414421
if (density < 0.34) {
415422
while (i < length) {
@@ -424,23 +431,18 @@ object SparseMatrix {
424431
}
425432
} else { // selection - rejection method
426433
var j = 0
427-
val triesPerCol = math.ceil(length * 1.0 / numCols).toInt
428434
val pool = numRows * numCols
429435
// loop over columns so that the sort in fromCOO requires less sorting
430436
while (i < length && j < numCols) {
431-
var k = 0
432-
val leftFromPool = (numCols - j) * numRows
433-
while (k < triesPerCol) {
434-
if (rng.nextDouble() < 1.0 * (length - i) / (pool - leftFromPool)) {
435-
var rowIndex = rng.nextInt(numRows)
436-
val colIndex = j
437-
while (entries.contains((rowIndex, colIndex))) {
438-
rowIndex = rng.nextInt(numRows)
439-
}
440-
entries += (rowIndex, colIndex) -> method(rng)
437+
var passedInPool = j * numRows
438+
var r = 0
439+
while (i < length && r < numRows) {
440+
if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
441+
entries += (r, j) -> method(rng)
441442
i += 1
442443
}
443-
k += 1
444+
r += 1
445+
passedInPool += 1
444446
}
445447
j += 1
446448
}

0 commit comments

Comments
 (0)