Skip to content

Commit f34da1f

Browse files
committed
improve testcode
1 parent 647dbbe commit f34da1f

File tree

1 file changed

+58
-73
lines changed

1 file changed

+58
-73
lines changed

mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala

Lines changed: 58 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -75,146 +75,131 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
7575

7676
registerTest(s"$name - mean only") {
7777
val (df, c, w) = wrappedInit()
78-
compare(df.select(metrics("mean").summary(c, w), mean(c, w)),
79-
Seq(Row(summarizer.mean), exp.mean))
78+
compareRow(df.select(metrics("mean").summary(c, w), mean(c, w)).first(),
79+
Row(Row(summarizer.mean), exp.mean))
8080
}
8181

8282
registerTest(s"$name - mean only w/o weight") {
8383
val (df, c, _) = wrappedInit()
84-
compare(df.select(metrics("mean").summary(c), mean(c)),
85-
Seq(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean))
84+
compareRow(df.select(metrics("mean").summary(c), mean(c)).first(),
85+
Row(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean))
8686
}
8787

8888
registerTest(s"$name - variance only") {
8989
val (df, c, w) = wrappedInit()
90-
compare(df.select(metrics("variance").summary(c, w), variance(c, w)),
91-
Seq(Row(summarizer.variance), exp.variance))
90+
compareRow(df.select(metrics("variance").summary(c, w), variance(c, w)).first(),
91+
Row(Row(summarizer.variance), exp.variance))
9292
}
9393

9494
registerTest(s"$name - variance only w/o weight") {
9595
val (df, c, _) = wrappedInit()
96-
compare(df.select(metrics("variance").summary(c), variance(c)),
97-
Seq(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance))
96+
compareRow(df.select(metrics("variance").summary(c), variance(c)).first(),
97+
Row(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance))
9898
}
9999

100100
registerTest(s"$name - count only") {
101101
val (df, c, w) = wrappedInit()
102-
compare(df.select(metrics("count").summary(c, w), count(c, w)),
103-
Seq(Row(summarizer.count), exp.count))
102+
compareRow(df.select(metrics("count").summary(c, w), count(c, w)).first(),
103+
Row(Row(summarizer.count), exp.count))
104104
}
105105

106106
registerTest(s"$name - count only w/o weight") {
107107
val (df, c, _) = wrappedInit()
108-
compare(df.select(metrics("count").summary(c), count(c)),
109-
Seq(Row(summarizerWithoutWeight.count), expWithoutWeight.count))
108+
compareRow(df.select(metrics("count").summary(c), count(c)).first(),
109+
Row(Row(summarizerWithoutWeight.count), expWithoutWeight.count))
110110
}
111111

112112
registerTest(s"$name - numNonZeros only") {
113113
val (df, c, w) = wrappedInit()
114-
compare(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)),
115-
Seq(Row(summarizer.numNonzeros), exp.numNonZeros))
114+
compareRow(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)).first(),
115+
Row(Row(summarizer.numNonzeros), exp.numNonZeros))
116116
}
117117

118118
registerTest(s"$name - numNonZeros only w/o weight") {
119119
val (df, c, _) = wrappedInit()
120-
compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)),
121-
Seq(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros))
120+
compareRow(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)).first(),
121+
Row(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros))
122122
}
123123

124124
registerTest(s"$name - min only") {
125125
val (df, c, w) = wrappedInit()
126-
compare(df.select(metrics("min").summary(c, w), min(c, w)),
127-
Seq(Row(summarizer.min), exp.min))
126+
compareRow(df.select(metrics("min").summary(c, w), min(c, w)).first(),
127+
Row(Row(summarizer.min), exp.min))
128128
}
129129

130130
registerTest(s"$name - min only w/o weight") {
131131
val (df, c, _) = wrappedInit()
132-
compare(df.select(metrics("min").summary(c), min(c)),
133-
Seq(Row(summarizerWithoutWeight.min), expWithoutWeight.min))
132+
compareRow(df.select(metrics("min").summary(c), min(c)).first(),
133+
Row(Row(summarizerWithoutWeight.min), expWithoutWeight.min))
134134
}
135135

136136
registerTest(s"$name - max only") {
137137
val (df, c, w) = wrappedInit()
138-
compare(df.select(metrics("max").summary(c, w), max(c, w)),
139-
Seq(Row(summarizer.max), exp.max))
138+
compareRow(df.select(metrics("max").summary(c, w), max(c, w)).first(),
139+
Row(Row(summarizer.max), exp.max))
140140
}
141141

142142
registerTest(s"$name - max only w/o weight") {
143143
val (df, c, _) = wrappedInit()
144-
compare(df.select(metrics("max").summary(c), max(c)),
145-
Seq(Row(summarizerWithoutWeight.max), expWithoutWeight.max))
144+
compareRow(df.select(metrics("max").summary(c), max(c)).first(),
145+
Row(Row(summarizerWithoutWeight.max), expWithoutWeight.max))
146146
}
147147

148148
registerTest(s"$name - normL1 only") {
149149
val (df, c, w) = wrappedInit()
150-
compare(df.select(metrics("normL1").summary(c, w), normL1(c, w)),
151-
Seq(Row(summarizer.normL1), exp.normL1))
150+
compareRow(df.select(metrics("normL1").summary(c, w), normL1(c, w)).first(),
151+
Row(Row(summarizer.normL1), exp.normL1))
152152
}
153153

154154
registerTest(s"$name - normL1 only w/o weight") {
155155
val (df, c, _) = wrappedInit()
156-
compare(df.select(metrics("normL1").summary(c), normL1(c)),
157-
Seq(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1))
156+
compareRow(df.select(metrics("normL1").summary(c), normL1(c)).first(),
157+
Row(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1))
158158
}
159159

160160
registerTest(s"$name - normL2 only") {
161161
val (df, c, w) = wrappedInit()
162-
compare(df.select(metrics("normL2").summary(c, w), normL2(c, w)),
163-
Seq(Row(summarizer.normL2), exp.normL2))
162+
compareRow(df.select(metrics("normL2").summary(c, w), normL2(c, w)).first(),
163+
Row(Row(summarizer.normL2), exp.normL2))
164164
}
165165

166166
registerTest(s"$name - normL2 only w/o weight") {
167167
val (df, c, _) = wrappedInit()
168-
compare(df.select(metrics("normL2").summary(c), normL2(c)),
169-
Seq(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2))
168+
compareRow(df.select(metrics("normL2").summary(c), normL2(c)).first(),
169+
Row(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2))
170170
}
171171

172172
registerTest(s"$name - multiple metrics at once") {
173173
val (df, c, w) = wrappedInit()
174-
compare(df.select(
175-
metrics("mean", "variance", "count", "numNonZeros").summary(c, w)),
176-
Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros))
174+
compareRow(df.select(
175+
metrics("mean", "variance", "count", "numNonZeros").summary(c, w)).first(),
176+
Row(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros))
177177
)
178178
}
179179

180180
registerTest(s"$name - multiple metrics at once w/o weight") {
181181
val (df, c, _) = wrappedInit()
182-
compare(df.select(
183-
metrics("mean", "variance", "count", "numNonZeros").summary(c)),
184-
Seq(Row(expWithoutWeight.mean, expWithoutWeight.variance,
182+
compareRow(df.select(
183+
metrics("mean", "variance", "count", "numNonZeros").summary(c)).first(),
184+
Row(Row(expWithoutWeight.mean, expWithoutWeight.variance,
185185
expWithoutWeight.count, expWithoutWeight.numNonZeros))
186186
)
187187
}
188188
}
189189

190-
private def compare(df: DataFrame, exp: Seq[Any]): Unit = {
191-
val res = df.head().toSeq
192-
val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" }
193-
assert(res.size === exp.size, (res.size, exp.size))
194-
for (((x1, x2), name) <- res.zip(exp).zip(names)) {
195-
compareStructures(x1, x2, name)
196-
}
197-
}
198-
199-
// Compares structured content.
200-
private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match {
201-
case (r1: Row, r2: Row) =>
202-
assert(r1.size === r2.size, (r1, r2))
203-
for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) }
204-
case (v1: Vector, v2: Vector) =>
205-
assertWithHint(v1 ~== v2 absTol 1e-4, name)
206-
case (v1: Vector, v2: OldVector) =>
207-
compareStructures(v1, v2.asML, name)
208-
case (l1: Long, l2: Long) => assert(l1 === l2)
209-
case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2")
210-
}
211-
212-
private def assertWithHint(x: => Boolean, hint: String): Unit = {
213-
try {
214-
assert(x, hint)
215-
} catch {
216-
case tfe: TestFailedException =>
217-
throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1)
190+
private def compareRow(r1: Row, r2: Row): Unit = {
191+
assert(r1.size === r2.size, (r1, r2))
192+
r1.toSeq.zip(r2.toSeq).foreach {
193+
case (v1: Vector, v2: Vector) =>
194+
assert(v1 ~== v2 absTol 1e-4)
195+
case (v1: Vector, v2: OldVector) =>
196+
assert(v1 ~== v2.asML absTol 1e-4)
197+
case (l1: Long, l2: Long) =>
198+
assert(l1 === l2)
199+
case (r1: Row, r2: Row) =>
200+
compareRow(r1, r2)
201+
case (x1: Any, x2: Any) =>
202+
throw new Exception(s"type mismatch: ${x1.getClass} ${x2.getClass} $x1 $x2")
218203
}
219204
}
220205

@@ -228,7 +213,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
228213
max = singleElem,
229214
min = singleElem,
230215
normL1 = Vectors.dense(0.0, 2.0, 4.0),
231-
normL2 = Vectors.dense(0.0, 1.4142135623730951, 2.8284271247461903)
216+
normL2 = Vectors.dense(0.0, 1.414213, 2.828427)
232217
),
233218
ExpectedMetrics(
234219
mean = singleElem,
@@ -249,14 +234,14 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
249234
(Vectors.dense(1.0, -3.0, 0.0), 0.0)
250235
),
251236
ExpectedMetrics(
252-
mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092),
237+
mean = Vectors.dense(2.393939, -2.545454, 0.909090),
253238
variance = Vectors.dense(8.0, 4.5, 18.0),
254239
count = 2L,
255240
numNonZeros = Vectors.dense(2.0, 1.0, 1.0),
256241
max = Vectors.dense(3.0, 0.0, 6.0),
257242
min = Vectors.dense(-1.0, -3.0, 0.0),
258243
normL1 = Vectors.dense(8.9, 8.4, 3.0),
259-
normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285)
244+
normL2 = Vectors.dense(5.069516, 5.019960, 4.242640)
260245
),
261246
ExpectedMetrics(
262247
mean = Vectors.dense(1.0, -2.0, 2.0),
@@ -266,7 +251,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
266251
max = Vectors.dense(3.0, 0.0, 6.0),
267252
min = Vectors.dense(-1.0, -3.0, 0.0),
268253
normL1 = Vectors.dense(5.0, 6.0, 6.0),
269-
normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0)
254+
normL2 = Vectors.dense(3.316624, 4.242640, 6.0)
270255
)
271256
)
272257

@@ -277,14 +262,14 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
277262
(Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0)
278263
),
279264
ExpectedMetrics(
280-
mean = Vectors.dense(2.393939393939394, -2.545454545454545, 0.9090909090909092),
265+
mean = Vectors.dense(2.393939, -2.545454, 0.909090),
281266
variance = Vectors.dense(8.0, 4.5, 18.0),
282267
count = 2L,
283268
numNonZeros = Vectors.dense(2.0, 1.0, 1.0),
284269
max = Vectors.dense(3.0, 0.0, 6.0),
285270
min = Vectors.dense(-1.0, -3.0, 0.0),
286271
normL1 = Vectors.dense(8.9, 8.4, 3.0),
287-
normL2 = Vectors.dense(5.06951674225463, 5.0199601592044525, 4.242640687119285)
272+
normL2 = Vectors.dense(5.069516, 5.019960, 4.242640)
288273
),
289274
ExpectedMetrics(
290275
mean = Vectors.dense(1.0, -2.0, 2.0),
@@ -294,7 +279,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
294279
max = Vectors.dense(3.0, 0.0, 6.0),
295280
min = Vectors.dense(-1.0, -3.0, 0.0),
296281
normL1 = Vectors.dense(5.0, 6.0, 6.0),
297-
normL2 = Vectors.dense(3.3166247903554, 4.242640687119285, 6.0)
282+
normL2 = Vectors.dense(3.316624, 4.242640, 6.0)
298283
)
299284
)
300285

0 commit comments

Comments
 (0)