Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def right: Expression = ordinal

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}

override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,39 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}

test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!GetArrayItem(array, Literal(0)).nullable)
assert(GetArrayItem(array, Literal(1)).nullable)
assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable)
assert(GetArrayItem(array, AttributeReference("ordinal", IntegerType)()).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!GetArrayItem(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(GetArrayItem(stArray2, Literal(0)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!GetArrayItem(stArray3, Literal(0)).nullable)
assert(GetArrayItem(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(GetArrayItem(stArray4, Literal(0)).nullable)
assert(GetArrayItem(stArray4, Literal(1)).nullable)
}

test("GetMapValue") {
val typeM = MapType(StringType, StringType)
val map = Literal.create(Map("a" -> "b"), typeM)
Expand Down