From 629ec84389d2decd876be085b8df915f8e6c74cf Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 12 Sep 2018 00:44:23 -0700 Subject: [PATCH 1/3] [SPARK-25415] ArrayPosition function may return incorrect result when right expression is implicitly down casted. --- .../expressions/collectionOperations.scala | 19 +++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 44 ++++++++++++++++++- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ea6fcccddfd4..54e685e955bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2062,18 +2062,23 @@ case class ArrayPosition(left: Expression, right: Expression) override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty } - Seq(ArrayType, elementType) } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 121db442c77f..2656b252ca36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1045,6 +1045,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(null), Row(null)) ) + checkAnswer( + df.selectExpr("array_position(array(1), 1.23D)"), + Seq(Row(0L), Row(0L)) + ) + + checkAnswer( + df.selectExpr("array_position(array(1), 1.0D)"), + Seq(Row(1L), Row(1L)) + ) + + checkAnswer( + df.selectExpr("array_position(array(1.23D), 1)"), + Seq(Row(0L), Row(0L)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1)), array(1.0D))"), + Seq(Row(1L), Row(1L)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1)), array(1.23D))"), + Seq(Row(0L), Row(0L)) + ) + checkAnswer( df.selectExpr("array_position(array(array(1), null)[0], 1)"), Seq(Row(1L), Row(1L)) @@ -1054,10 +1079,25 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(1L), Row(1L)) ) - val e = intercept[AnalysisException] { + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + df.selectExpr("array_position(array(1), '1')") + } + val errorMsg2 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("element_at function") { From bb181084b8deeee0130bf53fcc1417b10d518eae Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 12 Sep 2018 23:23:30 -0700 Subject: [PATCH 2/3] Code review --- .../sql/catalyst/expressions/collectionOperations.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 54e685e955bb..ca41314b98dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2075,7 +2075,7 @@ case class ArrayPosition(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2656b252ca36..b2d31f0f84b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1055,6 +1055,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(1L), Row(1L)) ) + checkAnswer( + df.selectExpr("array_position(array(1.D), 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( df.selectExpr("array_position(array(1.23D), 1)"), Seq(Row(0L), Row(0L)) From 55d4b950951892f3a239f960feadbe1a25198659 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 22 Sep 2018 12:08:02 -0700 Subject: [PATCH 3/3] Code review --- .../spark/sql/DataFrameFunctionsSuite.scala | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b2d31f0f84b2..5467f4d718bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1046,42 +1047,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - df.selectExpr("array_position(array(1), 1.23D)"), - Seq(Row(0L), Row(0L)) + OneRowRelation().selectExpr("array_position(array(1), 1.23D)"), + Seq(Row(0L)) ) checkAnswer( - df.selectExpr("array_position(array(1), 1.0D)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.0D)"), + Seq(Row(1L)) ) checkAnswer( - df.selectExpr("array_position(array(1.D), 1)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1.D), 1)"), + Seq(Row(1L)) ) checkAnswer( - df.selectExpr("array_position(array(1.23D), 1)"), - Seq(Row(0L), Row(0L)) + OneRowRelation().selectExpr("array_position(array(1.23D), 1)"), + Seq(Row(0L)) ) checkAnswer( - df.selectExpr("array_position(array(array(1)), array(1.0D))"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"), + Seq(Row(1L)) ) checkAnswer( - df.selectExpr("array_position(array(array(1)), array(1.23D))"), - Seq(Row(0L), Row(0L)) + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"), + Seq(Row(0L)) ) checkAnswer( - df.selectExpr("array_position(array(array(1), null)[0], 1)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L)) ) checkAnswer( - df.selectExpr("array_position(array(1, null), array(1, null)[0])"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L)) ) val e1 = intercept[AnalysisException] { @@ -1095,7 +1096,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e1.message.contains(errorMsg1)) val e2 = intercept[AnalysisException] { - df.selectExpr("array_position(array(1), '1')") + OneRowRelation().selectExpr("array_position(array(1), '1')") } val errorMsg2 = s"""