Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash

import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organize imports (put this group under breeze)


/**
* Trait for a local matrix.
*/
@SQLUserDefinedType(udt = classOf[MatrixUDT])
sealed trait Matrix extends Serializable {

/** Number of rows. */
Expand Down Expand Up @@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}

@DeveloperApi
private[spark] class MatrixUDT extends UserDefinedType[Matrix] {

override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
// set as not nullable, except values since in the future, support for binary matrices might
// be added for which values are not needed.
// the sparse matrix needs colPtrs and rowIndices, which are set as
// null, while building the dense matrix.
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("numRows", IntegerType, nullable = false),
StructField("numCols", IntegerType, nullable = false),
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
StructField("isTransposed", BooleanType, nullable = false)
))
}

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
row.update(3, sm.colPtrs.toSeq)
row.update(4, sm.rowIndices.toSeq)
row.update(5, sm.values.toSeq)
row.setBoolean(6, sm.isTransposed)

case dm: DenseMatrix =>
row.setByte(0, 1)
row.setInt(1, dm.numRows)
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
row.update(5, dm.values.toSeq)
row.setBoolean(6, dm.isTransposed)
}
row
}

override def deserialize(datum: Any): Matrix = {
datum match {
// TODO: something wrong with UDT serialization, should never happen.
case m: Matrix => m
case row: Row =>
require(row.length == 7,
s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
val values = row.getAs[Iterable[Double]](5).toArray
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
val colPtrs = row.getAs[Iterable[Int]](3).toArray
val rowIndices = row.getAs[Iterable[Int]](4).toArray
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
}
}
}

override def userClass: Class[Matrix] = classOf[Matrix]

override def equals(o: Any): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we implement equals, let's add a hashCode that returns a predefined random integer. See

https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala#L186

(and you can pick a value;)

o match {
case v: MatrixUDT => true
case _ => false
}
}

override def hashCode(): Int = 1994

private[spark] override def asNullable: MatrixUDT = this
}

/**
* Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns listed in sequence.
Expand All @@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
* @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
* row major.
*/
@SQLUserDefinedType(udt = classOf[MatrixUDT])
class DenseMatrix(
val numRows: Int,
val numCols: Int,
Expand Down Expand Up @@ -356,6 +445,7 @@ object DenseMatrix {
* Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
* and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
@SQLUserDefinedType(udt = classOf[MatrixUDT])
class SparseMatrix(
val numRows: Int,
val numCols: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}

test("MatrixUDT") {
val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
val dm3 = new DenseMatrix(0, 0, Array())
val sm1 = dm1.toSparse
val sm2 = dm2.toSparse
val sm3 = dm3.toSparse
val mUDT = new MatrixUDT()
Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
}
}
}