@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg
1919
2020import java .util .{Arrays , Random }
2121
22- import scala .collection .mutable .{ArrayBuffer , ArrayBuilder , Map }
22+ import scala .collection .mutable .{ArrayBuffer , ArrayBuilder , Map => MutableMap }
2323
2424import breeze .linalg .{CSCMatrix => BSM , DenseMatrix => BDM , Matrix => BM }
2525
@@ -408,8 +408,15 @@ object SparseMatrix {
408408 require(density >= 0.0 && density <= 1.0 , " density must be a double in the range " +
409409 s " 0.0 <= d <= 1.0. Currently, density: $density" )
410410 val length = math.ceil(numRows * numCols * density).toInt
411- val entries = Map [(Int , Int ), Double ]()
411+ val entries = MutableMap [(Int , Int ), Double ]()
412412 var i = 0
413+ if (density == 0.0 ) {
414+ return new SparseMatrix (numRows, numCols, new Array [Int ](numCols + 1 ),
415+ Array [Int ](), Array [Double ]())
416+ } else if (density == 1.0 ) {
417+ return new SparseMatrix (numRows, numCols, (0 to numRows * numCols by numRows).toArray,
418+ (0 until numRows * numCols).toArray, Array .fill(numRows * numCols)(method(rng)))
419+ }
413420 // Expected number of iterations is less than 1.5 * length
414421 if (density < 0.34 ) {
415422 while (i < length) {
@@ -424,23 +431,18 @@ object SparseMatrix {
424431 }
425432 } else { // selection - rejection method
426433 var j = 0
427- val triesPerCol = math.ceil(length * 1.0 / numCols).toInt
428434 val pool = numRows * numCols
429435 // loop over columns so that the sort in fromCOO requires less sorting
430436 while (i < length && j < numCols) {
431- var k = 0
432- val leftFromPool = (numCols - j) * numRows
433- while (k < triesPerCol) {
434- if (rng.nextDouble() < 1.0 * (length - i) / (pool - leftFromPool)) {
435- var rowIndex = rng.nextInt(numRows)
436- val colIndex = j
437- while (entries.contains((rowIndex, colIndex))) {
438- rowIndex = rng.nextInt(numRows)
439- }
440- entries += (rowIndex, colIndex) -> method(rng)
437+ var passedInPool = j * numRows
438+ var r = 0
439+ while (i < length && r < numRows) {
440+ if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
441+ entries += (r, j) -> method(rng)
441442 i += 1
442443 }
443- k += 1
444+ r += 1
445+ passedInPool += 1
444446 }
445447 j += 1
446448 }
0 commit comments