-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-26965][SQL] Makes ElementAt nullability more precise for array cases #23867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression) | |
| b | ||
| """, | ||
| since = "2.4.0") | ||
| case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { | ||
| case class ElementAt(left: Expression, right: Expression) | ||
| extends GetMapValueUtil with GetArrayItemUtil { | ||
|
|
||
| @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType | ||
|
|
||
|
|
@@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti | |
| } | ||
| } | ||
|
|
||
| override def nullable: Boolean = true | ||
| override def nullable: Boolean = left.dataType match { | ||
| case _: ArrayType => computeNullabilityFromArray | ||
| case _: MapType => computeNullabilityFromMap | ||
|
||
| } | ||
|
|
||
| override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -216,24 +216,15 @@ case class GetArrayStructFields( | |||
| } | ||||
|
|
||||
| /** | ||||
| * Returns the field at `ordinal` in the Array `child`. | ||||
| * | ||||
| * We need to do type checking here as `ordinal` expression maybe unresolved. | ||||
| * Common trait for [[GetArrayItem]] and [[ElementAt]]. | ||||
| */ | ||||
| case class GetArrayItem(child: Expression, ordinal: Expression) | ||||
| extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { | ||||
|
|
||||
| // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. | ||||
| override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) | ||||
|
|
||||
| override def toString: String = s"$child[$ordinal]" | ||||
| override def sql: String = s"${child.sql}[${ordinal.sql}]" | ||||
| trait GetArrayItemUtil extends BinaryExpression { | ||||
|
||||
|
|
||||
| override def left: Expression = child | ||||
| override def right: Expression = ordinal | ||||
| private val child = left | ||||
| private val ordinal = right | ||||
|
|
||||
| /** `Null` is returned for invalid ordinals. */ | ||||
| override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) { | ||||
| protected def computeNullabilityFromArray: Boolean = if (ordinal.foldable && !ordinal.nullable) { | ||||
| val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue() | ||||
| child match { | ||||
| case CreateArray(ar) if intOrdinal < ar.length => | ||||
|
|
@@ -247,7 +238,25 @@ case class GetArrayItem(child: Expression, ordinal: Expression) | |||
| } else { | ||||
| true | ||||
| } | ||||
| } | ||||
|
|
||||
| /** | ||||
| * Returns the field at `ordinal` in the Array `child`. | ||||
| * | ||||
| * We need to do type checking here as `ordinal` expression maybe unresolved. | ||||
| */ | ||||
| case class GetArrayItem(child: Expression, ordinal: Expression) | ||||
| extends GetArrayItemUtil with ExpectsInputTypes with ExtractValue with NullIntolerant { | ||||
|
|
||||
| // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. | ||||
| override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) | ||||
|
|
||||
| override def toString: String = s"$child[$ordinal]" | ||||
| override def sql: String = s"${child.sql}[${ordinal.sql}]" | ||||
|
|
||||
| override def left: Expression = child | ||||
| override def right: Expression = ordinal | ||||
| override def nullable: Boolean = computeNullabilityFromArray | ||||
| override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType | ||||
|
|
||||
| protected override def nullSafeEval(value: Any, ordinal: Any): Any = { | ||||
|
|
@@ -281,10 +290,29 @@ case class GetArrayItem(child: Expression, ordinal: Expression) | |||
| } | ||||
|
|
||||
| /** | ||||
| * Common base class for [[GetMapValue]] and [[ElementAt]]. | ||||
| * Common trait for [[GetMapValue]] and [[ElementAt]]. | ||||
| */ | ||||
| trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { | ||||
|
||||
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Then, what about GetArrayItemUtil? Actually, the reason why I asked is that this PR makes some non-trivial lineages like the followings.
def BinaryExpression.left->private val GetArrayItemUtil.child->override def GetArrayItem.left.def BinaryExpression.left->private val GetMapValueUtil.child->override def GetMapValue.left
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, ok. I'll simplify it.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for updating, @maropu . For this one, can we simplify like the following? We can remove these alias variables.
- private val child = left
- private val key = right
-
/** `Null` is returned for invalid ordinals. */
- protected def computeNullabilityFromMap: Boolean = if (key.foldable && !key.nullable) {
- val keyObj = key.eval()
- child match {
+ protected def computeNullabilityFromMap: Boolean = if (right.foldable && !right.nullable) {
+ val keyObj = right.eval()
+ left match {There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to keep the clear meaning, you may declare them inside of this function. But, since we are not using parameters, I think we already have some assumptions that this is inside BinaryExpression. So, for me, I just want to avoid making aliases like my previous comment.
protected def computeNullabilityFromMap: Boolean = {
val child = left
val key = right
....
}
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rely on ==? What if one expression returns unsafe row and the other returns safe row?
I know this is existing code, but after a hindsight maybe it's not worth to linear scan the map entries and just to get the precise nullability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I'll update soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dropped the nullablity computation for the map side and added the TODO comment there. How about this?
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1092,6 +1092,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) | ||
| } | ||
|
|
||
| test("SPARK-26965 correctly handles ElementAt nullability for arrays") { | ||
|
||
| // CreateArray case | ||
| val a = AttributeReference("a", IntegerType, nullable = false)() | ||
| val b = AttributeReference("b", IntegerType, nullable = true)() | ||
| val array = CreateArray(a :: b :: Nil) | ||
| assert(!ElementAt(array, Literal(0)).nullable) | ||
| assert(ElementAt(array, Literal(1)).nullable) | ||
| assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable) | ||
| assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) | ||
|
|
||
| // GetArrayStructFields case | ||
| val f1 = StructField("a", IntegerType, nullable = false) | ||
| val f2 = StructField("b", IntegerType, nullable = true) | ||
| val structType = StructType(f1 :: f2 :: Nil) | ||
| val c = AttributeReference("c", structType, nullable = false)() | ||
| val inputArray1 = CreateArray(c :: Nil) | ||
| val inputArray1ContainsNull = c.nullable | ||
| val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) | ||
| assert(!ElementAt(stArray1, Literal(0)).nullable) | ||
| val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) | ||
| assert(ElementAt(stArray2, Literal(0)).nullable) | ||
|
|
||
| val d = AttributeReference("d", structType, nullable = true)() | ||
| val inputArray2 = CreateArray(c :: d :: Nil) | ||
| val inputArray2ContainsNull = c.nullable || d.nullable | ||
| val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) | ||
| assert(!ElementAt(stArray3, Literal(0)).nullable) | ||
| assert(ElementAt(stArray3, Literal(1)).nullable) | ||
| val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) | ||
| assert(ElementAt(stArray4, Literal(0)).nullable) | ||
| assert(ElementAt(stArray4, Literal(1)).nullable) | ||
| } | ||
|
|
||
| test("SPARK-26965 correctly handles ElementAt nullability for maps") { | ||
|
||
| // String key test | ||
| val k1 = Literal("k1") | ||
| val v1 = AttributeReference("v1", StringType, nullable = true)() | ||
| val k2 = Literal("k2") | ||
| val v2 = AttributeReference("v2", StringType, nullable = false)() | ||
| val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil) | ||
| assert(ElementAt(map1, Literal("k1")).nullable) | ||
| assert(!ElementAt(map1, Literal("k2")).nullable) | ||
| assert(ElementAt(map1, Literal("non-existent-key")).nullable) | ||
|
|
||
| // Complex type key test | ||
| val k3 = Literal.create((1, "a")) | ||
| val k4 = Literal.create((2, "b")) | ||
| val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil) | ||
| assert(ElementAt(map2, Literal.create((1, "a"))).nullable) | ||
| assert(!ElementAt(map2, Literal.create((2, "b"))).nullable) | ||
| } | ||
|
|
||
| test("Concat") { | ||
| // Primitive-type elements | ||
| val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation?