Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -88,9 +89,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
value match {
case x: String =>
compact(render(JString(x)))
case v: Vector =>
v.toJson
case _ =>
throw new NotImplementedError(
"The default jsonEncode only supports string. " +
"The default jsonEncode only supports string and vector. " +
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
}
}
Expand All @@ -100,9 +103,14 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
parse(json) match {
case JString(x) =>
x.asInstanceOf[T]
case JObject(v) =>
val keys = v.map(_._1)
assert(keys.contains("type") && keys.contains("values"),
s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.")
Vectors.fromJson(json).asInstanceOf[T]
case _ =>
throw new NotImplementedError(
"The default jsonDecode only supports string. " +
"The default jsonDecode only supports string and vector. " +
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
}
Expand Down
22 changes: 18 additions & 4 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.param

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}

class ParamsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -80,7 +81,7 @@ class ParamsSuite extends SparkFunSuite {
}
}

{ // StringParam
{ // Param[String]
val param = new Param[String](dummy, "name", "doc")
// Currently we do not support null.
for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
Expand All @@ -89,6 +90,19 @@ class ParamsSuite extends SparkFunSuite {
}
}

{ // Param[Vector]
val param = new Param[Vector](dummy, "name", "doc")
val values = Seq(
Vectors.dense(Array.empty[Double]),
Vectors.dense(0.0, 2.0),
Vectors.sparse(0, Array.empty, Array.empty),
Vectors.sparse(2, Array(1), Array(2.0)))
for (value <- values) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // IntArrayParam
val param = new IntArrayParam(dummy, "name", "doc")
val values: Seq[Array[Int]] = Seq(
Expand Down Expand Up @@ -138,7 +152,7 @@ class ParamsSuite extends SparkFunSuite {
test("param") {
val solver = new TestParams()
val uid = solver.uid
import solver.{maxIter, inputCol}
import solver.{inputCol, maxIter}

assert(maxIter.name === "maxIter")
assert(maxIter.doc === "maximum number of iterations (>= 0)")
Expand Down Expand Up @@ -181,7 +195,7 @@ class ParamsSuite extends SparkFunSuite {

test("param map") {
val solver = new TestParams()
import solver.{maxIter, inputCol}
import solver.{inputCol, maxIter}

val map0 = ParamMap.empty

Expand Down Expand Up @@ -220,7 +234,7 @@ class ParamsSuite extends SparkFunSuite {

test("params") {
val solver = new TestParams()
import solver.{handleInvalid, maxIter, inputCol}
import solver.{handleInvalid, inputCol, maxIter}

val params = solver.params
assert(params.length === 3)
Expand Down