Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(4)
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, indices.toSeq)
row.update(3, values.toSeq)
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, values.toSeq)
row
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
// TODO: deserialize may get called twice. See SPARK-7186.
case row: Row =>
row
}
row
}

override def deserialize(datum: Any): Vector = {
datum match {
// TODO: something wrong with UDT serialization
case v: Vector =>
v
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
Expand All @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values)
}
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
// TODO: deserialize may get called twice. See SPARK-7186.
case v: Vector =>
v
}
}

Expand Down