-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24878][SQL] Fix reverse function for array type of primitive type containing null. #21830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI | |
| } | ||
|
|
||
| private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { | ||
| val length = ctx.freshName("length") | ||
| val javaElementType = CodeGenerator.javaType(elementType) | ||
|
|
||
| val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) | ||
|
|
||
| val numElements = ctx.freshName("numElements") | ||
| val arrayData = ctx.freshName("arrayData") | ||
|
|
||
| val initialization = if (isPrimitiveType) { | ||
| s"$childName.copy()" | ||
| ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") | ||
| } else { | ||
| s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" | ||
| } | ||
|
|
||
| val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length | ||
|
|
||
| val swapAssigments = if (isPrimitiveType) { | ||
| val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) | ||
| val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) | ||
| s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); | ||
| |boolean isNullAtL = ${ev.value}.isNullAt(l); | ||
| |if(!isNullAtK) { | ||
| | $javaElementType el = ${getCall("k")}; | ||
| | if(!isNullAtL) { | ||
| | ${ev.value}.$setFunc(k, ${getCall("l")}); | ||
| | } else { | ||
| | ${ev.value}.setNullAt(k); | ||
| | } | ||
| | ${ev.value}.$setFunc(l, el); | ||
| |} else if (!isNullAtL) { | ||
| | ${ev.value}.$setFunc(k, ${getCall("l")}); | ||
| | ${ev.value}.setNullAt(l); | ||
| |}""".stripMargin | ||
| val arrayDataClass = classOf[GenericArrayData].getName | ||
| s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" | ||
| } | ||
|
|
||
| val i = ctx.freshName("i") | ||
| val j = ctx.freshName("j") | ||
|
|
||
| val getValue = CodeGenerator.getValue(childName, elementType, i) | ||
|
|
||
| val setFunc = if (isPrimitiveType) { | ||
| s"set${CodeGenerator.primitiveTypeName(elementType)}" | ||
| } else { | ||
| "update" | ||
| } | ||
|
|
||
| val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we can simplify the code if we do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't override |
||
| s""" | ||
| |if ($childName.isNullAt($i)) { | ||
| | $arrayData.setNullAt($j); | ||
| |} else { | ||
| | $arrayData.$setFunc($j, $getValue); | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" | ||
| s"$arrayData.$setFunc($j, $getValue);" | ||
| } | ||
|
|
||
| s""" | ||
| |final int $length = $childName.numElements(); | ||
| |${ev.value} = $initialization; | ||
| |for(int k = 0; k < $numberOfIterations; k++) { | ||
| | int l = $length - k - 1; | ||
| | $swapAssigments | ||
| |final int $numElements = $childName.numElements(); | ||
| |$initialization | ||
| |for (int $i = 0; $i < $numElements; $i++) { | ||
| | int $j = $numElements - $i - 1; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We still need to calculate the index of the opposite side?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah i see |
||
| | $assignment | ||
| |} | ||
| |${ev.value} = $arrayData; | ||
| """.stripMargin | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
| } | ||
| } | ||
|
|
||
| test("reverse function") { | ||
| // String test cases | ||
| test("reverse function - string") { | ||
| val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") | ||
| def testString(): Unit = { | ||
| checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) | ||
|
|
@@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
| // Test with cached relation, the Project will be evaluated with codegen | ||
| oneRowDF.cache() | ||
| testString() | ||
| } | ||
|
|
||
| // Array test cases (primitive-type elements) | ||
| val idf = Seq( | ||
| test("reverse function - array for primitive type not containing null") { | ||
| val idfNotContainsNull = Seq( | ||
| Seq(1, 9, 8, 7), | ||
| Seq(5, 8, 9, 7, 2), | ||
| Seq.empty, | ||
| null | ||
| ).toDF("i") | ||
|
|
||
| def testArray(): Unit = { | ||
| def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { | ||
| checkAnswer( | ||
| idf.select(reverse('i)), | ||
| idfNotContainsNull.select(reverse('i)), | ||
| Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) | ||
| ) | ||
| checkAnswer( | ||
| idf.selectExpr("reverse(i)"), | ||
| idfNotContainsNull.selectExpr("reverse(i)"), | ||
| Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) | ||
| ) | ||
| } | ||
|
|
||
| // Test with local relation, the Project will be evaluated without codegen | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are many tests here, so I think it'd be nice to split the tests into two or three parts. |
||
| testArrayOfPrimitiveTypeNotContainsNull() | ||
| // Test with cached relation, the Project will be evaluated with codegen | ||
| idfNotContainsNull.cache() | ||
| testArrayOfPrimitiveTypeNotContainsNull() | ||
| } | ||
|
|
||
| test("reverse function - array for primitive type containing null") { | ||
| val idfContainsNull = Seq[Seq[Integer]]( | ||
| Seq(1, 9, 8, null, 7), | ||
| Seq(null, 5, 8, 9, 7, 2), | ||
| Seq.empty, | ||
| null | ||
| ).toDF("i") | ||
|
|
||
| def testArrayOfPrimitiveTypeContainsNull(): Unit = { | ||
| checkAnswer( | ||
| idfContainsNull.select(reverse('i)), | ||
| Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) | ||
| ) | ||
| checkAnswer( | ||
| idf.selectExpr("reverse(array(1, null, 2, null))"), | ||
| Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1))) | ||
| idfContainsNull.selectExpr("reverse(i)"), | ||
| Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) | ||
| ) | ||
| } | ||
|
|
||
| // Test with local relation, the Project will be evaluated without codegen | ||
| testArray() | ||
| testArrayOfPrimitiveTypeContainsNull() | ||
| // Test with cached relation, the Project will be evaluated with codegen | ||
| idf.cache() | ||
| testArray() | ||
| idfContainsNull.cache() | ||
| testArrayOfPrimitiveTypeContainsNull() | ||
| } | ||
|
|
||
| // Array test cases (non-primitive-type elements) | ||
| test("reverse function - array for non-primitive type") { | ||
| val sdf = Seq( | ||
| Seq("c", "a", "b"), | ||
| Seq("b", null, "c", null), | ||
|
|
@@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
| // Test with cached relation, the Project will be evaluated with codegen | ||
| sdf.cache() | ||
| testArrayOfNonPrimitiveType() | ||
| } | ||
|
|
||
| // Error test cases | ||
| intercept[AnalysisException] { | ||
| oneRowDF.selectExpr("reverse(struct(1, 'a'))") | ||
| test("reverse function - data type mismatch") { | ||
| val ex1 = intercept[AnalysisException] { | ||
| sql("select reverse(struct(1, 'a'))") | ||
| } | ||
| intercept[AnalysisException] { | ||
| oneRowDF.selectExpr("reverse(map(1, 'a'))") | ||
| assert(ex1.getMessage.contains("data type mismatch")) | ||
|
|
||
| val ex2 = intercept[AnalysisException] { | ||
| sql("select reverse(map(1, 'a'))") | ||
| } | ||
| assert(ex2.getMessage.contains("data type mismatch")) | ||
| } | ||
|
|
||
| test("array position function") { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code in the master doesn't create
UnsafeArrayDataper input row though, it seems this change does so. Can we avoid it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the code in the master also create
UnsafeArrayDataper input row inUnsafeArrayData.copy().There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check this? https://gist.github.com/maropu/e9e8afd64ce30cdf824bb4e18d0c9b4f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I talked with @ueshin and no problem about this.