diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3cc6b9f..21308bcec035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1394,8 +1394,14 @@ case class ArrayPosition(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df23e07e441a..afa952528631 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -708,6 +708,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") {