Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
}
}

// Pickler for SparseMatrix
private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {

def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val s = obj.asInstanceOf[SparseMatrix]
val order = ByteOrder.nativeOrder()

val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
val valuesBytes = new Array[Byte](8 * s.values.length)
val isTransposed = if (s.isTransposed) 1 else 0
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)

out.write(Opcodes.MARK)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(s.numRows))
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(s.numCols))
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
out.write(colPtrsBytes)
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
out.write(indicesBytes)
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
out.write(valuesBytes)
out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(isTransposed))
out.write(Opcodes.TUPLE)
}

def construct(args: Array[Object]): Object = {
if (args.length != 6) {
throw new PickleException("should be 6")
}
val order = ByteOrder.nativeOrder()
val colPtrsBytes = getBytes(args(2))
val indicesBytes = getBytes(args(3))
val valuesBytes = getBytes(args(4))
val colPtrs = new Array[Int](colPtrsBytes.length / 4)
val rowIndices = new Array[Int](indicesBytes.length / 4)
val values = new Array[Double](valuesBytes.length / 8)
ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
val isTransposed = args(5).asInstanceOf[Int] == 1
new SparseMatrix(
args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
isTransposed)
}
}

// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {

Expand Down Expand Up @@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
if (!initialized) {
new DenseVectorPickler().register()
new DenseMatrixPickler().register()
new SparseMatrixPickler().register()
new SparseVectorPickler().register()
new LabeledPointPickler().register()
new RatingPickler().register()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating

Expand Down Expand Up @@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
val emptyMatrix = Matrices.dense(0, 0, empty)
val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
assert(emptyMatrix == ne)

val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
assert(sm.toArray === nsm.toArray)

val smt = new SparseMatrix(
3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
isTransposed=true)
val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
assert(smt.toArray === nsmt.toArray)
}

test("pickle rating") {
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def __reduce__(self):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),
self.rowIndices.tostring(), self.values.tostring(),
self.isTransposed)
int(self.isTransposed))

def __getitem__(self, indices):
i, j = indices
Expand Down Expand Up @@ -801,7 +801,7 @@ def toDense(self):

# TODO: More efficient implementation:
def __eq__(self, other):
return np.all(self.toArray == other.toArray)
return np.all(self.toArray() == other.toArray())


class Matrices(object):
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def test_serialize(self):
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
self._test_serialize(SparseVector(3, {}))
self._test_serialize(DenseMatrix(2, 3, range(6)))
sm1 = SparseMatrix(
3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
self._test_serialize(sm1)

def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})
Expand Down