diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ea6fcccddfd4..ca41314b98dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2062,18 +2062,23 @@ case class ArrayPosition(left: Expression, right: Expression) override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty } - Seq(ArrayType, elementType) } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 121db442c77f..5467f4d718bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1046,18 +1047,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - df.selectExpr("array_position(array(array(1), null)[0], 1)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.23D)"), + Seq(Row(0L)) ) + checkAnswer( - df.selectExpr("array_position(array(1, null), array(1, null)[0])"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.0D)"), + Seq(Row(1L)) ) - val e = intercept[AnalysisException] { + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.D), 1)"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.23D), 1)"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L)) + ) + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L)) + ) + + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_position(array(1), '1')") + } + val errorMsg2 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("element_at function") {