diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f438748d9a4f..b3d04bfa8645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI } private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val initialization = if (isPrimitiveType) { - s"$childName.copy()" + ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" - } - - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length - - val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) - val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); - |boolean isNullAtL = ${ev.value}.isNullAt(l); - |if(!isNullAtK) { - | $javaElementType el = ${getCall("k")}; - | if(!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | } else { - | ${ev.value}.setNullAt(k); - | } - | ${ev.value}.$setFunc(l, el); - |} else if (!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | ${ev.value}.setNullAt(l); - |}""".stripMargin + val arrayDataClass = classOf[GenericArrayData].getName + s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" + } + + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val getValue = CodeGenerator.getValue(childName, elementType, i) + + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($childName.isNullAt($i)) { + | $arrayData.setNullAt($j); + |} else { + | $arrayData.$setFunc($j, $getValue); + |} + """.stripMargin } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + s"$arrayData.$setFunc($j, $getValue);" } s""" - |final int $length = $childName.numElements(); - |${ev.value} = $initialization; - |for(int k = 0; k < $numberOfIterations; k++) { - | int l = $length - k - 1; - | $swapAssigments + |final int $numElements = $childName.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | int $j = $numElements - $i - 1; + | $assignment |} + |${ev.value} = $arrayData; """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bf04251e655e..5a7bd45a4b5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("reverse function") { - // String test cases + test("reverse function - string") { val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") def testString(): Unit = { checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) @@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { // Test with cached relation, the Project will be evaluated with codegen oneRowDF.cache() testString() + } - // Array test cases (primitive-type elements) - val idf = Seq( + test("reverse function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), Seq.empty, null ).toDF("i") - def testArray(): Unit = { + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { checkAnswer( - idf.select(reverse('i)), + idfNotContainsNull.select(reverse('i)), Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) ) checkAnswer( - idf.selectExpr("reverse(i)"), + idfNotContainsNull.selectExpr("reverse(i)"), Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("reverse function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer( + idfContainsNull.select(reverse('i)), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) checkAnswer( - idf.selectExpr("reverse(array(1, null, 2, null))"), - Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1))) + idfContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) ) } // Test with local relation, the Project will be evaluated without codegen - testArray() + testArrayOfPrimitiveTypeContainsNull() // Test with cached relation, the Project will be evaluated with codegen - idf.cache() - testArray() + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } - // Array test cases (non-primitive-type elements) + test("reverse function - array for non-primitive type") { val sdf = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { // Test with cached relation, the Project will be evaluated with codegen sdf.cache() testArrayOfNonPrimitiveType() + } - // Error test cases - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(struct(1, 'a'))") + test("reverse function - data type mismatch") { + val ex1 = intercept[AnalysisException] { + sql("select reverse(struct(1, 'a'))") } - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(map(1, 'a'))") + assert(ex1.getMessage.contains("data type mismatch")) + + val ex2 = intercept[AnalysisException] { + sql("select reverse(map(1, 'a'))") } + assert(ex2.getMessage.contains("data type mismatch")) } test("array position function") {