Skip to content

Commit a7a03dc

Browse files
committed
Addressed srowen comments
1 parent 994b457 commit a7a03dc

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,17 @@ class IndexedRowMatrix @Since("1.0.0") (
113113
require(colsPerBlock > 0,
114114
s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock")
115115

116+
// Since block matrices require an integer row index
117+
require(numRows() / rowsPerBlock < Int.MaxValue,
118+
"Number of rows divided by rowsPerBlock cannot exceed maximum integer.")
119+
116120
val m = numRows()
117121
val n = numCols()
118-
val lastRowBlockIndex = m / rowsPerBlock
119-
val lastColBlockIndex = n / colsPerBlock
120-
val lastRowBlockSize = (m % rowsPerBlock).toInt
121-
val lastColBlockSize = (n % colsPerBlock).toInt
122+
// The remainder calculations only matter when m % rowsPerBlock != 0 or n % colsPerBlock != 0
123+
val remainderRowBlockIndex = m / rowsPerBlock
124+
val remainderColBlockIndex = n / colsPerBlock
125+
val remainderRowBlockSize = (m % rowsPerBlock).toInt
126+
val remainderColBlockSize = (n % colsPerBlock).toInt
122127
val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt
123128
val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt
124129

@@ -143,9 +148,9 @@ class IndexedRowMatrix @Since("1.0.0") (
143148
}.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map {
144149
case ((blockRow, blockColumn), itr) =>
145150
val actualNumRows =
146-
if (blockRow == lastRowBlockIndex) lastRowBlockSize else rowsPerBlock
151+
if (blockRow == remainderRowBlockIndex) remainderRowBlockSize else rowsPerBlock
147152
val actualNumColumns =
148-
if (blockColumn == lastColBlockIndex) lastColBlockSize else colsPerBlock
153+
if (blockColumn == remainderColBlockIndex) remainderColBlockSize else colsPerBlock
149154

150155
val arraySize = actualNumRows * actualNumColumns
151156
val matrixAsArray = new Array[Double](arraySize)
@@ -157,7 +162,7 @@ class IndexedRowMatrix @Since("1.0.0") (
157162
}
158163
}
159164
val denseMatrix = new DenseMatrix(actualNumRows, actualNumColumns, matrixAsArray)
160-
val finalMatrix = if (countForValues / arraySize.toDouble >= 0.5) {
165+
val finalMatrix = if (countForValues / arraySize.toDouble >= 0.1) {
161166
denseMatrix
162167
} else {
163168
denseMatrix.toSparse

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,23 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
119119

120120
test("toBlockMatrix sparse backing") {
121121
val sparseData = Seq(
122-
(7L, Vectors.sparse(6, Seq((0, 4.0))))
122+
(15L, Vectors.sparse(12, Seq((0, 4.0))))
123123
).map(x => IndexedRow(x._1, x._2))
124124

125125
// Gonna make m and n larger here so the matrices can easily be completely sparse:
126-
val m = 8
127-
val n = 6
126+
val m = 16
127+
val n = 12
128128

129129
val idxRowMatSparse = new IndexedRowMatrix(sc.parallelize(sparseData))
130130

131131
// Tests when n % colsPerBlock != 0
132-
val blockMat = idxRowMatSparse.toBlockMatrix(4, 4)
132+
val blockMat = idxRowMatSparse.toBlockMatrix(8, 8)
133133
assert(blockMat.numRows() === m)
134134
assert(blockMat.numCols() === n)
135135
assert(blockMat.toBreeze() === idxRowMatSparse.toBreeze())
136136

137137
// Tests when m % rowsPerBlock != 0
138-
val blockMat2 = idxRowMatSparse.toBlockMatrix(3, 3)
138+
val blockMat2 = idxRowMatSparse.toBlockMatrix(6, 6)
139139
assert(blockMat2.numRows() === m)
140140
assert(blockMat2.numCols() === n)
141141
assert(blockMat2.toBreeze() === idxRowMatSparse.toBreeze())
@@ -149,38 +149,34 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
149149
}
150150

151151
test("toBlockMatrix mixed backing") {
152+
val m = 24
153+
val n = 18
154+
152155
val mixedData = Seq(
153-
(0L, Vectors.dense(1, 2, 3)),
154-
(3L, Vectors.sparse(3, Seq((0, 4.0)))))
156+
(0L, Vectors.dense((0 to 17).map(_.toDouble).toArray)),
157+
(1L, Vectors.dense((0 to 17).map(_.toDouble).toArray)),
158+
(23L, Vectors.sparse(18, Seq((0, 4.0)))))
155159
.map(x => IndexedRow(x._1, x._2))
156160

157161
val idxRowMatMixed = new IndexedRowMatrix(
158162
sc.parallelize(mixedData))
159163

160164
// Tests when n % colsPerBlock != 0
161-
val blockMat = idxRowMatMixed.toBlockMatrix(2, 2)
165+
val blockMat = idxRowMatMixed.toBlockMatrix(12, 12)
162166
assert(blockMat.numRows() === m)
163167
assert(blockMat.numCols() === n)
164168
assert(blockMat.toBreeze() === idxRowMatMixed.toBreeze())
165169

166170
// Tests when m % rowsPerBlock != 0
167-
val blockMat2 = idxRowMatMixed.toBlockMatrix(3, 1)
171+
val blockMat2 = idxRowMatMixed.toBlockMatrix(18, 6)
168172
assert(blockMat2.numRows() === m)
169173
assert(blockMat2.numCols() === n)
170174
assert(blockMat2.toBreeze() === idxRowMatMixed.toBreeze())
171175

172176
val blocks = blockMat.blocks.collect()
173177

174-
/* Diagram of mixed data blockmat. Lines indicate blocking.
175-
1 2 | 3
176-
0 0 | 0
177-
-------
178-
0 0 | 0
179-
4 0 | 0
180-
*/
181-
182-
blocks.forall { case((row, col), matrix) =>
183-
if (row == 0) matrix.isInstanceOf[DenseMatrix] else matrix.isInstanceOf[SparseMatrix]}
178+
assert(blocks.forall { case((row, col), matrix) =>
179+
if (row == 0) matrix.isInstanceOf[DenseMatrix] else matrix.isInstanceOf[SparseMatrix]})
184180
}
185181

186182
test("multiply a local matrix") {

0 commit comments

Comments
 (0)