Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2062,18 +2062,23 @@ case class ArrayPosition(left: Expression, right: Expression)
override def dataType: DataType = LongType

override def inputTypes: Seq[AbstractDataType] = {
val elementType = left.dataType match {
case t: ArrayType => t.elementType
case _ => AnyDataType
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
case _ => Seq.empty
}
Seq(ArrayType, elementType)
}

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1046,18 +1047,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)

checkAnswer(
df.selectExpr("array_position(array(array(1), null)[0], 1)"),
Seq(Row(1L), Row(1L))
OneRowRelation().selectExpr("array_position(array(1), 1.23D)"),
Seq(Row(0L))
)

checkAnswer(
df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
Seq(Row(1L), Row(1L))
OneRowRelation().selectExpr("array_position(array(1), 1.0D)"),
Seq(Row(1L))
)

val e = intercept[AnalysisException] {
checkAnswer(
OneRowRelation().selectExpr("array_position(array(1.D), 1)"),
Seq(Row(1L))
)

checkAnswer(
OneRowRelation().selectExpr("array_position(array(1.23D), 1)"),
Seq(Row(0L))
)

checkAnswer(
OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"),
Seq(Row(1L))
)

checkAnswer(
OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"),
Seq(Row(0L))
)

checkAnswer(
OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"),
Seq(Row(1L))
)
checkAnswer(
OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"),
Seq(Row(1L))
)

val e1 = intercept[AnalysisException] {
Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)")
}
assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type"))
val errorMsg1 =
s"""
|Input to function array_position should have been array followed by a
|value with same element type, but it's [string, string].
""".stripMargin.replace("\n", " ").trim()
assert(e1.message.contains(errorMsg1))

val e2 = intercept[AnalysisException] {
OneRowRelation().selectExpr("array_position(array(1), '1')")
}
val errorMsg2 =
s"""
|Input to function array_position should have been array followed by a
|value with same element type, but it's [array<int>, string].
""".stripMargin.replace("\n", " ").trim()
assert(e2.message.contains(errorMsg2))
}

test("element_at function") {
Expand Down