@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
2323
2424import breeze .linalg .{CSCMatrix => BSM , DenseMatrix => BDM , Matrix => BM }
2525
26+ import org .apache .spark .annotation .DeveloperApi
27+ import org .apache .spark .sql .Row
28+ import org .apache .spark .sql .types ._
29+ import org .apache .spark .sql .catalyst .expressions .GenericMutableRow
30+
2631/**
2732 * Trait for a local matrix.
2833 */
34+ @ SQLUserDefinedType (udt = classOf [MatrixUDT ])
2935sealed trait Matrix extends Serializable {
3036
3137 /** Number of rows. */
@@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
102108 private [spark] def foreachActive (f : (Int , Int , Double ) => Unit )
103109}
104110
111+ @ DeveloperApi
112+ private [spark] class MatrixUDT extends UserDefinedType [Matrix ] {
113+
114+ override def sqlType : StructType = {
115+ // type: 0 = sparse, 1 = dense
116+ // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
117+ // set as not nullable, except values since in the future, support for binary matrices might
118+ // be added for which values are not needed.
119+ // the sparse matrix needs colPtrs and rowIndices, which are set as
120+ // null, while building the dense matrix.
121+ StructType (Seq (
122+ StructField (" type" , ByteType , nullable = false ),
123+ StructField (" numRows" , IntegerType , nullable = false ),
124+ StructField (" numCols" , IntegerType , nullable = false ),
125+ StructField (" colPtrs" , ArrayType (IntegerType , containsNull = false ), nullable = true ),
126+ StructField (" rowIndices" , ArrayType (IntegerType , containsNull = false ), nullable = true ),
127+ StructField (" values" , ArrayType (DoubleType , containsNull = false ), nullable = true ),
128+ StructField (" isTransposed" , BooleanType , nullable = false )
129+ ))
130+ }
131+
132+ override def serialize (obj : Any ): Row = {
133+ val row = new GenericMutableRow (7 )
134+ obj match {
135+ case sm : SparseMatrix =>
136+ row.setByte(0 , 0 )
137+ row.setInt(1 , sm.numRows)
138+ row.setInt(2 , sm.numCols)
139+ row.update(3 , sm.colPtrs.toSeq)
140+ row.update(4 , sm.rowIndices.toSeq)
141+ row.update(5 , sm.values.toSeq)
142+ row.setBoolean(6 , sm.isTransposed)
143+
144+ case dm : DenseMatrix =>
145+ row.setByte(0 , 1 )
146+ row.setInt(1 , dm.numRows)
147+ row.setInt(2 , dm.numCols)
148+ row.setNullAt(3 )
149+ row.setNullAt(4 )
150+ row.update(5 , dm.values.toSeq)
151+ row.setBoolean(6 , dm.isTransposed)
152+ }
153+ row
154+ }
155+
156+ override def deserialize (datum : Any ): Matrix = {
157+ datum match {
158+ // TODO: something wrong with UDT serialization, should never happen.
159+ case m : Matrix => m
160+ case row : Row =>
161+ require(row.length == 7 ,
162+ s " MatrixUDT.deserialize given row with length ${row.length} but requires length == 7 " )
163+ val tpe = row.getByte(0 )
164+ val numRows = row.getInt(1 )
165+ val numCols = row.getInt(2 )
166+ val values = row.getAs[Iterable [Double ]](5 ).toArray
167+ val isTransposed = row.getBoolean(6 )
168+ tpe match {
169+ case 0 =>
170+ val colPtrs = row.getAs[Iterable [Int ]](3 ).toArray
171+ val rowIndices = row.getAs[Iterable [Int ]](4 ).toArray
172+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices, values, isTransposed)
173+ case 1 =>
174+ new DenseMatrix (numRows, numCols, values, isTransposed)
175+ }
176+ }
177+ }
178+
179+ override def userClass : Class [Matrix ] = classOf [Matrix ]
180+
181+ override def equals (o : Any ): Boolean = {
182+ o match {
183+ case v : MatrixUDT => true
184+ case _ => false
185+ }
186+ }
187+
188+ override def hashCode (): Int = 1994
189+
190+ private [spark] override def asNullable : MatrixUDT = this
191+ }
192+
105193/**
106194 * Column-major dense matrix.
107195 * The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
119207 * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
120208 * row major.
121209 */
210+ @ SQLUserDefinedType (udt = classOf [MatrixUDT ])
122211class DenseMatrix (
123212 val numRows : Int ,
124213 val numCols : Int ,
@@ -360,6 +449,7 @@ object DenseMatrix {
360449 * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
361450 * and `rowIndices` behave as colIndices, and `values` are stored in row major.
362451 */
452+ @ SQLUserDefinedType (udt = classOf [MatrixUDT ])
363453class SparseMatrix (
364454 val numRows : Int ,
365455 val numCols : Int ,
0 commit comments