Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ trait CheckAnalysis extends PredicateHelper {
u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")

case operator: LogicalPlan =>
// Check argument data types of higher-order functions downwards first because function
// arguments of the higher-order functions might be unresolved due to the unresolved
// argument data types, otherwise always claims the function arguments are unresolved.
operator transformExpressionsDown {
case hof: HigherOrderFunction
if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure =>
hof.checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
hof.failAnalysis(
s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message")
}
}

operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.qualifiedName).mkString(", ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,23 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
*/
private def createLambda(
e: Expression,
partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match {
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
case f: LambdaFunction if f.bound => f

case LambdaFunction(function, names, _) =>
if (names.size != partialArguments.size) {
if (names.size != argInfo.size) {
e.failAnalysis(
s"The number of lambda function arguments '${names.size}' does not " +
"match the number of arguments expected by the higher order function " +
s"'${partialArguments.size}'.")
s"'${argInfo.size}'.")
}

if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) {
e.failAnalysis(
"Lambda function arguments should not have names that are semantically the same.")
}

val arguments = partialArguments.zip(names).map {
val arguments = argInfo.zip(names).map {
case ((dataType, nullable), ne) =>
NamedLambdaVariable(ne.name, dataType, nullable)
}
Expand All @@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
// create a lambda function with default parameters because this is expected by the higher
// order function. Note that we hide the lambda variables produced by this function in order
// to prevent accidental naming collisions.
val arguments = partialArguments.zipWithIndex.map {
val arguments = argInfo.zipWithIndex.map {
case ((dataType, nullable), i) =>
NamedLambdaVariable(s"col$i", dataType, nullable)
}
Expand All @@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match {
case _ if e.resolved => e

case h: HigherOrderFunction if h.inputResolved =>
case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess =>
h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))

case l: LambdaFunction if !l.bound =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression {
def inputTypes: Seq[AbstractDataType]

override def checkInputDataTypes(): TypeCheckResult = {
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
}
}

object ExpectsInputTypes {

def checkInputDataTypes(
inputs: Seq[Expression],
inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
s"however, '${child.sql}' is of ${child.dataType.catalogString} type."
s"however, '${input.sql}' is of ${input.dataType.catalogString} type."
}

if (mismatches.isEmpty) {
Expand All @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression {
}
}


/**
* A mixin for the analyzer to perform implicit type casting using
* [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].
Expand Down
Loading