diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 67f6739b1e18f..018b6b9c9c375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression) b """, since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { +case class ElementAt(left: Expression, right: Expression) + extends GetMapValueUtil with GetArrayItemUtil { @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } - override def nullable: Boolean = true + override def nullable: Boolean = left.dataType match { + case _: ArrayType => computeNullabilityFromArray(left, right) + case _: MapType => true + } override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 55ed617e2904b..e9d60ed3a481f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -221,7 +221,8 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { + extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue + with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def left: Expression = child override def right: Expression = ordinal - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) { - val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() - child match { - case CreateArray(ar) if intOrdinal < ar.length => - ar(intOrdinal).nullable - case GetArrayStructFields(CreateArray(elements), field, _, _, _) - if intOrdinal < elements.length => - elements(intOrdinal).nullable || field.nullable - case _ => - true - } - } else { - true - } - + override def nullable: Boolean = computeNullabilityFromArray(left, right) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType protected override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -281,10 +266,34 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Common base class for [[GetMapValue]] and [[ElementAt]]. + * Common trait for [[GetArrayItem]] and [[ElementAt]]. */ +trait GetArrayItemUtil { + + /** `Null` is returned for invalid ordinals. */ + protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = { + if (ordinal.foldable && !ordinal.nullable) { + val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() + child match { + case CreateArray(ar) if intOrdinal < ar.length => + ar(intOrdinal).nullable + case GetArrayStructFields(CreateArray(elements), field, _, _, _) + if intOrdinal < elements.length => + elements(intOrdinal).nullable || field.nullable + case _ => + true + } + } else { + true + } + } +} + +/** + * Common trait for [[GetMapValue]] and [[ElementAt]]. + */ +trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { -abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] @@ -380,23 +389,14 @@ case class GetMapValue(child: Expression, key: Expression) override def left: Expression = child override def right: Expression = key - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = if (key.foldable && !key.nullable) { - val keyObj = key.eval() - child match { - case m: CreateMap if m.resolved => - m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find { - case (k, _) if k.eval() == keyObj => true - case _ => false - }.map(_._2.nullable).getOrElse(true) - case _ => - true - } - } else { - true - } - - + /** + * `Null` is returned for invalid ordinals. + * + * TODO: We could make nullability more precise in foldable cases (e.g., literal input). + * But, since the key search is O(n), it takes much time to compute nullability. + * If we find efficient key searches, revisit this. + */ + override def nullable: Boolean = true override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType // todo: current search is O(n), improve it. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index bed8547dbc83d..2ddad744cbab0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1092,6 +1092,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } + test("correctly handles ElementAt nullability for arrays") { + // CreateArray case + val a = AttributeReference("a", IntegerType, nullable = false)() + val b = AttributeReference("b", IntegerType, nullable = true)() + val array = CreateArray(a :: b :: Nil) + assert(!ElementAt(array, Literal(0)).nullable) + assert(ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable) + assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) + + // GetArrayStructFields case + val f1 = StructField("a", IntegerType, nullable = false) + val f2 = StructField("b", IntegerType, nullable = true) + val structType = StructType(f1 :: f2 :: Nil) + val c = AttributeReference("c", structType, nullable = false)() + val inputArray1 = CreateArray(c :: Nil) + val inputArray1ContainsNull = c.nullable + val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) + assert(!ElementAt(stArray1, Literal(0)).nullable) + val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) + assert(ElementAt(stArray2, Literal(0)).nullable) + + val d = AttributeReference("d", structType, nullable = true)() + val inputArray2 = CreateArray(c :: d :: Nil) + val inputArray2ContainsNull = c.nullable || d.nullable + val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) + assert(!ElementAt(stArray3, Literal(0)).nullable) + assert(ElementAt(stArray3, Literal(1)).nullable) + val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) + assert(ElementAt(stArray4, Literal(0)).nullable) + assert(ElementAt(stArray4, Literal(1)).nullable) + } + test("Concat") { // Primitive-type elements val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index d65b49f11884d..d8d65715281d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,25 +110,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } - test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") { - // String key test - val k1 = Literal("k1") - val v1 = AttributeReference("v1", StringType, nullable = true)() - val k2 = Literal("k2") - val v2 = AttributeReference("v2", StringType, nullable = false)() - val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil) - assert(GetMapValue(map1, Literal("k1")).nullable) - assert(!GetMapValue(map1, Literal("k2")).nullable) - assert(GetMapValue(map1, Literal("non-existent-key")).nullable) - - // Complex type key test - val k3 = Literal.create((1, "a")) - val k4 = Literal.create((2, "b")) - val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil) - assert(GetMapValue(map2, Literal.create((1, "a"))).nullable) - assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable) - } - test("GetStructField") { val typeS = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), typeS)