Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package org.apache.spark.mllib.linalg.distributed

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV}

import org.apache.spark.{Partitioner, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix}
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -264,13 +264,35 @@ class BlockMatrix @Since("1.3.0") (
new CoordinateMatrix(entryRDD, numRows(), numCols())
}


/** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
@Since("1.3.0")
def toIndexedRowMatrix(): IndexedRowMatrix = {
require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " +
s"numCols: ${numCols()}")
// TODO: This implementation may be optimized
toCoordinateMatrix().toIndexedRowMatrix()
val cols = numCols().toInt

require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).")

val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) =>
mat.rowIter.zipWithIndex.map {
case (vector, rowIdx) =>
blockRowIdx * rowsPerBlock + rowIdx -> (blockColIdx, vector.toBreeze)
}
}.groupByKey().map { case (rowIdx, vectors) =>
val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble

val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz
BSV.zeros[Double](cols)
} else {
BDV.zeros[Double](cols)
}

vectors.foreach { case (blockColIdx: Int, vec: BV[Double]) =>
val offset = colsPerBlock * blockColIdx
wholeVector(offset until offset + colsPerBlock) := vec
}
new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector))
}
new IndexedRowMatrix(rows)
}

/** Collect the distributed matrix on the driver as a `DenseMatrix`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.mllib.linalg.distributed

import java.{util => ju}

import breeze.linalg.{DenseMatrix => BDM}
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV}

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix}
import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Matrices, Matrix, SparseMatrix, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

Expand Down Expand Up @@ -134,6 +134,33 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(rowMat.numRows() === m)
assert(rowMat.numCols() === n)
assert(rowMat.toBreeze() === gridBasedMat.toBreeze())

val rows = 1
val cols = 10

val matDense = new DenseMatrix(rows, cols,
Array(1.0, 1.0, 3.0, 2.0, 5.0, 6.0, 7.0, 1.0, 2.0, 3.0))
val matSparse = new SparseMatrix(rows, cols,
Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), Array(0), Array(1.0))

val vectors: Seq[((Int, Int), Matrix)] = Seq(
((0, 0), matDense),
((1, 0), matSparse))

val rdd = sc.parallelize(vectors)
val B = new BlockMatrix(rdd, rows, cols)

val C = B.toIndexedRowMatrix.rows.collect

(C(0).vector.toBreeze, C(1).vector.toBreeze) match {
case (denseVector: BDV[Double], sparseVector: BSV[Double]) =>
assert(denseVector.length === sparseVector.length)

assert(matDense.toArray === denseVector.toArray)
assert(matSparse.toArray === sparseVector.toArray)
case _ =>
throw new RuntimeException("IndexedRow returns vectors of unexpected type")
}
}

test("toBreeze and toLocalMatrix") {
Expand Down