diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 83c76c2d4e2b..75d81aea1ca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -471,7 +471,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType)) + .orElse(findWiderTypeForTwo(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -493,7 +493,10 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + findWiderCommonType(b.map(_.dataType)).flatMap { listDataType => + findCommonTypeForBinaryComparison(listDataType, a.dataType, conf) + .orElse(findWiderTypeForTwo(listDataType, a.dataType)) + } match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5983fe63c79e..faa33607a016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1997,6 +1997,14 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } + test("SPARK-22398: type coercion for IN predicates should be coherent") { + val df = spark.range(2) + val inWithLiterals = df.where("id in ('01')") + val inWithSubquery = df.where("id in (select '01' from (select 1))") + assert(inWithLiterals.count() == inWithSubquery.count(), + "IN behavior is not the same with list of Literals and with a subquery") + } + test("SPARK-22520: support code generation for large CaseWhen") { val N = 30 var expr1 = when($"id" === lit(0), 0)