Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import scala.annotation.varargs
import scala.collection.JavaConverters._

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson}

import org.apache.spark.SparkException
import org.apache.spark.annotation.{AlphaComponent, Since}
Expand Down Expand Up @@ -171,6 +174,12 @@ sealed trait Vector extends Serializable {
*/
@Since("1.5.0")
def argmax: Int

/**
* Converts the vector to a JSON string.
*/
@Since("1.6.0")
def toJson: String
}

/**
Expand Down Expand Up @@ -339,6 +348,27 @@ object Vectors {
parseNumeric(NumericParser.parse(s))
}

/**
* Parses the JSON representation of a vector into a [[Vector]].
*/
@Since("1.6.0")
def fromJson(json: String): Vector = {
implicit val formats = DefaultFormats
val jValue = parseJson(json)
(jValue \ "type").extract[Int] match {
case 0 => // sparse
val size = (jValue \ "size").extract[Int]
val indices = (jValue \ "indices").extract[Seq[Int]].toArray
val values = (jValue \ "values").extract[Seq[Double]].toArray
sparse(size, indices, values)
case 1 => // dense
val values = (jValue \ "values").extract[Seq[Double]].toArray
dense(values)
case _ =>
throw new IllegalArgumentException(s"Cannot parse $json into a vector.")
}
}

private[mllib] def parseNumeric(any: Any): Vector = {
any match {
case values: Array[Double] =>
Expand Down Expand Up @@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") (
maxIdx
}
}

@Since("1.6.0")
override def toJson: String = {
val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
compact(render(jValue))
}
}

@Since("1.3.0")
Expand Down Expand Up @@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") (
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
}

@Since("1.6.0")
override def toJson: String = {
val jValue = ("type" -> 0) ~
("size" -> size) ~
("indices" -> indices.toSeq) ~
("values" -> values.toSeq)
compact(render(jValue))
}
}

@Since("1.3.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg
import scala.util.Random

import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
import org.json4s.jackson.JsonMethods.{parse => parseJson}

import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.mllib.util.TestingUtils._
Expand Down Expand Up @@ -374,4 +375,20 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
}

test("toJson/fromJson") {
val sv0 = Vectors.sparse(0, Array.empty, Array.empty)
val sv1 = Vectors.sparse(1, Array.empty, Array.empty)
val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
val dv0 = Vectors.dense(Array.empty[Double])
val dv1 = Vectors.dense(1.0)
val dv2 = Vectors.dense(0.0, 2.0)
for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) {
val json = v.toJson
parseJson(json) // `json` should be a valid JSON string
val u = Vectors.fromJson(json)
assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.")
assert(u === v, "toJson/fromJson should preserve vector values.")
}
}
}
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ object MimaExcludes {
) ++ Seq (
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.status.api.v1.ApplicationInfo.this")
) ++ Seq(
// SPARK-11766 add toJson to Vector
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.toJson")
)
case v if v.startsWith("1.5") =>
Seq(
Expand Down