Skip to content

Commit 5ab652c

Browse files
MechCodermengxr
authored andcommitted
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Utilities for pickling and unpickling SparseMatrices using SerDe Author: MechCoder <[email protected]> Closes #5775 from MechCoder/spark-7202 and squashes the following commits: 7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
1 parent c6d1efb commit 5ab652c

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
10151015
}
10161016
}
10171017

1018+
// Pickler for SparseMatrix
1019+
private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
1020+
1021+
def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
1022+
val s = obj.asInstanceOf[SparseMatrix]
1023+
val order = ByteOrder.nativeOrder()
1024+
1025+
val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
1026+
val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
1027+
val valuesBytes = new Array[Byte](8 * s.values.length)
1028+
val isTransposed = if (s.isTransposed) 1 else 0
1029+
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
1030+
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
1031+
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
1032+
1033+
out.write(Opcodes.MARK)
1034+
out.write(Opcodes.BININT)
1035+
out.write(PickleUtils.integer_to_bytes(s.numRows))
1036+
out.write(Opcodes.BININT)
1037+
out.write(PickleUtils.integer_to_bytes(s.numCols))
1038+
out.write(Opcodes.BINSTRING)
1039+
out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
1040+
out.write(colPtrsBytes)
1041+
out.write(Opcodes.BINSTRING)
1042+
out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
1043+
out.write(indicesBytes)
1044+
out.write(Opcodes.BINSTRING)
1045+
out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
1046+
out.write(valuesBytes)
1047+
out.write(Opcodes.BININT)
1048+
out.write(PickleUtils.integer_to_bytes(isTransposed))
1049+
out.write(Opcodes.TUPLE)
1050+
}
1051+
1052+
def construct(args: Array[Object]): Object = {
1053+
if (args.length != 6) {
1054+
throw new PickleException("should be 6")
1055+
}
1056+
val order = ByteOrder.nativeOrder()
1057+
val colPtrsBytes = getBytes(args(2))
1058+
val indicesBytes = getBytes(args(3))
1059+
val valuesBytes = getBytes(args(4))
1060+
val colPtrs = new Array[Int](colPtrsBytes.length / 4)
1061+
val rowIndices = new Array[Int](indicesBytes.length / 4)
1062+
val values = new Array[Double](valuesBytes.length / 8)
1063+
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
1064+
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
1065+
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
1066+
val isTransposed = args(5).asInstanceOf[Int] == 1
1067+
new SparseMatrix(
1068+
args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
1069+
isTransposed)
1070+
}
1071+
}
1072+
10181073
// Pickler for SparseVector
10191074
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
10201075

@@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
10991154
if (!initialized) {
11001155
new DenseVectorPickler().register()
11011156
new DenseMatrixPickler().register()
1157+
new SparseMatrixPickler().register()
11021158
new SparseVectorPickler().register()
11031159
new LabeledPointPickler().register()
11041160
new RatingPickler().register()

mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
22+
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
2323
import org.apache.spark.mllib.regression.LabeledPoint
2424
import org.apache.spark.mllib.recommendation.Rating
2525

@@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
7777
val emptyMatrix = Matrices.dense(0, 0, empty)
7878
val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
7979
assert(emptyMatrix == ne)
80+
81+
val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
82+
val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
83+
assert(sm.toArray === nsm.toArray)
84+
85+
val smt = new SparseMatrix(
86+
3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
87+
isTransposed=true)
88+
val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
89+
assert(smt.toArray === nsmt.toArray)
8090
}
8191

8292
test("pickle rating") {

python/pyspark/mllib/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def __reduce__(self):
755755
return SparseMatrix, (
756756
self.numRows, self.numCols, self.colPtrs.tostring(),
757757
self.rowIndices.tostring(), self.values.tostring(),
758-
self.isTransposed)
758+
int(self.isTransposed))
759759

760760
def __getitem__(self, indices):
761761
i, j = indices
@@ -801,7 +801,7 @@ def toDense(self):
801801

802802
# TODO: More efficient implementation:
803803
def __eq__(self, other):
804-
return np.all(self.toArray == other.toArray)
804+
return np.all(self.toArray() == other.toArray())
805805

806806

807807
class Matrices(object):

python/pyspark/mllib/tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def test_serialize(self):
9292
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
9393
self._test_serialize(SparseVector(3, {}))
9494
self._test_serialize(DenseMatrix(2, 3, range(6)))
95+
sm1 = SparseMatrix(
96+
3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
97+
self._test_serialize(sm1)
9598

9699
def test_dot(self):
97100
sv = SparseVector(4, {1: 1, 3: 2})

0 commit comments

Comments
 (0)