Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual.length == 1)
val expected = UTF8String.fromString("abc")

if (!checkResult(actual.head, expected, expressions.head.dataType)) {
if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
Expand All @@ -126,7 +126,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual.length == 1)
val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))

if (!checkResult(actual.head, expected, expressions.head.dataType)) {
if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
Expand All @@ -142,7 +142,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual.length == 1)
val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true))

if (!checkResult(actual.head, expected, expressions.head.dataType)) {
if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
Expand All @@ -154,7 +154,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq(InternalRow(Seq.fill(length)(true): _*))

if (!checkResult(actual, expected, expressions.head.dataType)) {
if (!checkResult(actual, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
Expand All @@ -170,7 +170,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual.length == 1)
val expected = InternalRow(Seq.fill(length)(true): _*)

if (!checkResult(actual.head, expected, expressions.head.dataType)) {
if (!checkResult(actual.head, expected, expressions.head)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
Expand Down Expand Up @@ -365,7 +365,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actualOr.length == 1)
val expectedOr = false

if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) {
if (!checkResult(actualOr.head, expectedOr, exprOr)) {
fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr")
}

Expand All @@ -379,7 +379,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actualAnd.length == 1)
val expectedAnd = false

if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) {
if (!checkResult(actualAnd.head, expectedAnd, exprAnd)) {
fail(
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,22 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa

/**
* Check the equality between result of expression and expected value, it will handle
* Array[Byte], Spread[Double], MapData and Row.
* Array[Byte], Spread[Double], MapData and Row. Also check whether nullable in expression is
* true if result is null
*/
protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = {
protected def checkResult(result: Any, expected: Any, expression: Expression): Boolean = {
checkResult(result, expected, expression.dataType, expression.nullable)
}

protected def checkResult(
result: Any,
expected: Any,
exprDataType: DataType,
exprNullable: Boolean): Boolean = {
val dataType = UserDefinedType.sqlType(exprDataType)

// The result is null for a non-nullable expression
assert(result != null || exprNullable, "exprNullable should be true if result is null")
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
Expand All @@ -83,24 +94,24 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
val st = dataType.asInstanceOf[StructType]
assert(result.numFields == st.length && expected.numFields == st.length)
st.zipWithIndex.forall { case (f, i) =>
checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType)
checkResult(
result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType, f.nullable)
}
case (result: ArrayData, expected: ArrayData) =>
result.numElements == expected.numElements && {
val et = dataType.asInstanceOf[ArrayType].elementType
val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType]
var isSame = true
var i = 0
while (isSame && i < result.numElements) {
isSame = checkResult(result.get(i, et), expected.get(i, et), et)
isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn)
i += 1
}
isSame
}
case (result: MapData, expected: MapData) =>
val kt = dataType.asInstanceOf[MapType].keyType
val vt = dataType.asInstanceOf[MapType].valueType
checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) &&
checkResult(result.valueArray, expected.valueArray, ArrayType(vt))
val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType]
checkResult(result.keyArray, expected.keyArray, ArrayType(kt, false), false) &&
checkResult(result.valueArray, expected.valueArray, ArrayType(vt, vcn), false)
case (result: Double, expected: Double) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
Expand Down Expand Up @@ -175,7 +186,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
val actual = try evaluateWithoutCodegen(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (!checkResult(actual, expected, expression.dataType)) {
if (!checkResult(actual, expected, expression)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
Expand All @@ -191,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
for (fallbackMode <- modes) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
val actual = evaluateWithMutableProjection(expression, inputRow)
if (!checkResult(actual, expected, expression.dataType)) {
if (!checkResult(actual, expected, expression)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (fallback mode = $fallbackMode): $expression, " +
s"actual: $actual, expected: $expected$input")
Expand Down Expand Up @@ -221,6 +232,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

val dataType = expression.dataType
if (!checkResult(unsafeRow.get(0, dataType), expected, dataType, expression.nullable)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you add this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because this statement checks consistency between expression and its nullable, as you proposed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh, I am not sure about this. Do we then still need the code below? Seems to me we are checking the same thing twice, please correct me if I am wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check different properties in these two if statements.

  1. Line 231 checks consistency between value and nullable in expected
  2. Line 245 checks bit-wise value between expected and expression

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I just meant that here we are checking the result and we are doing the same after too. Shouldn't we just add an assert for unsafeRow.get(0, dataType) != null || expression.nullable here instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. sees only expected. 2. sees expected and expression. Thus, we are doing different.

At 1, as we discussed, we need to check the consistency recursively. IIUC, unsafeRow.get(0, dataType) != null || expression.nullable does not perform checks recursively. Do I make a misunderstanding?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not perform checks recursively

good point, I was not considering it. Then, do we need the check at https://github.com/apache/spark/pull/22375/files/9ef335d6e43a6ef7d253d0ed3564f95bd0278f71#diff-41747ec3f56901eb7bfb95d2a217e94dL231? Isn't it performed in checkResult?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think checkResult already validates expression according to the nullability of the given expression at 1. Thus, if the expected is not correct, 2. will detect an incorrect point.

fail("Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " +
s"$expression, actual: $unsafeRow, expected: $expected, " +
s"dataType: $dataType, nullable: ${expression.nullable}")
}
if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
Expand All @@ -229,8 +246,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
}
} else {
val lit = InternalRow(expected, expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a more straightforward approach is, validate the expected according to the nullability of the given expression.

val expectedRow =
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
val expectedRow = UnsafeProjection.create(Array(dataType, dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail(s"Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
Expand Down Expand Up @@ -280,15 +296,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
expression)
plan.initialize(0)
var actual = plan(inputRow).get(0, expression.dataType)
assert(checkResult(actual, expected, expression.dataType))
assert(checkResult(actual, expected, expression))

plan = generateProject(
GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
expression)
plan.initialize(0)
actual = FromUnsafeProjection(expression.dataType :: Nil)(
plan(inputRow)).get(0, expression.dataType)
assert(checkResult(actual, expected, expression.dataType))
assert(checkResult(actual, expected, expression))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{DataType, IntegerType, MapType}

/**
* A test suite for testing [[ExpressionEvalHelper]].
Expand All @@ -35,6 +36,13 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper
val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) }
assert(e.getMessage.contains("some_variable"))
}

test("SPARK-25388: checkEvaluation should fail if nullable in DataType is incorrect") {
val e = intercept[RuntimeException] {
checkEvaluation(MapIncorrectDataTypeExpression(), Map(3 -> 7, 6 -> null))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if here you put Map(3 -> 7, 6 -> -1)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an output. This is because the test correctly detects a failure even in codegen-off mode, too.

"Incorrect evaluation (codegen off): mapincorrectdatatypeexpression(), actual: keys: [3,6], values: [7,null], expected: keys: [3,6], values: [7,-1]" did not contain "Incorrect evaluation in unsafe mode"
ScalaTestFailureLocation: org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelperSuite$$anonfun$4 at (ExpressionEvalHelperSuite.scala:44)
org.scalatest.exceptions.TestFailedException: "Incorrect evaluation (codegen off): mapincorrectdatatypeexpression(), actual: keys: [3,6], values: [7,null], expected: keys: [3,6], values: [7,-1]" did not contain "Incorrect evaluation in unsafe mode"
	at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:528)
...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so the example above passes in codegen off and fails with codegen on with this fix, while using Map(3 -> 7, 6 -> -1) passes codegen on and fails codegen off, am I right?

What I am thinking about (but I have not yet found a working implementation) is: since the problem arise when we say we expect null in a non-nullable datatype, can we add such a check? I mean, instead of pretending the expected value to be nullable, can't we add a check in case it is not nullable for being sure that it does not contain null? I think it would be better, because we would be able to distinguish a failure caused by a bad test, ie. a test written wrongly, from a UT failure caused by a bug in what we are testing. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is summary.

  • Map(3 -> 7, 6 -> null) passes in codegen off and fails in codegen on with this fix
  • Map(3 -> 7, 6 -> -1) fails in codegen off and fails in codegen on

Would it be possible to share examples of two cases that you think we would be able to distinguish?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks for the summary, it states more clearly what I thought.

My point is that this fix works properly only when we test both codegen on and off, but it would fail to detect the error condition it claims to fix if only one of them (for any reason) is tested. So I am wondering if it is possible to perform a check on the expected value, instead of this fix. Something like:

assert(containsNull(expected) && isNullable(expression.dataType))

where containsNull and isNullable have to be defined properly. In this way we should fail properly independently from whether codegen is on or not. And we can also give a more clear hint in the error message about the problem being most likely a bad UT.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we make it checking recursively, I think that this case cannot be detected. This is because the mismatch occurs in the different recursive path.

Would it be possible to share the case where we distingished a wrong output from a bad written UT in other places, as you proposed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I said that the suggestion above is wrong and needs to be rewritten in a recursive way. Sorry for the bad suggestion, I just meant to show my idea. So it should be something like:

assert(!containsNullWhereNotNullable(expected, expression.dataType))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not still understand your motivation correctly. What is the motivation to introduce this assertion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivations are the 2 mentioned above. Basically, I am proposing the same suggestion @cloud-fan has just commented here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With some hints from @ueshin, this PR implemented the check of null value with nullable bit in checkResult().

}
assert(e.getMessage.contains("and exprNullable was"))
}
}

/**
Expand All @@ -53,3 +61,18 @@ case class BadCodegenExpression() extends LeafExpression {
}
override def dataType: DataType = IntegerType
}

/**
* An expression that returns a MapData with incorrect DataType whose valueContainsNull is false
* while its value includes null
*/
case class MapIncorrectDataTypeExpression() extends LeafExpression with CodegenFallback {
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
val keys = new GenericArrayData(Array(3, 6))
val values = new GenericArrayData(Array(7, null))
new ArrayBasedMapData(keys, values)
}
// since values includes null, valueContainsNull must be true
override def dataType: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class ObjectHashAggregateSuite
actual.zip(expected).foreach { case (lhs: Row, rhs: Row) =>
assert(lhs.length == rhs.length)
lhs.toSeq.zip(rhs.toSeq).foreach {
case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType)
case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType, false)
case (a, b) => a == b
}
}
Expand Down