Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
case class ElementAt(left: Expression, right: Expression)
extends GetMapValueUtil with GetArrayItemUtil {

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

Expand Down Expand Up @@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}

override def nullable: Boolean = true
override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: MapType => true
}

override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
Expand All @@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def left: Expression = child
override def right: Expression = ordinal

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}

override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
Expand Down Expand Up @@ -281,10 +266,34 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

/**
* Common base class for [[GetMapValue]] and [[ElementAt]].
* Common trait for [[GetArrayItem]] and [[ElementAt]].
*/
trait GetArrayItemUtil {

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}
}
}

/**
* Common trait for [[GetMapValue]] and [[ElementAt]].
*/
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
Expand Down Expand Up @@ -381,22 +390,7 @@ case class GetMapValue(child: Expression, key: Expression)
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update this comment, to say why the nullability is always true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, added.

override def nullable: Boolean = if (key.foldable && !key.nullable) {
val keyObj = key.eval()
child match {
case m: CreateMap if m.resolved =>
m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
case (k, _) if k.eval() == keyObj => true
case _ => false
}.map(_._2.nullable).getOrElse(true)
case _ =>
true
}
} else {
true
}


override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}

test("correctly handles ElementAt nullability for arrays") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(0)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(1)).nullable)
}

test("Concat") {
// Primitive-type elements
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}

test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") {
// String key test
val k1 = Literal("k1")
val v1 = AttributeReference("v1", StringType, nullable = true)()
val k2 = Literal("k2")
val v2 = AttributeReference("v2", StringType, nullable = false)()
val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil)
assert(GetMapValue(map1, Literal("k1")).nullable)
assert(!GetMapValue(map1, Literal("k2")).nullable)
assert(GetMapValue(map1, Literal("non-existent-key")).nullable)

// Complex type key test
val k3 = Literal.create((1, "a"))
val k4 = Literal.create((2, "b"))
val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil)
assert(GetMapValue(map2, Literal.create((1, "a"))).nullable)
assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable)
}

test("GetStructField") {
val typeS = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), typeS)
Expand Down