Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
case class ElementAt(left: Expression, right: Expression)
extends GetMapValueUtil with GetArrayItemUtil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation?


@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

Expand Down Expand Up @@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}

override def nullable: Boolean = true
override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: MapType => computeNullabilityFromMap
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Map is out of the scope, shall we update the title and use the following code path instead?

  • Remove extends GetMapValueUtil.
  • Remove computeNullabilityFromMap and use the following.
case _: MapType => true

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

}

override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation?


// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
Expand All @@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def left: Expression = child
override def right: Expression = ordinal

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}

override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
Expand Down Expand Up @@ -281,10 +266,53 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

/**
* Common base class for [[GetMapValue]] and [[ElementAt]].
* Common trait for [[GetArrayItem]] and [[ElementAt]].
*/
trait GetArrayItemUtil {

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}
}
}

/**
* Common trait for [[GetMapValue]] and [[ElementAt]].
*/
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
Copy link
Member

@dongjoon-hyun dongjoon-hyun Feb 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GetMapValueUtil need to extend extends BinaryExpression with ImplicitCastInputTypes? If child and key are required, computeNullabilityFromMap seems to able to accept the child and key as parameters. How do you think that way?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems BinaryExpression.nullSafeCodeGen is used in the other place?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Then, what about GetArrayItemUtil? Actually, the reason why I asked is that this PR makes some non-trivial lineages like the followings.

  • def BinaryExpression.left -> private val GetArrayItemUtil.child -> override def GetArrayItem.left.
  • def BinaryExpression.left -> private val GetMapValueUtil.child -> override def GetMapValue.left

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, ok. I'll simplify it.


private val child = left
private val key = right

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromMap: Boolean = if (key.foldable && !key.nullable) {
val keyObj = key.eval()
child match {
Copy link
Member

@dongjoon-hyun dongjoon-hyun Feb 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for updating, @maropu . For this one, can we simplify like the following? We can remove these alias variables.

-  private val child = left
-  private val key = right
-
   /** `Null` is returned for invalid ordinals. */
-  protected def computeNullabilityFromMap: Boolean = if (key.foldable && !key.nullable) {
-    val keyObj = key.eval()
-    child match {
+  protected def computeNullabilityFromMap: Boolean = if (right.foldable && !right.nullable) {
+    val keyObj = right.eval()
+    left match {

Copy link
Member

@dongjoon-hyun dongjoon-hyun Feb 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep the clear meaning, you may declare them inside of this function. But, since we are not using parameters, I think we already have some assumptions that this is inside BinaryExpression. So, for me, I just want to avoid making aliases like my previous comment.

protected def computeNullabilityFromMap: Boolean = {
    val child = left
    val key = right
    ....
}

case m: CreateMap if m.resolved =>
m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
case (k, _) if k.eval() == keyObj => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rely on ==? What if one expression returns unsafe row and the other returns safe row?

I know this is existing code, but after a hindsight maybe it's not worth to linear scan the map entries and just to get the precise nullability.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I'll update soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dropped the nullablity computation for the map side and added the TODO comment there. How about this?

case _ => false
}.map(_._2.nullable).getOrElse(true)
case _ =>
true
}
} else {
true
}

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
Expand Down Expand Up @@ -379,24 +407,7 @@ case class GetMapValue(child: Expression, key: Expression)

override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = if (key.foldable && !key.nullable) {
val keyObj = key.eval()
child match {
case m: CreateMap if m.resolved =>
m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
case (k, _) if k.eval() == keyObj => true
case _ => false
}.map(_._2.nullable).getOrElse(true)
case _ =>
true
}
} else {
true
}


override def nullable: Boolean = computeNullabilityFromMap
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}

test("SPARK-26965 correctly handles ElementAt nullability for arrays") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a new feature, it seems that we don't need JIRA ID, SPARK-26965, in the test case name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dropped that in the latest commit though, any rule for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I've heard the rule from @cloud-fan in my old PRs.

I'm not sure if that is a written rule or not.

Hi, @cloud-fan . Do we have some URLs for the above SPARK JIRA ID usage rule in the test case name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. I'll be more careful next time, thanks!

// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(0)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(1)).nullable)
}

test("SPARK-26965 correctly handles ElementAt nullability for maps") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ur, ditto. SPARK-26965.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh...

// String key test
val k1 = Literal("k1")
val v1 = AttributeReference("v1", StringType, nullable = true)()
val k2 = Literal("k2")
val v2 = AttributeReference("v2", StringType, nullable = false)()
val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil)
assert(ElementAt(map1, Literal("k1")).nullable)
assert(!ElementAt(map1, Literal("k2")).nullable)
assert(ElementAt(map1, Literal("non-existent-key")).nullable)

// Complex type key test
val k3 = Literal.create((1, "a"))
val k4 = Literal.create((2, "b"))
val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil)
assert(ElementAt(map2, Literal.create((1, "a"))).nullable)
assert(!ElementAt(map2, Literal.create((2, "b"))).nullable)
}

test("Concat") {
// Primitive-type elements
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
Expand Down