Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
}

private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val length = ctx.freshName("length")
val javaElementType = CodeGenerator.javaType(elementType)

val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)

val numElements = ctx.freshName("numElements")
val arrayData = ctx.freshName("arrayData")

val initialization = if (isPrimitiveType) {
s"$childName.copy()"
ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.")
Copy link
Member

@maropu maropu Jul 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in the master doesn't create UnsafeArrayData per input row though, it seems this change does so. Can we avoid it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the code in the master also create UnsafeArrayData per input row in UnsafeArrayData.copy().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked with @ueshin and no problem about this.

} else {
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
}

val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length

val swapAssigments = if (isPrimitiveType) {
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|if(!isNullAtK) {
| $javaElementType el = ${getCall("k")};
| if(!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| } else {
| ${ev.value}.setNullAt(k);
| }
| ${ev.value}.$setFunc(l, el);
|} else if (!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| ${ev.value}.setNullAt(l);
|}""".stripMargin
val arrayDataClass = classOf[GenericArrayData].getName
s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);"
}

val i = ctx.freshName("i")
val j = ctx.freshName("j")

val getValue = CodeGenerator.getValue(childName, elementType, i)

val setFunc = if (isPrimitiveType) {
s"set${CodeGenerator.primitiveTypeName(elementType)}"
} else {
"update"
}

val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can simplify the code if we do override def dataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't override dataType only for ArrayType because Reverse is also used for StringType.

s"""
|if ($childName.isNullAt($i)) {
| $arrayData.setNullAt($j);
|} else {
| $arrayData.$setFunc($j, $getValue);
|}
""".stripMargin
} else {
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
s"$arrayData.$setFunc($j, $getValue);"
}

s"""
|final int $length = $childName.numElements();
|${ev.value} = $initialization;
|for(int k = 0; k < $numberOfIterations; k++) {
| int l = $length - k - 1;
| $swapAssigments
|final int $numElements = $childName.numElements();
|$initialization
|for (int $i = 0; $i < $numElements; $i++) {
| int $j = $numElements - $i - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need j if we do

for (int i = numElements - 1; i >=0; i--)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to calculate the index of the opposite side?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see

| $assignment
|}
|${ev.value} = $arrayData;
""".stripMargin
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
}

test("reverse function") {
// String test cases
test("reverse function - string") {
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")
def testString(): Unit = {
checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS")))
Expand All @@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
// Test with cached relation, the Project will be evaluated with codegen
oneRowDF.cache()
testString()
}

// Array test cases (primitive-type elements)
val idf = Seq(
test("reverse function - array for primitive type not containing null") {
val idfNotContainsNull = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

def testArray(): Unit = {
def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
checkAnswer(
idf.select(reverse('i)),
idfNotContainsNull.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(i)"),
idfNotContainsNull.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
}

// Test with local relation, the Project will be evaluated without codegen
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many tests here, so I think it'd be nice to split the tests into two or three parts.

testArrayOfPrimitiveTypeNotContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
idfNotContainsNull.cache()
testArrayOfPrimitiveTypeNotContainsNull()
}

test("reverse function - array for primitive type containing null") {
val idfContainsNull = Seq[Seq[Integer]](
Seq(1, 9, 8, null, 7),
Seq(null, 5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

def testArrayOfPrimitiveTypeContainsNull(): Unit = {
checkAnswer(
idfContainsNull.select(reverse('i)),
Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(array(1, null, 2, null))"),
Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1)))
idfContainsNull.selectExpr("reverse(i)"),
Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null))
)
}

// Test with local relation, the Project will be evaluated without codegen
testArray()
testArrayOfPrimitiveTypeContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
idf.cache()
testArray()
idfContainsNull.cache()
testArrayOfPrimitiveTypeContainsNull()
}

// Array test cases (non-primitive-type elements)
test("reverse function - array for non-primitive type") {
val sdf = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Expand Down Expand Up @@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
// Test with cached relation, the Project will be evaluated with codegen
sdf.cache()
testArrayOfNonPrimitiveType()
}

// Error test cases
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
test("reverse function - data type mismatch") {
val ex1 = intercept[AnalysisException] {
sql("select reverse(struct(1, 'a'))")
}
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(map(1, 'a'))")
assert(ex1.getMessage.contains("data type mismatch"))

val ex2 = intercept[AnalysisException] {
sql("select reverse(map(1, 'a'))")
}
assert(ex2.getMessage.contains("data type mismatch"))
}

test("array position function") {
Expand Down