|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.stat |
19 | 19 |
|
20 | | -import org.scalatest.exceptions.TestFailedException |
21 | | - |
22 | | -import org.apache.spark.SparkFunSuite |
| 20 | +import org.apache.spark.{SparkException, SparkFunSuite} |
23 | 21 | import org.apache.spark.ml.linalg.{Vector, Vectors} |
24 | 22 | import org.apache.spark.ml.util.TestingUtils._ |
25 | 23 | import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} |
26 | 24 | import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics} |
27 | 25 | import org.apache.spark.mllib.util.MLlibTestSparkContext |
28 | | -import org.apache.spark.sql.{DataFrame, Row} |
| 26 | +import org.apache.spark.sql.Row |
29 | 27 |
|
30 | 28 | class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { |
31 | 29 |
|
@@ -203,6 +201,16 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { |
203 | 201 | } |
204 | 202 | } |
205 | 203 |
|
| 204 | + test("no element") { |
| 205 | + val df = Seq[Tuple1[Vector]]().toDF("features") |
| 206 | + val c = df.col("features") |
| 207 | + intercept[SparkException] { |
| 208 | + df.select(metrics("mean").summary(c), mean(c)).first() |
| 209 | + } |
| 210 | + compareRow(df.select(metrics("count").summary(c), count(c)).first(), |
| 211 | + Row(Row(0L), 0L)) |
| 212 | + } |
| 213 | + |
206 | 214 | val singleElem = Vectors.dense(0.0, 1.0, 2.0) |
207 | 215 | testExample("single element", Seq((singleElem, 2.0)), |
208 | 216 | ExpectedMetrics( |
|
0 commit comments