Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ object TypeCoercion {
* string. If the wider decimal type exceeds system limitation, this rule will truncate
* the decimal type before return it.
*/
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
private[catalyst] def findWiderTypeWithoutStringPromotionForTwo(
t1: DataType,
t2: DataType): Option[DataType] = {
findTightestCommonType(t1, t2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ case class ArrayContains(left: Expression, right: Expression)
(left.dataType, right.dataType) match {
case (_, NullType) => Seq.empty
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
val errorMsg1 =
s"""
|Input to function array_contains should have been array followed by a
|value with same element type, but it's [array<int>, decimal(29,29)].
|value with same element type, but it's [array<int>, decimal(38,29)].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why precision becomes 38 in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)

For query array_contains(array(1), .01234567890123456790123456780)
e.inputTypes will return Seq(Array(Decimal(38,29)), Decimal(38,29)) and above code will cast .01234567890123456790123456780 as Decimal(38,29).
Previously, when we were using findWiderTypeForTwo, decimal types were not getting upcasted but findWiderTypeWithoutStringPromotionForTwo will successfully upcast DecimalType

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, when we were using findWiderTypeForTwo

Before this PR, we were using findTightestCommonType. Why do we add cast but still can't resolve ArrayContains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean why in above test case query, ArrayContains is throwing AnalysisException instead of casting integer to Decimal?

An integer cannot be casted to decimal with scale > 28.

decimalWith28Zeroes = 1.0000000000000000000000000000
SELECT array_contains(array(1), decimalWith28Zeroes);
Result =>> true
decimalWith29Zeroes = 1.00000000000000000000000000000
SELECT array_contains(array(1), decimalWith29Zeroes);
Result =>> AnalysisException

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@cloud-fan cloud-fan Dec 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I get that we can't do cast here. My question is: since we can't do cast, we should leave the expression un-touched. But now we add cast to one side and leave the expression unresolved. Where do we add that useless cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)

This code is to cast left and right expression one by one. Here,

  • e.childern is Seq( array<int>, decimal(29,29)), and
  • e.inputTypes will return Seq(array<decimal(38,29)>, decimal(38,29))

impicitCast(array<int>, array<decimal(38,29)>) will return None, since int can't be casted to decimal(38,29).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above code is creating new expression by updating only right child.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for finding this out!

""".stripMargin.replace("\n", " ").trim()
assert(e1.message.contains(errorMsg1))

Expand All @@ -865,6 +865,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(e2.message.contains(errorMsg2))
}

test("SPARK-29600: ArrayContains function may return incorrect result for DecimalType") {
checkAnswer(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a bug, can you split these three tests into a separate test unit and add a test title with the jira ID(SPARK-29600)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you update the title, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll update

sql("select array_contains(array(1.10), 1.1)"),
Seq(Row(true))
)

checkAnswer(
sql("SELECT array_contains(array(1.1), 1.10)"),
Seq(Row(true))
)

checkAnswer(
sql("SELECT array_contains(array(1.11), 1.1)"),
Seq(Row(false))
)
}

test("arrays_overlap function") {
val df = Seq(
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),
Expand Down