diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e76193fd9422..f94e733ad813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -243,7 +243,7 @@ object TypeCoercion { * string. If the wider decimal type exceeds system limitation, this rule will truncate * the decimal type before return it. */ - private[analysis] def findWiderTypeWithoutStringPromotionForTwo( + private[catalyst] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d5d42510842e..6ed68e47ce7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1081,7 +1081,7 @@ case class ArrayContains(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findTightestCommonType(e1, e2) match { + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a346377cd1bc..584768eff700 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -850,7 +850,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val errorMsg1 = s""" |Input to function array_contains should have been array followed by a - |value with same element type, but it's [array, decimal(29,29)]. + |value with same element type, but it's [array, decimal(38,29)]. """.stripMargin.replace("\n", " ").trim() assert(e1.message.contains(errorMsg1)) @@ -865,6 +865,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(e2.message.contains(errorMsg2)) } + test("SPARK-29600: ArrayContains function may return incorrect result for DecimalType") { + checkAnswer( + sql("select array_contains(array(1.10), 1.1)"), + Seq(Row(true)) + ) + + checkAnswer( + sql("SELECT array_contains(array(1.1), 1.10)"), + Seq(Row(true)) + ) + + checkAnswer( + sql("SELECT array_contains(array(1.11), 1.1)"), + Seq(Row(false)) + ) + } + test("arrays_overlap function") { val df = Seq( (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),