diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 104ad98ca099..55ed617e2904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -381,7 +381,21 @@ case class GetMapValue(child: Expression, key: Expression) override def right: Expression = key /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true + override def nullable: Boolean = if (key.foldable && !key.nullable) { + val keyObj = key.eval() + child match { + case m: CreateMap if m.resolved => + m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find { + case (k, _) if k.eval() == keyObj => true + case _ => false + }.map(_._2.nullable).getOrElse(true) + case _ => + true + } + } else { + true + } + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index d8d65715281d..d65b49f11884 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,6 +110,25 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } + test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") { + // String key test + val k1 = Literal("k1") + val v1 = AttributeReference("v1", StringType, nullable = true)() + val k2 = Literal("k2") + val v2 = AttributeReference("v2", StringType, nullable = false)() + val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil) + assert(GetMapValue(map1, Literal("k1")).nullable) + assert(!GetMapValue(map1, Literal("k2")).nullable) + assert(GetMapValue(map1, Literal("non-existent-key")).nullable) + + // Complex type key test + val k3 = Literal.create((1, "a")) + val k4 = Literal.create((2, "b")) + val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil) + assert(GetMapValue(map2, Literal.create((1, "a"))).nullable) + assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable) + } + test("GetStructField") { val typeS = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), typeS)