diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index c383eec3d56b..dd82d4b45a4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -113,7 +113,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = UTF8String.fromString("abc") - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -126,7 +126,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true)) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -142,7 +142,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true)) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -154,7 +154,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) - if (!checkResult(actual, expected, expressions.head.dataType)) { + if (!checkResult(actual, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -170,7 +170,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -365,7 +365,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actualOr.length == 1) val expectedOr = false - if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) { + if (!checkResult(actualOr.head, expectedOr, exprOr)) { fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr") } @@ -379,7 +379,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actualAnd.length == 1) val expectedAnd = false - if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) { + if (!checkResult(actualAnd.head, expectedAnd, exprAnd)) { fail( s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b5986aac6555..da18475276a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -69,11 +69,22 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], MapData and Row. + * Array[Byte], Spread[Double], MapData and Row. Also check whether nullable in expression is + * true if result is null */ - protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, expression: Expression): Boolean = { + checkResult(result, expected, expression.dataType, expression.nullable) + } + + protected def checkResult( + result: Any, + expected: Any, + exprDataType: DataType, + exprNullable: Boolean): Boolean = { val dataType = UserDefinedType.sqlType(exprDataType) + // The result is null for a non-nullable expression + assert(result != null || exprNullable, "exprNullable should be true if result is null") (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) @@ -83,24 +94,24 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val st = dataType.asInstanceOf[StructType] assert(result.numFields == st.length && expected.numFields == st.length) st.zipWithIndex.forall { case (f, i) => - checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + checkResult( + result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType, f.nullable) } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { - val et = dataType.asInstanceOf[ArrayType].elementType + val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType] var isSame = true var i = 0 while (isSame && i < result.numElements) { - isSame = checkResult(result.get(i, et), expected.get(i, et), et) + isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn) i += 1 } isSame } case (result: MapData, expected: MapData) => - val kt = dataType.asInstanceOf[MapType].keyType - val vt = dataType.asInstanceOf[MapType].valueType - checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && - checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) + val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType] + checkResult(result.keyArray, expected.keyArray, ArrayType(kt, false), false) && + checkResult(result.valueArray, expected.valueArray, ArrayType(vt, vcn), false) case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -175,7 +186,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected, expression.dataType)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -191,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa for (fallbackMode <- modes) { withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { val actual = evaluateWithMutableProjection(expression, inputRow) - if (!checkResult(actual, expected, expression.dataType)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (fallback mode = $fallbackMode): $expression, " + s"actual: $actual, expected: $expected$input") @@ -221,6 +232,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + val dataType = expression.dataType + if (!checkResult(unsafeRow.get(0, dataType), expected, dataType, expression.nullable)) { + fail("Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + + s"$expression, actual: $unsafeRow, expected: $expected, " + + s"dataType: $dataType, nullable: ${expression.nullable}") + } if (expected == null) { if (!unsafeRow.isNullAt(0)) { val expectedRow = InternalRow(expected, expected) @@ -229,8 +246,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } else { val lit = InternalRow(expected, expected) - val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + val expectedRow = UnsafeProjection.create(Array(dataType, dataType)).apply(lit) if (unsafeRow != expectedRow) { fail(s"Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -280,7 +296,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa expression) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected, expression.dataType)) + assert(checkResult(actual, expected, expression)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -288,7 +304,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected, expression.dataType)) + assert(checkResult(actual, expected, expression)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 7c7c4cccee25..54ef9641bee0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{DataType, IntegerType, MapType} /** * A test suite for testing [[ExpressionEvalHelper]]. @@ -35,6 +36,13 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } assert(e.getMessage.contains("some_variable")) } + + test("SPARK-25388: checkEvaluation should fail if nullable in DataType is incorrect") { + val e = intercept[RuntimeException] { + checkEvaluation(MapIncorrectDataTypeExpression(), Map(3 -> 7, 6 -> null)) + } + assert(e.getMessage.contains("and exprNullable was")) + } } /** @@ -53,3 +61,18 @@ case class BadCodegenExpression() extends LeafExpression { } override def dataType: DataType = IntegerType } + +/** + * An expression that returns a MapData with incorrect DataType whose valueContainsNull is false + * while its value includes null + */ +case class MapIncorrectDataTypeExpression() extends LeafExpression with CodegenFallback { + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = { + val keys = new GenericArrayData(Array(3, 6)) + val values = new GenericArrayData(Array(7, null)) + new ArrayBasedMapData(keys, values) + } + // since values includes null, valueContainsNull must be true + override def dataType: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 0ef630bbd367..c9309197791b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -416,7 +416,7 @@ class ObjectHashAggregateSuite actual.zip(expected).foreach { case (lhs: Row, rhs: Row) => assert(lhs.length == rhs.length) lhs.toSeq.zip(rhs.toSeq).foreach { - case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType) + case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType, false) case (a, b) => a == b } }