Skip to content

Commit 4582a7e

Browse files
committed
Decided on just one toBlockMatrix implementation
1 parent 12e78bf commit 4582a7e

File tree

2 files changed

+31
-61
lines changed

2 files changed

+31
-61
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,15 @@ class IndexedRowMatrix @Since("1.0.0") (
9191
}
9292

9393
/**
94-
* Converts to BlockMatrix. Creates blocks of `SparseMatrix` with size 1024 x 1024.
94+
* Converts to BlockMatrix. Creates blocks with size 1024 x 1024.
9595
*/
9696
@Since("1.3.0")
9797
def toBlockMatrix(): BlockMatrix = {
9898
toBlockMatrix(1024, 1024)
9999
}
100100

101101
/**
102-
* Converts to BlockMatrix. Creates blocks of `SparseMatrix`.
102+
* Converts to BlockMatrix. Blocks may be sparse or dense depending on the sparsity of the rows.
103103
* @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have
104104
* a smaller value. Must be an integer value greater than 0.
105105
* @param colsPerBlock The number of columns of each block. The blocks at the right edge may have
@@ -108,28 +108,6 @@ class IndexedRowMatrix @Since("1.0.0") (
108108
*/
109109
@Since("1.3.0")
110110
def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = {
111-
// TODO: This implementation may be optimized
112-
toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock)
113-
}
114-
115-
/**
116-
* Converts to BlockMatrix. Creates blocks of `DenseMatrix` with size 1024 x 1024.
117-
*/
118-
@Since("2.2.0")
119-
def toBlockMatrixDense(): BlockMatrix = {
120-
toBlockMatrixDense(1024, 1024)
121-
}
122-
123-
/**
124-
* Converts to BlockMatrix. Creates blocks of `SparseMatrix`.
125-
* @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have
126-
* a smaller value. Must be an integer value greater than 0.
127-
* @param colsPerBlock The number of columns of each block. The blocks at the right edge may have
128-
* a smaller value. Must be an integer value greater than 0.
129-
* @return a [[BlockMatrix]]
130-
*/
131-
@Since("2.2.0")
132-
def toBlockMatrixDense(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = {
133111
require(rowsPerBlock > 0,
134112
s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock")
135113
require(colsPerBlock > 0,
@@ -144,33 +122,48 @@ class IndexedRowMatrix @Since("1.0.0") (
144122
val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt
145123
val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt
146124

147-
val blocks: RDD[((Int, Int), Matrix)] = rows.flatMap{ ir =>
125+
val blocks = rows.flatMap { ir: IndexedRow =>
148126
val blockRow = ir.index / rowsPerBlock
149127
val rowInBlock = ir.index % rowsPerBlock
150128

151-
ir.vector.toArray
152-
.grouped(colsPerBlock)
153-
.zipWithIndex
154-
.map{ case (values, blockColumn) =>
155-
((blockRow.toInt, blockColumn), (rowInBlock.toInt, values))
156-
}
129+
ir.vector match {
130+
case SparseVector(size, indices, values) =>
131+
indices.zip(values).map { case (index, value) =>
132+
val blockColumn = index / colsPerBlock
133+
val columnInBlock = index % colsPerBlock
134+
((blockRow.toInt, blockColumn.toInt), (rowInBlock.toInt, Array((value, columnInBlock))))
135+
}
136+
case DenseVector(values) =>
137+
values.grouped(colsPerBlock)
138+
.zipWithIndex
139+
.map { case (values, blockColumn) =>
140+
((blockRow.toInt, blockColumn), (rowInBlock.toInt, values.zipWithIndex))
141+
}
142+
}
157143
}.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map{
158144
case ((blockRow, blockColumn), itr) =>
159145
val actualNumRows =
160146
if (blockRow == lastRowBlockIndex) lastRowBlockSize else rowsPerBlock
161-
val actualNumColumns: Int =
147+
val actualNumColumns =
162148
if (blockColumn == lastColBlockIndex) lastColBlockSize else colsPerBlock
163149

164150
val arraySize = actualNumRows * actualNumColumns
165151
val matrixAsArray = new Array[Double](arraySize)
166-
itr.foreach{ case (rowWithinBlock, values) =>
167-
var i = 0
168-
while (i < values.length) {
169-
matrixAsArray.update(i * actualNumRows + rowWithinBlock, values(i))
170-
i += 1
152+
var countForValues = 0
153+
itr.foreach { case (rowWithinBlock, valuesWithColumns) =>
154+
valuesWithColumns.foreach { case (value, columnWithinBlock) =>
155+
matrixAsArray.update(columnWithinBlock * actualNumRows + rowWithinBlock, value)
156+
countForValues += 1
171157
}
172158
}
173-
((blockRow, blockColumn), new DenseMatrix(actualNumRows, actualNumColumns, matrixAsArray))
159+
val denseMatrix = new DenseMatrix(actualNumRows, actualNumColumns, matrixAsArray)
160+
val finalMatrix = if (countForValues / arraySize.toDouble > 0.5) {
161+
denseMatrix
162+
} else {
163+
denseMatrix.toSparse
164+
}
165+
166+
((blockRow, blockColumn), finalMatrix)
174167
}
175168
new BlockMatrix(blocks, rowsPerBlock, colsPerBlock)
176169
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,29 +110,6 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
110110
}
111111
}
112112

113-
test("toBlockMatrixDense") {
114-
val idxRowMat = new IndexedRowMatrix(indexedRows)
115-
116-
// Tests when n % colsPerBlock != 0
117-
val blockMat = idxRowMat.toBlockMatrixDense(2, 2)
118-
assert(blockMat.numRows() === m)
119-
assert(blockMat.numCols() === n)
120-
assert(blockMat.toBreeze() === idxRowMat.toBreeze())
121-
122-
// Tests when m % rowsPerBlock != 0
123-
val blockMat2 = idxRowMat.toBlockMatrixDense(3, 1)
124-
assert(blockMat2.numRows() === m)
125-
assert(blockMat2.numCols() === n)
126-
assert(blockMat2.toBreeze() === idxRowMat.toBreeze())
127-
128-
intercept[IllegalArgumentException] {
129-
idxRowMat.toBlockMatrixDense(-1, 2)
130-
}
131-
intercept[IllegalArgumentException] {
132-
idxRowMat.toBlockMatrixDense(2, 0)
133-
}
134-
}
135-
136113
test("multiply a local matrix") {
137114
val A = new IndexedRowMatrix(indexedRows)
138115
val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))

0 commit comments

Comments
 (0)