Skip to content

Commit 11e0259

Browse files
MechCodermengxr
authored andcommitted
[SPARK-6309] [SQL] [MLlib] Implement MatrixUDT
Utilities to serialize and deserialize Matrices in MLlib Author: MechCoder <[email protected]> Closes apache#5048 from MechCoder/spark-6309 and squashes the following commits: 05dc6f2 [MechCoder] Hashcode and organize imports 16d5d47 [MechCoder] Test some more 6e67020 [MechCoder] TST: Test using Array conversion instead of equals 7fa7a2c [MechCoder] [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT
1 parent 49a01c7 commit 11e0259

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
2323

2424
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
2525

26+
import org.apache.spark.annotation.DeveloperApi
27+
import org.apache.spark.sql.Row
28+
import org.apache.spark.sql.types._
29+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
30+
2631
/**
2732
* Trait for a local matrix.
2833
*/
34+
@SQLUserDefinedType(udt = classOf[MatrixUDT])
2935
sealed trait Matrix extends Serializable {
3036

3137
/** Number of rows. */
@@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
102108
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
103109
}
104110

111+
@DeveloperApi
112+
private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
113+
114+
override def sqlType: StructType = {
115+
// type: 0 = sparse, 1 = dense
116+
// the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
117+
// set as not nullable, except values since in the future, support for binary matrices might
118+
// be added for which values are not needed.
119+
// the sparse matrix needs colPtrs and rowIndices, which are set as
120+
// null, while building the dense matrix.
121+
StructType(Seq(
122+
StructField("type", ByteType, nullable = false),
123+
StructField("numRows", IntegerType, nullable = false),
124+
StructField("numCols", IntegerType, nullable = false),
125+
StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
126+
StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
127+
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
128+
StructField("isTransposed", BooleanType, nullable = false)
129+
))
130+
}
131+
132+
override def serialize(obj: Any): Row = {
133+
val row = new GenericMutableRow(7)
134+
obj match {
135+
case sm: SparseMatrix =>
136+
row.setByte(0, 0)
137+
row.setInt(1, sm.numRows)
138+
row.setInt(2, sm.numCols)
139+
row.update(3, sm.colPtrs.toSeq)
140+
row.update(4, sm.rowIndices.toSeq)
141+
row.update(5, sm.values.toSeq)
142+
row.setBoolean(6, sm.isTransposed)
143+
144+
case dm: DenseMatrix =>
145+
row.setByte(0, 1)
146+
row.setInt(1, dm.numRows)
147+
row.setInt(2, dm.numCols)
148+
row.setNullAt(3)
149+
row.setNullAt(4)
150+
row.update(5, dm.values.toSeq)
151+
row.setBoolean(6, dm.isTransposed)
152+
}
153+
row
154+
}
155+
156+
override def deserialize(datum: Any): Matrix = {
157+
datum match {
158+
// TODO: something wrong with UDT serialization, should never happen.
159+
case m: Matrix => m
160+
case row: Row =>
161+
require(row.length == 7,
162+
s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
163+
val tpe = row.getByte(0)
164+
val numRows = row.getInt(1)
165+
val numCols = row.getInt(2)
166+
val values = row.getAs[Iterable[Double]](5).toArray
167+
val isTransposed = row.getBoolean(6)
168+
tpe match {
169+
case 0 =>
170+
val colPtrs = row.getAs[Iterable[Int]](3).toArray
171+
val rowIndices = row.getAs[Iterable[Int]](4).toArray
172+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
173+
case 1 =>
174+
new DenseMatrix(numRows, numCols, values, isTransposed)
175+
}
176+
}
177+
}
178+
179+
override def userClass: Class[Matrix] = classOf[Matrix]
180+
181+
override def equals(o: Any): Boolean = {
182+
o match {
183+
case v: MatrixUDT => true
184+
case _ => false
185+
}
186+
}
187+
188+
override def hashCode(): Int = 1994
189+
190+
private[spark] override def asNullable: MatrixUDT = this
191+
}
192+
105193
/**
106194
* Column-major dense matrix.
107195
* The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
119207
* @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
120208
* row major.
121209
*/
210+
@SQLUserDefinedType(udt = classOf[MatrixUDT])
122211
class DenseMatrix(
123212
val numRows: Int,
124213
val numCols: Int,
@@ -360,6 +449,7 @@ object DenseMatrix {
360449
* Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
361450
* and `rowIndices` behave as colIndices, and `values` are stored in row major.
362451
*/
452+
@SQLUserDefinedType(udt = classOf[MatrixUDT])
363453
class SparseMatrix(
364454
val numRows: Int,
365455
val numCols: Int,

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
424424
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
425425
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
426426
}
427+
428+
test("MatrixUDT") {
429+
val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
430+
val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
431+
val dm3 = new DenseMatrix(0, 0, Array())
432+
val sm1 = dm1.toSparse
433+
val sm2 = dm2.toSparse
434+
val sm3 = dm3.toSparse
435+
val mUDT = new MatrixUDT()
436+
Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
437+
mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
438+
}
439+
}
427440
}

0 commit comments

Comments
 (0)