From 60394d9d0d5762ade918ee8b301aa3e760abf3ff Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 7 Sep 2017 18:54:58 +0800 Subject: [PATCH 1/8] init pr --- .../org/apache/spark/ml/stat/Summarizer.scala | 98 ++++---- .../spark/ml/stat/SummarizerSuite.scala | 222 +++++------------- 2 files changed, 112 insertions(+), 208 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index cae41edb7aca..210bf85e48c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -41,15 +41,12 @@ sealed abstract class SummaryBuilder { /** * Returns an aggregate object that contains the summary of the column with the requested metrics. * @param featuresCol a column that contains features Vector object. - * @param weightCol a column that contains weight value. + * @param weightCol (Optional param) a column that contains weight value. Default weight is 1.0. * @return an aggregate column that contains the statistics. The exact content of this * structure is determined during the creation of the builder. */ @Since("2.3.0") - def summary(featuresCol: Column, weightCol: Column): Column - - @Since("2.3.0") - def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0)) + def summary(featuresCol: Column, weightCol: Column = lit(1.0)): Column } /** @@ -60,15 +57,17 @@ sealed abstract class SummaryBuilder { * This class lets users pick the statistics they would like to extract for a given column. Here is * an example in Scala: * {{{ - * val dataframe = ... // Some dataframe containing a feature column - * val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features")) - * val Row(Row(min_, max_)) = allStats.first() + * import org.apache.spark.ml.linalg._ + * val dataframe = ... // Some dataframe containing a feature column and a weight column + * val multiStatsDF = dataframe.select( + * Summarizer.metrics("min", "max", "count").summary($"features", $"weight") + * val Tuple1((minVec, maxVec, count)) = multiStatsDF.as[Tuple1[(Vector, Vector, Long)]].first() * }}} * * If one wants to get a single metric, shortcuts are also available: * {{{ * val meanDF = dataframe.select(Summarizer.mean($"features")) - * val Row(mean_) = meanDF.first() + * val Tuple1(meanVec) = meanDF.as[Tuple1[Vector]].first() * }}} * * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD @@ -109,31 +108,47 @@ object Summarizer extends Logging { } @Since("2.3.0") - def mean(col: Column): Column = getSingleMetric(col, "mean") + def mean(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "mean") + } @Since("2.3.0") - def variance(col: Column): Column = getSingleMetric(col, "variance") + def variance(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "variance") + } @Since("2.3.0") - def count(col: Column): Column = getSingleMetric(col, "count") + def count(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "count") + } @Since("2.3.0") - def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros") + def numNonZeros(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "numNonZeros") + } @Since("2.3.0") - def max(col: Column): Column = getSingleMetric(col, "max") + def max(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "max") + } @Since("2.3.0") - def min(col: Column): Column = getSingleMetric(col, "min") + def min(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "min") + } @Since("2.3.0") - def normL1(col: Column): Column = getSingleMetric(col, "normL1") + def normL1(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "normL1") + } @Since("2.3.0") - def normL2(col: Column): Column = getSingleMetric(col, "normL2") + def normL2(col: Column, weightCol: Column = lit(1.0)): Column = { + getSingleMetric(col, weightCol, "normL2") + } - private def getSingleMetric(col: Column, metric: String): Column = { - val c1 = metrics(metric).summary(col) + private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = { + val c1 = metrics(metric).summary(col, weightCol) c1.getField(metric).as(s"$metric($col)") } } @@ -187,8 +202,7 @@ private[ml] object SummaryBuilderImpl extends Logging { StructType(fields) } - private val arrayDType = ArrayType(DoubleType, containsNull = false) - private val arrayLType = ArrayType(LongType, containsNull = false) + private[this] val vectorUDT = new VectorUDT /** * All the metrics that can be currently computed by Spark for vectors. @@ -197,14 +211,14 @@ private[ml] object SummaryBuilderImpl extends Logging { * metrics that need to de computed internally to get the final result. */ private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq( - ("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)), - ("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), + ("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)), + ("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)), ("count", Count, LongType, Seq()), - ("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)), - ("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)), - ("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)), - ("normL2", NormL2, arrayDType, Seq(ComputeM2)), - ("normL1", NormL1, arrayDType, Seq(ComputeL1)) + ("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)), + ("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)), + ("min", Min, vectorUDT, Seq(ComputeMin, ComputeNNZ)), + ("normL2", NormL2, vectorUDT, Seq(ComputeM2)), + ("normL1", NormL1, vectorUDT, Seq(ComputeL1)) ) /** @@ -527,27 +541,28 @@ private[ml] object SummaryBuilderImpl extends Logging { weightExpr: Expression, mutableAggBufferOffset: Int, inputAggBufferOffset: Int) - extends TypedImperativeAggregate[SummarizerBuffer] { + extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes { - override def eval(state: SummarizerBuffer): InternalRow = { + override def eval(state: SummarizerBuffer): Any = { val metrics = requestedMetrics.map { - case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray) - case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray) + case Mean => vectorUDT.serialize(state.mean) + case Variance => vectorUDT.serialize(state.variance) case Count => state.count - case NumNonZeros => UnsafeArrayData.fromPrimitiveArray( - state.numNonzeros.toArray.map(_.toLong)) - case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray) - case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray) - case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray) - case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray) + case NumNonZeros => vectorUDT.serialize(state.numNonzeros) + case Max => vectorUDT.serialize(state.max) + case Min => vectorUDT.serialize(state.min) + case NormL2 => vectorUDT.serialize(state.normL2) + case NormL1 => vectorUDT.serialize(state.normL1) } InternalRow.apply(metrics: _*) } + override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil + override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { - val features = udt.deserialize(featuresExpr.eval(row)) + val features = vectorUDT.deserialize(featuresExpr.eval(row)) val weight = weightExpr.eval(row).asInstanceOf[Double] state.add(features, weight) state @@ -591,7 +606,4 @@ private[ml] object SummaryBuilderImpl extends Logging { override def prettyName: String = "aggregate_metrics" } - - private[this] val udt = new VectorUDT - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 1ea851ef2d67..2ea0e2f1f363 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.ml.stat import org.scalatest.exceptions.TestFailedException -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -34,134 +33,88 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { import Summarizer._ import SummaryBuilderImpl._ - private case class ExpectedMetrics( - mean: Seq[Double], - variance: Seq[Double], - count: Long, - numNonZeros: Seq[Long], - max: Seq[Double], - min: Seq[Double], - normL2: Seq[Double], - normL1: Seq[Double]) - /** - * The input is expected to be either a sparse vector, a dense vector or an array of doubles - * (which will be converted to a dense vector) - * The expected is the list of all the known metrics. + * The input is expected to be either a sparse vector, a dense vector. * - * The tests take an list of input vectors and a list of all the summary values that - * are expected for this input. They currently test against some fixed subset of the - * metrics, but should be made fuzzy in the future. + * The tests take an list of input vectors, and compare results with + * `mllib.stat.MultivariateOnlineSummarizer`. They currently test against some fixed subset + * of the metrics, but should be made fuzzy in the future. */ - private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = { - - def inputVec: Seq[Vector] = input.map { - case x: Array[Double @unchecked] => Vectors.dense(x) - case x: Seq[Double @unchecked] => Vectors.dense(x.toArray) - case x: Vector => x - case x => throw new Exception(x.toString) - } + private def testExample(name: String, inputVec: Seq[(Vector, Double)]): Unit = { val summarizer = { val _summarizer = new MultivariateOnlineSummarizer - inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v))) + inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1), v._2)) _summarizer } // Because the Spark context is reset between tests, we cannot hold a reference onto it. def wrappedInit() = { - val df = inputVec.map(Tuple1.apply).toDF("features") - val col = df.col("features") - (df, col) + val df = inputVec.toDF("features", "weight") + val featuresCol = df.col("features") + val weightCol = df.col("weight") + (df, featuresCol, weightCol) } registerTest(s"$name - mean only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), summarizer.mean)) - } - - registerTest(s"$name - mean only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(mean(c)), Seq(exp.mean)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("mean").summary(c, weight), mean(c, weight)), + Seq(Row(summarizer.mean), summarizer.mean)) } registerTest(s"$name - variance only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("variance").summary(c), variance(c)), - Seq(Row(exp.variance), summarizer.variance)) - } - - registerTest(s"$name - variance only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(variance(c)), Seq(summarizer.variance)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("variance").summary(c, weight), variance(c, weight)), + Seq(Row(summarizer.variance), summarizer.variance)) } registerTest(s"$name - count only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("count").summary(c), count(c)), - Seq(Row(exp.count), exp.count)) - } - - registerTest(s"$name - count only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(count(c)), - Seq(exp.count)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("count").summary(c, weight), count(c, weight)), + Seq(Row(summarizer.count), summarizer.count)) } registerTest(s"$name - numNonZeros only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), - Seq(Row(exp.numNonZeros), exp.numNonZeros)) - } - - registerTest(s"$name - numNonZeros only (direct)") { - val (df, c) = wrappedInit() - compare(df.select(numNonZeros(c)), - Seq(exp.numNonZeros)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("numNonZeros").summary(c, weight), numNonZeros(c, weight)), + Seq(Row(summarizer.numNonzeros), summarizer.numNonzeros)) } registerTest(s"$name - min only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("min").summary(c), min(c)), - Seq(Row(exp.min), exp.min)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("min").summary(c, weight), min(c, weight)), + Seq(Row(summarizer.min), summarizer.min)) } registerTest(s"$name - max only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("max").summary(c), max(c)), - Seq(Row(exp.max), exp.max)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("max").summary(c, weight), max(c, weight)), + Seq(Row(summarizer.max), summarizer.max)) } registerTest(s"$name - normL1 only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("normL1").summary(c), normL1(c)), - Seq(Row(exp.normL1), exp.normL1)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("normL1").summary(c, weight), normL1(c, weight)), + Seq(Row(summarizer.normL1), summarizer.normL1)) } registerTest(s"$name - normL2 only") { - val (df, c) = wrappedInit() - compare(df.select(metrics("normL2").summary(c), normL2(c)), - Seq(Row(exp.normL2), exp.normL2)) + val (df, c, weight) = wrappedInit() + compare(df.select(metrics("normL2").summary(c, weight), normL2(c, weight)), + Seq(Row(summarizer.normL2), summarizer.normL2)) } - registerTest(s"$name - all metrics at once") { - val (df, c) = wrappedInit() + registerTest(s"$name - multiple metrics at once") { + val (df, c, weight) = wrappedInit() compare(df.select( - metrics("mean", "variance", "count", "numNonZeros").summary(c), - mean(c), variance(c), count(c), numNonZeros(c)), - Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros), - exp.mean, exp.variance, exp.count, exp.numNonZeros)) + metrics("mean", "variance", "count", "numNonZeros").summary(c, weight)), + Seq(Row(summarizer.mean, summarizer.variance, summarizer.count, summarizer.numNonzeros)) + ) } } - private def denseData(input: Seq[Seq[Double]]): DataFrame = { - input.map(_.toArray).map(Vectors.dense).map(Tuple1.apply).toDF("features") - } - private def compare(df: DataFrame, exp: Seq[Any]): Unit = { - val coll = df.collect().toSeq - val Seq(row) = coll - val res = row.toSeq + val res = df.head().toSeq val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" } assert(res.size === exp.size, (res.size, exp.size)) for (((x1, x2), name) <- res.zip(exp).zip(names)) { @@ -171,32 +124,18 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { // Compares structured content. private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match { - case (y1: Seq[Double @unchecked], v1: OldVector) => - compareStructures(y1, v1.toArray.toSeq, name) - case (d1: Double, d2: Double) => - assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name) - case (r1: GenericRowWithSchema, r2: Row) => - assert(r1.size === r2.size, (r1, r2)) - for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) { - compareStructures(x1, x2, s"$name.$fname") - } case (r1: Row, r2: Row) => assert(r1.size === r2.size, (r1, r2)) for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) } case (v1: Vector, v2: Vector) => - assert2(v1 ~== v2 absTol 1e-4, name) + assertWithHint(v1 ~== v2 absTol 1e-4, name) + case (v1: Vector, v2: OldVector) => + compareStructures(v1, v2.asML, name) case (l1: Long, l2: Long) => assert(l1 === l2) - case (s1: Seq[_], s2: Seq[_]) => - assert(s1.size === s2.size, s"$name ${(s1, s2)}") - for (((x1, idx), x2) <- s1.zipWithIndex.zip(s2)) { - compareStructures(x1, x2, s"$name.$idx") - } - case (arr1: Array[_], arr2: Array[_]) => - assert(arr1.toSeq === arr2.toSeq) case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2") } - private def assert2(x: => Boolean, hint: String): Unit = { + private def assertWithHint(x: => Boolean, hint: String): Unit = { try { assert(x, hint) } catch { @@ -205,67 +144,20 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("debugging test") { - val df = denseData(Nil) - val c = df.col("features") - val c1 = metrics("mean").summary(c) - val res = df.select(c1) - intercept[SparkException] { - compare(res, Seq.empty) - } - } - - test("basic error handling") { - val df = denseData(Nil) - val c = df.col("features") - val res = df.select(metrics("mean").summary(c), mean(c)) - intercept[SparkException] { - compare(res, Seq.empty) - } - } + testExample("single element", Seq((Vectors.dense(0.0, 1.0, 2.0), 2.0))) - test("no element, working metrics") { - val df = denseData(Nil) - val c = df.col("features") - val res = df.select(metrics("count").summary(c), count(c)) - compare(res, Seq(Row(0L), 0L)) - } + testExample("two elements (dense)", + Seq( + (Vectors.dense(-1.0, 0.0, 6.0), 0.5), + (Vectors.dense(3.0, -3.0, 0.0), 2.8) + ) + ) - val singleElem = Seq(0.0, 1.0, 2.0) - testExample("single element", Seq(singleElem), ExpectedMetrics( - mean = singleElem, - variance = Seq(0.0, 0.0, 0.0), - count = 1, - numNonZeros = Seq(0, 1, 1), - max = singleElem, - min = singleElem, - normL1 = singleElem, - normL2 = singleElem - )) - - testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics( - mean = Seq(0.0, 0.0, 0.0), - // TODO: I have a doubt about these values, they are not normalized. - variance = Seq(0.0, 2.0, 8.0), - count = 2, - numNonZeros = Seq(0, 2, 2), - max = Seq(0.0, 1.0, 2.0), - min = Seq(0.0, -1.0, -2.0), - normL1 = Seq(0.0, 2.0, 4.0), - normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0) - )) - - testExample("dense vector input", - Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)), - ExpectedMetrics( - mean = Seq(1.0, -1.5, 3.0), - variance = Seq(8.0, 4.5, 18.0), - count = 2, - numNonZeros = Seq(2, 1, 1), - max = Seq(3.0, 0.0, 6.0), - min = Seq(-1.0, -3, 0.0), - normL1 = Seq(4.0, 3.0, 6.0), - normL2 = Seq(math.sqrt(10), 3, 6.0) + testExample("two elements (sparse)", + Seq( + (Vectors.dense(-1.0, 0.0, 6.0).toSparse, 0.5), + (Vectors.dense(3.0, -3.0, 0.0).toSparse, 2.8), + (Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0) ) ) From 6adcfa75dfd679b395480ca40f46bc9f4d201096 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 8 Sep 2017 16:48:43 +0800 Subject: [PATCH 2/8] update --- .../org/apache/spark/ml/stat/Summarizer.scala | 56 +++++++++++---- .../spark/ml/stat/SummarizerSuite.scala | 70 ++++++++++++++++++- 2 files changed, 109 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 210bf85e48c7..b1e886625f16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -41,12 +41,16 @@ sealed abstract class SummaryBuilder { /** * Returns an aggregate object that contains the summary of the column with the requested metrics. * @param featuresCol a column that contains features Vector object. - * @param weightCol (Optional param) a column that contains weight value. Default weight is 1.0. + * @param weightCol a column that contains weight value. Default weight is 1.0. * @return an aggregate column that contains the statistics. The exact content of this * structure is determined during the creation of the builder. */ @Since("2.3.0") - def summary(featuresCol: Column, weightCol: Column = lit(1.0)): Column + def summary(featuresCol: Column, weightCol: Column): Column + + @Since("2.3.0") + def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0)) + } /** @@ -93,8 +97,7 @@ object Summarizer extends Logging { * - min: the minimum for each coefficient. * - normL2: the Euclidian norm for each coefficient. * - normL1: the L1 norm of each coefficient (sum of the absolute values). - * @param firstMetric the metric being provided - * @param metrics additional metrics that can be provided. + * @param metrics metrics that can be provided. * @return a builder. * @throws IllegalArgumentException if one of the metric names is not understood. * @@ -102,51 +105,76 @@ object Summarizer extends Logging { * interface. */ @Since("2.3.0") - def metrics(firstMetric: String, metrics: String*): SummaryBuilder = { - val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics) + def metrics(metrics: String*): SummaryBuilder = { + require(metrics.size >= 1, "Should include at least one metric") + val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics) new SummaryBuilderImpl(typedMetrics, computeMetrics) } @Since("2.3.0") - def mean(col: Column, weightCol: Column = lit(1.0)): Column = { + def mean(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "mean") } @Since("2.3.0") - def variance(col: Column, weightCol: Column = lit(1.0)): Column = { + def mean(col: Column): Column = mean(col, lit(1.0)) + + @Since("2.3.0") + def variance(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "variance") } @Since("2.3.0") - def count(col: Column, weightCol: Column = lit(1.0)): Column = { + def variance(col: Column): Column = variance(col, lit(1.0)) + + @Since("2.3.0") + def count(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "count") } @Since("2.3.0") - def numNonZeros(col: Column, weightCol: Column = lit(1.0)): Column = { + def count(col: Column): Column = count(col, lit(1.0)) + + @Since("2.3.0") + def numNonZeros(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "numNonZeros") } @Since("2.3.0") - def max(col: Column, weightCol: Column = lit(1.0)): Column = { + def numNonZeros(col: Column): Column = numNonZeros(col, lit(1.0)) + + @Since("2.3.0") + def max(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "max") } @Since("2.3.0") - def min(col: Column, weightCol: Column = lit(1.0)): Column = { + def max(col: Column): Column = max(col, lit(1.0)) + + @Since("2.3.0") + def min(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "min") } @Since("2.3.0") - def normL1(col: Column, weightCol: Column = lit(1.0)): Column = { + def min(col: Column): Column = min(col, lit(1.0)) + + @Since("2.3.0") + def normL1(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "normL1") } @Since("2.3.0") - def normL2(col: Column, weightCol: Column = lit(1.0)): Column = { + def normL1(col: Column): Column = normL1(col, lit(1.0)) + + @Since("2.3.0") + def normL2(col: Column, weightCol: Column): Column = { getSingleMetric(col, weightCol, "normL2") } + @Since("2.3.0") + def normL2(col: Column): Column = normL2(col, lit(1.0)) + private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = { val c1 = metrics(metric).summary(col, weightCol) c1.getField(metric).as(s"$metric($col)") diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 2ea0e2f1f363..e7aae44e01aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -48,6 +48,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { _summarizer } + val summarizerWithoutWeight = { + val _summarizer = new MultivariateOnlineSummarizer + inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1))) + _summarizer + } + // Because the Spark context is reset between tests, we cannot hold a reference onto it. def wrappedInit() = { val df = inputVec.toDF("features", "weight") @@ -62,48 +68,96 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { Seq(Row(summarizer.mean), summarizer.mean)) } + registerTest(s"$name - mean only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("mean").summary(c), mean(c)), + Seq(Row(summarizerWithoutWeight.mean), summarizerWithoutWeight.mean)) + } + registerTest(s"$name - variance only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("variance").summary(c, weight), variance(c, weight)), Seq(Row(summarizer.variance), summarizer.variance)) } + registerTest(s"$name - variance only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("variance").summary(c), variance(c)), + Seq(Row(summarizerWithoutWeight.variance), summarizerWithoutWeight.variance)) + } + registerTest(s"$name - count only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("count").summary(c, weight), count(c, weight)), Seq(Row(summarizer.count), summarizer.count)) } + registerTest(s"$name - count only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("count").summary(c), count(c)), + Seq(Row(summarizerWithoutWeight.count), summarizerWithoutWeight.count)) + } + registerTest(s"$name - numNonZeros only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("numNonZeros").summary(c, weight), numNonZeros(c, weight)), Seq(Row(summarizer.numNonzeros), summarizer.numNonzeros)) } + registerTest(s"$name - numNonZeros only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), + Seq(Row(summarizerWithoutWeight.numNonzeros), summarizerWithoutWeight.numNonzeros)) + } + registerTest(s"$name - min only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("min").summary(c, weight), min(c, weight)), Seq(Row(summarizer.min), summarizer.min)) } + registerTest(s"$name - min only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("min").summary(c), min(c)), + Seq(Row(summarizerWithoutWeight.min), summarizerWithoutWeight.min)) + } + registerTest(s"$name - max only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("max").summary(c, weight), max(c, weight)), Seq(Row(summarizer.max), summarizer.max)) } + registerTest(s"$name - max only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("max").summary(c), max(c)), + Seq(Row(summarizerWithoutWeight.max), summarizerWithoutWeight.max)) + } + registerTest(s"$name - normL1 only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("normL1").summary(c, weight), normL1(c, weight)), Seq(Row(summarizer.normL1), summarizer.normL1)) } + registerTest(s"$name - normL1 only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("normL1").summary(c), normL1(c)), + Seq(Row(summarizerWithoutWeight.normL1), summarizerWithoutWeight.normL1)) + } + registerTest(s"$name - normL2 only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("normL2").summary(c, weight), normL2(c, weight)), Seq(Row(summarizer.normL2), summarizer.normL2)) } + registerTest(s"$name - normL2 only w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select(metrics("normL2").summary(c), normL2(c)), + Seq(Row(summarizerWithoutWeight.normL2), summarizerWithoutWeight.normL2)) + } + registerTest(s"$name - multiple metrics at once") { val (df, c, weight) = wrappedInit() compare(df.select( @@ -111,6 +165,15 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { Seq(Row(summarizer.mean, summarizer.variance, summarizer.count, summarizer.numNonzeros)) ) } + + registerTest(s"$name - multiple metrics at once w/o weight") { + val (df, c, _) = wrappedInit() + compare(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c)), + Seq(Row(summarizerWithoutWeight.mean, summarizerWithoutWeight.variance, + summarizerWithoutWeight.count, summarizerWithoutWeight.numNonzeros)) + ) + } } private def compare(df: DataFrame, exp: Seq[Any]): Unit = { @@ -146,14 +209,15 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { testExample("single element", Seq((Vectors.dense(0.0, 1.0, 2.0), 2.0))) - testExample("two elements (dense)", + testExample("multiple elements (dense)", Seq( (Vectors.dense(-1.0, 0.0, 6.0), 0.5), - (Vectors.dense(3.0, -3.0, 0.0), 2.8) + (Vectors.dense(3.0, -3.0, 0.0), 2.8), + (Vectors.dense(1.0, -3.0, 0.0), 0.0) ) ) - testExample("two elements (sparse)", + testExample("multiple elements (sparse)", Seq( (Vectors.dense(-1.0, 0.0, 6.0).toSparse, 0.5), (Vectors.dense(3.0, -3.0, 0.0).toSparse, 2.8), From 5b0baf5aba839a2c7fa1496847665bf7d4654607 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 8 Nov 2017 09:32:21 +0800 Subject: [PATCH 3/8] update --- .../src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index b1e886625f16..c702aa30dac5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -62,16 +62,17 @@ sealed abstract class SummaryBuilder { * an example in Scala: * {{{ * import org.apache.spark.ml.linalg._ + * import org.apache.spark.sql.Row * val dataframe = ... // Some dataframe containing a feature column and a weight column * val multiStatsDF = dataframe.select( * Summarizer.metrics("min", "max", "count").summary($"features", $"weight") - * val Tuple1((minVec, maxVec, count)) = multiStatsDF.as[Tuple1[(Vector, Vector, Long)]].first() + * val Row(Row(minVec, maxVec, count)) = multiStatsDF.first() * }}} * * If one wants to get a single metric, shortcuts are also available: * {{{ * val meanDF = dataframe.select(Summarizer.mean($"features")) - * val Tuple1(meanVec) = meanDF.as[Tuple1[Vector]].first() + * val Row(meanVec) = meanDF.first() * }}} * * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD From 2742cd38a83714b8041b83c72c02d179bc21b6c8 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 9 Nov 2017 16:34:36 +0800 Subject: [PATCH 4/8] update --- .../org/apache/spark/ml/stat/Summarizer.scala | 3 +- .../spark/ml/stat/JavaSummarizerSuite.java | 64 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index c702aa30dac5..61374a2d65c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -62,7 +62,7 @@ sealed abstract class SummaryBuilder { * an example in Scala: * {{{ * import org.apache.spark.ml.linalg._ - * import org.apache.spark.sql.Row + * import org.apache.spark.sql.Row * val dataframe = ... // Some dataframe containing a feature column and a weight column * val multiStatsDF = dataframe.select( * Summarizer.metrics("min", "max", "count").summary($"features", $"weight") @@ -106,6 +106,7 @@ object Summarizer extends Logging { * interface. */ @Since("2.3.0") + @scala.annotation.varargs def metrics(metrics: String*): SummaryBuilder = { require(metrics.size >= 1, "Should include at least one metric") val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics) diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java new file mode 100644 index 000000000000..38ab39aa0f49 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.Dataset; +import static org.apache.spark.sql.functions.col; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; + +public class JavaSummarizerSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = new ArrayList(); + points.add(new LabeledPoint(0.0, Vectors.dense(1.0, 2.0))); + points.add(new LabeledPoint(0.0, Vectors.dense(3.0, 4.0))); + + dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @Test + public void testSummarizer() { + dataset.select(col("features")); + Row result = dataset + .select(Summarizer.metrics("mean", "max", "count").summary(col("features"))) + .first().getStruct(0); + Vector meanVec = result.getAs("mean"); + Vector maxVec = result.getAs("max"); + long count = result.getAs("count"); + + assertEquals(2L, count); + assertArrayEquals(new double[]{2.0, 3.0}, meanVec.toArray(), 0.0); + assertArrayEquals(new double[]{3.0, 4.0}, maxVec.toArray(), 0.0); + } +} From 680021899f41246363bee588e25e4724a358de22 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 13 Dec 2017 16:49:47 +0800 Subject: [PATCH 5/8] address comments --- .../spark/ml/stat/SummarizerSuite.scala | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index e7aae44e01aa..b2acea99bc08 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -33,6 +33,16 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { import Summarizer._ import SummaryBuilderImpl._ + private case class ExpectedMetrics( + mean: Vector, + variance: Vector, + count: Long, + numNonZeros: Vector, + max: Vector, + min: Vector, + normL2: Vector, + normL1: Vector) + /** * The input is expected to be either a sparse vector, a dense vector. * @@ -40,7 +50,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { * `mllib.stat.MultivariateOnlineSummarizer`. They currently test against some fixed subset * of the metrics, but should be made fuzzy in the future. */ - private def testExample(name: String, inputVec: Seq[(Vector, Double)]): Unit = { + private def testExample(name: String, inputVec: Seq[(Vector, Double)], + exp: ExpectedMetrics = null): Unit = { val summarizer = { val _summarizer = new MultivariateOnlineSummarizer @@ -66,96 +77,112 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { val (df, c, weight) = wrappedInit() compare(df.select(metrics("mean").summary(c, weight), mean(c, weight)), Seq(Row(summarizer.mean), summarizer.mean)) + println(s"${name} mean ${summarizer.mean}") } registerTest(s"$name - mean only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(summarizerWithoutWeight.mean), summarizerWithoutWeight.mean)) + println(s"${name} mean wo ${summarizerWithoutWeight.mean}") } registerTest(s"$name - variance only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("variance").summary(c, weight), variance(c, weight)), Seq(Row(summarizer.variance), summarizer.variance)) + println(s"${name} var ${summarizer.variance}") } registerTest(s"$name - variance only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("variance").summary(c), variance(c)), Seq(Row(summarizerWithoutWeight.variance), summarizerWithoutWeight.variance)) + println(s"${name} var wo ${summarizerWithoutWeight.variance}") } registerTest(s"$name - count only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("count").summary(c, weight), count(c, weight)), Seq(Row(summarizer.count), summarizer.count)) + println(s"${name} count ${summarizer.count}") } registerTest(s"$name - count only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("count").summary(c), count(c)), Seq(Row(summarizerWithoutWeight.count), summarizerWithoutWeight.count)) + println(s"${name} count wo ${summarizerWithoutWeight.count}") } registerTest(s"$name - numNonZeros only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("numNonZeros").summary(c, weight), numNonZeros(c, weight)), Seq(Row(summarizer.numNonzeros), summarizer.numNonzeros)) + println(s"${name} nnz ${summarizer.numNonzeros}") } registerTest(s"$name - numNonZeros only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), Seq(Row(summarizerWithoutWeight.numNonzeros), summarizerWithoutWeight.numNonzeros)) + println(s"${name} nnz wo ${summarizerWithoutWeight.numNonzeros}") } registerTest(s"$name - min only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("min").summary(c, weight), min(c, weight)), Seq(Row(summarizer.min), summarizer.min)) + println(s"${name} min ${summarizer.min}") } registerTest(s"$name - min only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("min").summary(c), min(c)), Seq(Row(summarizerWithoutWeight.min), summarizerWithoutWeight.min)) + println(s"${name} min wo ${summarizerWithoutWeight.min}") } registerTest(s"$name - max only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("max").summary(c, weight), max(c, weight)), Seq(Row(summarizer.max), summarizer.max)) + println(s"${name} max ${summarizer.max}") } registerTest(s"$name - max only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("max").summary(c), max(c)), Seq(Row(summarizerWithoutWeight.max), summarizerWithoutWeight.max)) + println(s"${name} max wo ${summarizerWithoutWeight.max}") } registerTest(s"$name - normL1 only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("normL1").summary(c, weight), normL1(c, weight)), Seq(Row(summarizer.normL1), summarizer.normL1)) + println(s"${name} l1 ${summarizer.normL1}") } registerTest(s"$name - normL1 only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("normL1").summary(c), normL1(c)), Seq(Row(summarizerWithoutWeight.normL1), summarizerWithoutWeight.normL1)) + println(s"${name} l1 wo ${summarizerWithoutWeight.normL1}") } registerTest(s"$name - normL2 only") { val (df, c, weight) = wrappedInit() compare(df.select(metrics("normL2").summary(c, weight), normL2(c, weight)), Seq(Row(summarizer.normL2), summarizer.normL2)) + println(s"${name} l2 ${summarizer.normL2}") } registerTest(s"$name - normL2 only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("normL2").summary(c), normL2(c)), Seq(Row(summarizerWithoutWeight.normL2), summarizerWithoutWeight.normL2)) + println(s"${name} l2 wo ${summarizerWithoutWeight.normL2}") } registerTest(s"$name - multiple metrics at once") { From 647dbbea9f300b7edb40806f7cd687edf4ec7cb2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 13 Dec 2017 18:39:26 +0800 Subject: [PATCH 6/8] address comments --- .../org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../spark/ml/stat/SummarizerSuite.scala | 156 ++++++++++++------ 2 files changed, 102 insertions(+), 56 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 61374a2d65c8..9bed74a9f2c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -232,7 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging { StructType(fields) } - private[this] val vectorUDT = new VectorUDT + private val vectorUDT = new VectorUDT /** * All the metrics that can be currently computed by Spark for vectors. diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index b2acea99bc08..56dd6b54dca8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -51,7 +51,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { * of the metrics, but should be made fuzzy in the future. */ private def testExample(name: String, inputVec: Seq[(Vector, Double)], - exp: ExpectedMetrics = null): Unit = { + exp: ExpectedMetrics, expWithoutWeight: ExpectedMetrics): Unit = { val summarizer = { val _summarizer = new MultivariateOnlineSummarizer @@ -74,122 +74,106 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } registerTest(s"$name - mean only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("mean").summary(c, weight), mean(c, weight)), - Seq(Row(summarizer.mean), summarizer.mean)) - println(s"${name} mean ${summarizer.mean}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("mean").summary(c, w), mean(c, w)), + Seq(Row(summarizer.mean), exp.mean)) } registerTest(s"$name - mean only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("mean").summary(c), mean(c)), - Seq(Row(summarizerWithoutWeight.mean), summarizerWithoutWeight.mean)) - println(s"${name} mean wo ${summarizerWithoutWeight.mean}") + Seq(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean)) } registerTest(s"$name - variance only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("variance").summary(c, weight), variance(c, weight)), - Seq(Row(summarizer.variance), summarizer.variance)) - println(s"${name} var ${summarizer.variance}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("variance").summary(c, w), variance(c, w)), + Seq(Row(summarizer.variance), exp.variance)) } registerTest(s"$name - variance only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("variance").summary(c), variance(c)), - Seq(Row(summarizerWithoutWeight.variance), summarizerWithoutWeight.variance)) - println(s"${name} var wo ${summarizerWithoutWeight.variance}") + Seq(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance)) } registerTest(s"$name - count only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("count").summary(c, weight), count(c, weight)), - Seq(Row(summarizer.count), summarizer.count)) - println(s"${name} count ${summarizer.count}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("count").summary(c, w), count(c, w)), + Seq(Row(summarizer.count), exp.count)) } registerTest(s"$name - count only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("count").summary(c), count(c)), - Seq(Row(summarizerWithoutWeight.count), summarizerWithoutWeight.count)) - println(s"${name} count wo ${summarizerWithoutWeight.count}") + Seq(Row(summarizerWithoutWeight.count), expWithoutWeight.count)) } registerTest(s"$name - numNonZeros only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("numNonZeros").summary(c, weight), numNonZeros(c, weight)), - Seq(Row(summarizer.numNonzeros), summarizer.numNonzeros)) - println(s"${name} nnz ${summarizer.numNonzeros}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)), + Seq(Row(summarizer.numNonzeros), exp.numNonZeros)) } registerTest(s"$name - numNonZeros only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), - Seq(Row(summarizerWithoutWeight.numNonzeros), summarizerWithoutWeight.numNonzeros)) - println(s"${name} nnz wo ${summarizerWithoutWeight.numNonzeros}") + Seq(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros)) } registerTest(s"$name - min only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("min").summary(c, weight), min(c, weight)), - Seq(Row(summarizer.min), summarizer.min)) - println(s"${name} min ${summarizer.min}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("min").summary(c, w), min(c, w)), + Seq(Row(summarizer.min), exp.min)) } registerTest(s"$name - min only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("min").summary(c), min(c)), - Seq(Row(summarizerWithoutWeight.min), summarizerWithoutWeight.min)) - println(s"${name} min wo ${summarizerWithoutWeight.min}") + Seq(Row(summarizerWithoutWeight.min), expWithoutWeight.min)) } registerTest(s"$name - max only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("max").summary(c, weight), max(c, weight)), - Seq(Row(summarizer.max), summarizer.max)) - println(s"${name} max ${summarizer.max}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("max").summary(c, w), max(c, w)), + Seq(Row(summarizer.max), exp.max)) } registerTest(s"$name - max only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("max").summary(c), max(c)), - Seq(Row(summarizerWithoutWeight.max), summarizerWithoutWeight.max)) - println(s"${name} max wo ${summarizerWithoutWeight.max}") + Seq(Row(summarizerWithoutWeight.max), expWithoutWeight.max)) } registerTest(s"$name - normL1 only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("normL1").summary(c, weight), normL1(c, weight)), - Seq(Row(summarizer.normL1), summarizer.normL1)) - println(s"${name} l1 ${summarizer.normL1}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("normL1").summary(c, w), normL1(c, w)), + Seq(Row(summarizer.normL1), exp.normL1)) } registerTest(s"$name - normL1 only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("normL1").summary(c), normL1(c)), - Seq(Row(summarizerWithoutWeight.normL1), summarizerWithoutWeight.normL1)) - println(s"${name} l1 wo ${summarizerWithoutWeight.normL1}") + Seq(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1)) } registerTest(s"$name - normL2 only") { - val (df, c, weight) = wrappedInit() - compare(df.select(metrics("normL2").summary(c, weight), normL2(c, weight)), - Seq(Row(summarizer.normL2), summarizer.normL2)) - println(s"${name} l2 ${summarizer.normL2}") + val (df, c, w) = wrappedInit() + compare(df.select(metrics("normL2").summary(c, w), normL2(c, w)), + Seq(Row(summarizer.normL2), exp.normL2)) } registerTest(s"$name - normL2 only w/o weight") { val (df, c, _) = wrappedInit() compare(df.select(metrics("normL2").summary(c), normL2(c)), - Seq(Row(summarizerWithoutWeight.normL2), summarizerWithoutWeight.normL2)) - println(s"${name} l2 wo ${summarizerWithoutWeight.normL2}") + Seq(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2)) } registerTest(s"$name - multiple metrics at once") { - val (df, c, weight) = wrappedInit() + val (df, c, w) = wrappedInit() compare(df.select( - metrics("mean", "variance", "count", "numNonZeros").summary(c, weight)), - Seq(Row(summarizer.mean, summarizer.variance, summarizer.count, summarizer.numNonzeros)) + metrics("mean", "variance", "count", "numNonZeros").summary(c, w)), + Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros)) ) } @@ -197,8 +181,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { val (df, c, _) = wrappedInit() compare(df.select( metrics("mean", "variance", "count", "numNonZeros").summary(c)), - Seq(Row(summarizerWithoutWeight.mean, summarizerWithoutWeight.variance, - summarizerWithoutWeight.count, summarizerWithoutWeight.numNonzeros)) + Seq(Row(expWithoutWeight.mean, expWithoutWeight.variance, + expWithoutWeight.count, expWithoutWeight.numNonZeros)) ) } } @@ -234,13 +218,55 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } - testExample("single element", Seq((Vectors.dense(0.0, 1.0, 2.0), 2.0))) + val singleElem = Vectors.dense(0.0, 1.0, 2.0) + testExample("single element", Seq((singleElem, 2.0)), + ExpectedMetrics( + mean = singleElem, + variance = Vectors.dense(0.0, 0.0, 0.0), + count = 1L, + numNonZeros = Vectors.dense(0.0, 1.0, 1.0), + max = singleElem, + min = singleElem, + normL1 = Vectors.dense(0.0, 2.0, 4.0), + normL2 = Vectors.dense(0.0, 1.4142135623730951, 2.8284271247461903) + ), + ExpectedMetrics( + mean = singleElem, + variance = Vectors.dense(0.0, 0.0, 0.0), + count = 1L, + numNonZeros = Vectors.dense(0.0, 1.0, 1.0), + max = singleElem, + min = singleElem, + normL1 = singleElem, + normL2 = singleElem + ) + ) testExample("multiple elements (dense)", Seq( (Vectors.dense(-1.0, 0.0, 6.0), 0.5), (Vectors.dense(3.0, -3.0, 0.0), 2.8), (Vectors.dense(1.0, -3.0, 0.0), 0.0) + ), + ExpectedMetrics( + mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092), + variance = Vectors.dense(8.0, 4.5, 18.0), + count = 2L, + numNonZeros = Vectors.dense(2.0, 1.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(8.9, 8.4, 3.0), + normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285) + ), + ExpectedMetrics( + mean = Vectors.dense(1.0, -2.0, 2.0), + variance = Vectors.dense(4.0, 3.0, 12.0), + count = 3L, + numNonZeros = Vectors.dense(3.0, 2.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(5.0, 6.0, 6.0), + normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0) ) ) @@ -249,6 +275,26 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { (Vectors.dense(-1.0, 0.0, 6.0).toSparse, 0.5), (Vectors.dense(3.0, -3.0, 0.0).toSparse, 2.8), (Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0) + ), + ExpectedMetrics( + mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092), + variance = Vectors.dense(8.0, 4.5, 18.0), + count = 2L, + numNonZeros = Vectors.dense(2.0, 1.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(8.9, 8.4, 3.0), + normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285) + ), + ExpectedMetrics( + mean = Vectors.dense(1.0, -2.0, 2.0), + variance = Vectors.dense(4.0, 3.0, 12.0), + count = 3L, + numNonZeros = Vectors.dense(3.0, 2.0, 1.0), + max = Vectors.dense(3.0, 0.0, 6.0), + min = Vectors.dense(-1.0, -3.0, 0.0), + normL1 = Vectors.dense(5.0, 6.0, 6.0), + normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0) ) ) From f34da1fef6548b7780751c84f53b18befcf48798 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 15 Dec 2017 17:52:51 +0800 Subject: [PATCH 7/8] improve testcode --- .../spark/ml/stat/SummarizerSuite.scala | 131 ++++++++---------- 1 file changed, 58 insertions(+), 73 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 56dd6b54dca8..9a12d41c9a70 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -75,146 +75,131 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { registerTest(s"$name - mean only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("mean").summary(c, w), mean(c, w)), - Seq(Row(summarizer.mean), exp.mean)) + compareRow(df.select(metrics("mean").summary(c, w), mean(c, w)).first(), + Row(Row(summarizer.mean), exp.mean)) } registerTest(s"$name - mean only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("mean").summary(c), mean(c)), - Seq(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean)) + compareRow(df.select(metrics("mean").summary(c), mean(c)).first(), + Row(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean)) } registerTest(s"$name - variance only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("variance").summary(c, w), variance(c, w)), - Seq(Row(summarizer.variance), exp.variance)) + compareRow(df.select(metrics("variance").summary(c, w), variance(c, w)).first(), + Row(Row(summarizer.variance), exp.variance)) } registerTest(s"$name - variance only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("variance").summary(c), variance(c)), - Seq(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance)) + compareRow(df.select(metrics("variance").summary(c), variance(c)).first(), + Row(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance)) } registerTest(s"$name - count only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("count").summary(c, w), count(c, w)), - Seq(Row(summarizer.count), exp.count)) + compareRow(df.select(metrics("count").summary(c, w), count(c, w)).first(), + Row(Row(summarizer.count), exp.count)) } registerTest(s"$name - count only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("count").summary(c), count(c)), - Seq(Row(summarizerWithoutWeight.count), expWithoutWeight.count)) + compareRow(df.select(metrics("count").summary(c), count(c)).first(), + Row(Row(summarizerWithoutWeight.count), expWithoutWeight.count)) } registerTest(s"$name - numNonZeros only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)), - Seq(Row(summarizer.numNonzeros), exp.numNonZeros)) + compareRow(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)).first(), + Row(Row(summarizer.numNonzeros), exp.numNonZeros)) } registerTest(s"$name - numNonZeros only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), - Seq(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros)) + compareRow(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)).first(), + Row(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros)) } registerTest(s"$name - min only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("min").summary(c, w), min(c, w)), - Seq(Row(summarizer.min), exp.min)) + compareRow(df.select(metrics("min").summary(c, w), min(c, w)).first(), + Row(Row(summarizer.min), exp.min)) } registerTest(s"$name - min only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("min").summary(c), min(c)), - Seq(Row(summarizerWithoutWeight.min), expWithoutWeight.min)) + compareRow(df.select(metrics("min").summary(c), min(c)).first(), + Row(Row(summarizerWithoutWeight.min), expWithoutWeight.min)) } registerTest(s"$name - max only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("max").summary(c, w), max(c, w)), - Seq(Row(summarizer.max), exp.max)) + compareRow(df.select(metrics("max").summary(c, w), max(c, w)).first(), + Row(Row(summarizer.max), exp.max)) } registerTest(s"$name - max only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("max").summary(c), max(c)), - Seq(Row(summarizerWithoutWeight.max), expWithoutWeight.max)) + compareRow(df.select(metrics("max").summary(c), max(c)).first(), + Row(Row(summarizerWithoutWeight.max), expWithoutWeight.max)) } registerTest(s"$name - normL1 only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("normL1").summary(c, w), normL1(c, w)), - Seq(Row(summarizer.normL1), exp.normL1)) + compareRow(df.select(metrics("normL1").summary(c, w), normL1(c, w)).first(), + Row(Row(summarizer.normL1), exp.normL1)) } registerTest(s"$name - normL1 only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("normL1").summary(c), normL1(c)), - Seq(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1)) + compareRow(df.select(metrics("normL1").summary(c), normL1(c)).first(), + Row(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1)) } registerTest(s"$name - normL2 only") { val (df, c, w) = wrappedInit() - compare(df.select(metrics("normL2").summary(c, w), normL2(c, w)), - Seq(Row(summarizer.normL2), exp.normL2)) + compareRow(df.select(metrics("normL2").summary(c, w), normL2(c, w)).first(), + Row(Row(summarizer.normL2), exp.normL2)) } registerTest(s"$name - normL2 only w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select(metrics("normL2").summary(c), normL2(c)), - Seq(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2)) + compareRow(df.select(metrics("normL2").summary(c), normL2(c)).first(), + Row(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2)) } registerTest(s"$name - multiple metrics at once") { val (df, c, w) = wrappedInit() - compare(df.select( - metrics("mean", "variance", "count", "numNonZeros").summary(c, w)), - Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros)) + compareRow(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c, w)).first(), + Row(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros)) ) } registerTest(s"$name - multiple metrics at once w/o weight") { val (df, c, _) = wrappedInit() - compare(df.select( - metrics("mean", "variance", "count", "numNonZeros").summary(c)), - Seq(Row(expWithoutWeight.mean, expWithoutWeight.variance, + compareRow(df.select( + metrics("mean", "variance", "count", "numNonZeros").summary(c)).first(), + Row(Row(expWithoutWeight.mean, expWithoutWeight.variance, expWithoutWeight.count, expWithoutWeight.numNonZeros)) ) } } - private def compare(df: DataFrame, exp: Seq[Any]): Unit = { - val res = df.head().toSeq - val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" } - assert(res.size === exp.size, (res.size, exp.size)) - for (((x1, x2), name) <- res.zip(exp).zip(names)) { - compareStructures(x1, x2, name) - } - } - - // Compares structured content. - private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match { - case (r1: Row, r2: Row) => - assert(r1.size === r2.size, (r1, r2)) - for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) } - case (v1: Vector, v2: Vector) => - assertWithHint(v1 ~== v2 absTol 1e-4, name) - case (v1: Vector, v2: OldVector) => - compareStructures(v1, v2.asML, name) - case (l1: Long, l2: Long) => assert(l1 === l2) - case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2") - } - - private def assertWithHint(x: => Boolean, hint: String): Unit = { - try { - assert(x, hint) - } catch { - case tfe: TestFailedException => - throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1) + private def compareRow(r1: Row, r2: Row): Unit = { + assert(r1.size === r2.size, (r1, r2)) + r1.toSeq.zip(r2.toSeq).foreach { + case (v1: Vector, v2: Vector) => + assert(v1 ~== v2 absTol 1e-4) + case (v1: Vector, v2: OldVector) => + assert(v1 ~== v2.asML absTol 1e-4) + case (l1: Long, l2: Long) => + assert(l1 === l2) + case (r1: Row, r2: Row) => + compareRow(r1, r2) + case (x1: Any, x2: Any) => + throw new Exception(s"type mismatch: ${x1.getClass} ${x2.getClass} $x1 $x2") } } @@ -228,7 +213,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { max = singleElem, min = singleElem, normL1 = Vectors.dense(0.0, 2.0, 4.0), - normL2 = Vectors.dense(0.0, 1.4142135623730951, 2.8284271247461903) + normL2 = Vectors.dense(0.0, 1.414213, 2.828427) ), ExpectedMetrics( mean = singleElem, @@ -249,14 +234,14 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { (Vectors.dense(1.0, -3.0, 0.0), 0.0) ), ExpectedMetrics( - mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092), + mean = Vectors.dense(2.393939, -2.545454, 0.909090), variance = Vectors.dense(8.0, 4.5, 18.0), count = 2L, numNonZeros = Vectors.dense(2.0, 1.0, 1.0), max = Vectors.dense(3.0, 0.0, 6.0), min = Vectors.dense(-1.0, -3.0, 0.0), normL1 = Vectors.dense(8.9, 8.4, 3.0), - normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285) + normL2 = Vectors.dense(5.069516, 5.019960, 4.242640) ), ExpectedMetrics( mean = Vectors.dense(1.0, -2.0, 2.0), @@ -266,7 +251,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { max = Vectors.dense(3.0, 0.0, 6.0), min = Vectors.dense(-1.0, -3.0, 0.0), normL1 = Vectors.dense(5.0, 6.0, 6.0), - normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0) + normL2 = Vectors.dense(3.316624, 4.242640, 6.0) ) ) @@ -277,14 +262,14 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { (Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0) ), ExpectedMetrics( - mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092), + mean = Vectors.dense(2.393939, -2.545454, 0.909090), variance = Vectors.dense(8.0, 4.5, 18.0), count = 2L, numNonZeros = Vectors.dense(2.0, 1.0, 1.0), max = Vectors.dense(3.0, 0.0, 6.0), min = Vectors.dense(-1.0, -3.0, 0.0), normL1 = Vectors.dense(8.9, 8.4, 3.0), - normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285) + normL2 = Vectors.dense(5.069516, 5.019960, 4.242640) ), ExpectedMetrics( mean = Vectors.dense(1.0, -2.0, 2.0), @@ -294,7 +279,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { max = Vectors.dense(3.0, 0.0, 6.0), min = Vectors.dense(-1.0, -3.0, 0.0), normL1 = Vectors.dense(5.0, 6.0, 6.0), - normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0) + normL2 = Vectors.dense(3.316624, 4.242640, 6.0) ) ) From 24697f303292d375f81a39f15fb70cac5aa013b3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 21 Dec 2017 10:27:01 +0800 Subject: [PATCH 8/8] add no element test --- .../apache/spark/ml/stat/SummarizerSuite.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index 9a12d41c9a70..5e4f40298969 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.stat -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -203,6 +201,16 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("no element") { + val df = Seq[Tuple1[Vector]]().toDF("features") + val c = df.col("features") + intercept[SparkException] { + df.select(metrics("mean").summary(c), mean(c)).first() + } + compareRow(df.select(metrics("count").summary(c), count(c)).first(), + Row(Row(0L), 0L)) + } + val singleElem = Vectors.dense(0.0, 1.0, 2.0) testExample("single element", Seq((singleElem, 2.0)), ExpectedMetrics(