Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,26 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("GetArrayStructFields") {
val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: Nil))
val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
val nullArrayStruct = Literal.create(null, typeNullAS)

def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = {
expr.dataType match {
case ArrayType(StructType(fields), containsNull) =>
val field = fields.find(_.name == fieldName).get
GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull)
}
// test 4 types: struct field nullability X array element nullability
val type1 = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val type2 = ArrayType(StructType(StructField("a", IntegerType, nullable = false) :: Nil))
val type3 = ArrayType(StructType(StructField("a", IntegerType) :: Nil), containsNull = false)
val type4 = ArrayType(
StructType(StructField("a", IntegerType, nullable = false) :: Nil), containsNull = false)

val input1 = Literal.create(Seq(create_row(1)), type4)
val input2 = Literal.create(Seq(create_row(null)), type3)
val input3 = Literal.create(Seq(null), type2)
val input4 = Literal.create(null, type1)

def getArrayStructFields(expr: Expression, fieldName: String): Expression = {
ExtractValue.apply(expr, Literal.create(fieldName, StringType), _ == _)
}

checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1))
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
checkEvaluation(getArrayStructFields(input1, "a"), Seq(1))
checkEvaluation(getArrayStructFields(input2, "a"), Seq(null))
checkEvaluation(getArrayStructFields(input3, "a"), Seq(null))
checkEvaluation(getArrayStructFields(input4, "a"), null)
}

test("SPARK-32167: nullability of GetArrayStructFields") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ import org.apache.spark.util.Utils

/**
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
*
* Note: when you write unit test for an expression and call `checkEvaluation` to check the result,
* please make sure that you explore all the cases that can lead to null result (including
* null in struct fields, array elements and map values). The framework will test the
* nullability flag of the expression automatically.
*/
trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestBase {
self: SparkFunSuite =>
Expand Down