diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index dbe43709d1d3..cdb83d3580f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -143,21 +143,26 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayStructFields") { - val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: Nil)) - val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) - val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) - val nullArrayStruct = Literal.create(null, typeNullAS) - - def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { - expr.dataType match { - case ArrayType(StructType(fields), containsNull) => - val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) - } + // test 4 types: struct field nullability X array element nullability + val type1 = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val type2 = ArrayType(StructType(StructField("a", IntegerType, nullable = false) :: Nil)) + val type3 = ArrayType(StructType(StructField("a", IntegerType) :: Nil), containsNull = false) + val type4 = ArrayType( + StructType(StructField("a", IntegerType, nullable = false) :: Nil), containsNull = false) + + val input1 = Literal.create(Seq(create_row(1)), type4) + val input2 = Literal.create(Seq(create_row(null)), type3) + val input3 = Literal.create(Seq(null), type2) + val input4 = Literal.create(null, type1) + + def getArrayStructFields(expr: Expression, fieldName: String): Expression = { + ExtractValue.apply(expr, Literal.create(fieldName, StringType), _ == _) } - checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) - checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + checkEvaluation(getArrayStructFields(input1, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(input2, "a"), Seq(null)) + checkEvaluation(getArrayStructFields(input3, "a"), Seq(null)) + checkEvaluation(getArrayStructFields(input4, "a"), null) } test("SPARK-32167: nullability of GetArrayStructFields") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6f73c1b0c04f..341b26ddf657 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -40,6 +40,11 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. + * + * Note: when you write unit test for an expression and call `checkEvaluation` to check the result, + * please make sure that you explore all the cases that can lead to null result (including + * null in struct fields, array elements and map values). The framework will test the + * nullability flag of the expression automatically. */ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite =>