Skip to content

Commit 28ed664

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23926][SQL] Adding more tests + fixing a bug in codegen.
1 parent 3a76d87 commit 28ed664

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -225,28 +225,18 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
225225
> SELECT _FUNC_(array(2, 1, 4, 3));
226226
[3, 4, 1, 2]
227227
""",
228-
since = "2.4.0")
228+
since = "1.5.0",
229+
note = "Reverse logic for arrays is available since 2.4.0."
230+
)
229231
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
230232

231233
// Input types are utilized by type coercion in ImplicitTypeCasts.
232-
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
233-
234-
val allowedTypes = Seq(StringType, ArrayType)
234+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
235235

236236
override def dataType: DataType = child.dataType
237237

238238
lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
239239

240-
override def checkInputDataTypes(): TypeCheckResult = {
241-
if (allowedTypes.exists(_.acceptsType(child.dataType))) {
242-
TypeCheckResult.TypeCheckSuccess
243-
} else {
244-
TypeCheckResult.TypeCheckFailure(
245-
s"The argument of function $prettyName should be StringType or ArrayType," +
246-
s" but it's " + child.dataType.simpleString)
247-
}
248-
}
249-
250240
override def nullSafeEval(input: Any): Any = input match {
251241
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
252242
case s: UTF8String => s.reverse()
@@ -266,10 +256,19 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
266256
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
267257
val length = ctx.freshName("length")
268258
val javaElementType = CodeGenerator.javaType(elementType)
269-
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
259+
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
260+
261+
val initialization = if (isPrimitiveType) {
262+
s"$childName.copy()"
263+
} else {
264+
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
265+
}
266+
267+
val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
270268

271-
val swapAssigments = if (CodeGenerator.isPrimitiveType(elementType)) {
269+
val swapAssigments = if (isPrimitiveType) {
272270
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
271+
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
273272
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
274273
|boolean isNullAtL = ${ev.value}.isNullAt(l);
275274
|if(!isNullAtK) {
@@ -285,19 +284,17 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
285284
| ${ev.value}.setNullAt(l);
286285
|}""".stripMargin
287286
} else {
288-
s"""|Object el = ${getCall("k")};
289-
|${ev.value}.update(k, ${getCall("l")});
290-
|${ev.value}.update(l, el);""".stripMargin
287+
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
291288
}
292289

293290
s"""
294-
|${ev.value} = $childName.copy();
295-
|final int $length = ${ev.value}.numElements();
296-
|for(int k = 0; k < $length / 2; k++) {
297-
| int l = $length - k - 1;
298-
| $swapAssigments
299-
|}
300-
""".stripMargin
291+
|final int $length = $childName.numElements();
292+
|${ev.value} = $initialization;
293+
|for(int k = 0; k < $numberOfIterations; k++) {
294+
| int l = $length - k - 1;
295+
| $swapAssigments
296+
|}
297+
""".stripMargin
301298
}
302299

303300
override def prettyName: String = "reverse"

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
414414
}
415415

416416
test("reverse function") {
417+
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on
418+
417419
// String test cases
418420
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
419421

@@ -438,7 +440,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
438440
Seq(Row(null))
439441
)
440442

441-
442443
// Array test cases (primitive-type elements)
443444
val idf = Seq(
444445
Seq(1, 9, 8, 7),
@@ -451,6 +452,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
451452
idf.select(reverse('i)),
452453
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
453454
)
455+
checkAnswer(
456+
idf.filter(dummyFilter('i)).select(reverse('i)),
457+
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
458+
)
454459
checkAnswer(
455460
idf.selectExpr("reverse(i)"),
456461
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
@@ -459,6 +464,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
459464
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
460465
Seq(Row(Seq(null, 2, null, 1)))
461466
)
467+
checkAnswer(
468+
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
469+
Seq(Row(Seq(null, 2, null, 1)))
470+
)
462471

463472
// Array test cases (complex-type elements)
464473
val sdf = Seq(
@@ -472,6 +481,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
472481
sdf.select(reverse('s)),
473482
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
474483
)
484+
checkAnswer(
485+
sdf.filter(dummyFilter('s)).select(reverse('s)),
486+
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
487+
)
475488
checkAnswer(
476489
sdf.selectExpr("reverse(s)"),
477490
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
@@ -480,6 +493,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
480493
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
481494
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
482495
)
496+
checkAnswer(
497+
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
498+
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
499+
)
483500

484501
// Error test cases
485502
intercept[AnalysisException] {

0 commit comments

Comments
 (0)