From d2bf2be8c244910dbc2fc066f8c73212208ffe01 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 11 Aug 2018 13:52:59 +0900 Subject: [PATCH 1/7] Rename inputs to arguments. --- .../analysis/higherOrderFunctions.scala | 2 +- .../expressions/higherOrderFunctions.scala | 82 +++++++++---------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 5e2029c251ee4..7561ed9f555ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { case _ if e.resolved => e - case h: HigherOrderFunction if h.inputResolved => + case h: HigherOrderFunction if h.argumentsResolved => h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) case l: LambdaFunction if !l.bound => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7f8203ab92213..7914e01aaa2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -90,18 +90,18 @@ object LambdaFunction { */ trait HigherOrderFunction extends Expression { - override def children: Seq[Expression] = inputs ++ functions + override def children: Seq[Expression] = arguments ++ functions /** - * Inputs to the higher ordered function. + * Arguments of the higher ordered function. */ - def inputs: Seq[Expression] + def arguments: Seq[Expression] /** - * All inputs have been resolved. This means that the types and nullabilty of (most of) the + * All arguments have been resolved. This means that the types and nullabilty of (most of) the * lambda function arguments is known, and that we can start binding the lambda functions. */ - lazy val inputResolved: Boolean = inputs.forall(_.resolved) + lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) /** * Functions applied by the higher order function. @@ -111,7 +111,7 @@ trait HigherOrderFunction extends Expression { /** * All inputs must be resolved and all functions must be resolved lambda functions. */ - override lazy val resolved: Boolean = inputResolved && functions.forall { + override lazy val resolved: Boolean = argumentsResolved && functions.forall { case l: LambdaFunction => l.resolved case _ => false } @@ -157,9 +157,9 @@ object HigherOrderFunction { */ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { - def input: Expression + def argument: Expression - override def inputs: Seq[Expression] = input :: Nil + override def arguments: Seq[Expression] = argument :: Nil def function: Expression @@ -173,11 +173,11 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method * in order to save null-check code. */ - protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") override def eval(inputRow: InternalRow): Any = { - val value = input.eval(inputRow) + val value = argument.eval(inputRow) if (value == null) { null } else { @@ -209,16 +209,16 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { """, since = "2.4.0") case class ArrayTransform( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, elem :: (IntegerType, false) :: Nil)) @@ -237,8 +237,8 @@ case class ArrayTransform( (elementVar, indexVar) } - override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { - val arr = inputValue.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val result = new GenericArrayData(new Array[Any](arr.numElements)) var i = 0 @@ -268,7 +268,7 @@ examples = """ """, since = "2.4.0") case class MapFilter( - input: Expression, + argument: Expression, function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { @@ -278,16 +278,16 @@ case class MapFilter( } @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + HigherOrderFunction.mapKeyValueArgumentType(argument.dataType) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val m = value.asInstanceOf[MapData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val m = argumentValue.asInstanceOf[MapData] val f = functionForEval val retKeys = new mutable.ListBuffer[Any] val retValues = new mutable.ListBuffer[Any] @@ -302,7 +302,7 @@ case class MapFilter( ArrayBasedMapData(retKeys.toArray, retValues.toArray) } - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType override def expectingFunctionType: AbstractDataType = BooleanType @@ -321,25 +321,25 @@ case class MapFilter( """, since = "2.4.0") case class ArrayFilter( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) copy(function = f(function, elem :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val buffer = new mutable.ArrayBuffer[Any](arr.numElements) var i = 0 @@ -368,25 +368,25 @@ case class ArrayFilter( """, since = "2.4.0") case class ArrayExists( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: DataType = BooleanType override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) copy(function = f(function, elem :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval var exists = false var i = 0 @@ -422,29 +422,29 @@ case class ArrayExists( """, since = "2.4.0") case class ArrayAggregate( - input: Expression, + argument: Expression, zero: Expression, merge: Expression, finish: Expression) extends HigherOrderFunction with CodegenFallback { - def this(input: Expression, zero: Expression, merge: Expression) = { - this(input, zero, merge, LambdaFunction.identity) + def this(argument: Expression, zero: Expression, merge: Expression) = { + this(argument, zero, merge, LambdaFunction.identity) } - override def inputs: Seq[Expression] = input :: zero :: Nil + override def arguments: Seq[Expression] = argument :: zero :: Nil override def functions: Seq[Expression] = merge :: finish :: Nil - override def nullable: Boolean = input.nullable || finish.nullable + override def nullable: Boolean = argument.nullable || finish.nullable override def dataType: DataType = finish.dataType override def checkInputDataTypes(): TypeCheckResult = { - if (!ArrayType.acceptsType(input.dataType)) { + if (!ArrayType.acceptsType(argument.dataType)) { TypeCheckResult.TypeCheckFailure( s"argument 1 requires ${ArrayType.simpleString} type, " + - s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") + s"however, '${argument.sql}' is of ${argument.dataType.catalogString} type.") } else if (!DataType.equalsStructurally( zero.dataType, merge.dataType, ignoreNullability = true)) { TypeCheckResult.TypeCheckFailure( @@ -458,7 +458,7 @@ case class ArrayAggregate( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) val acc = zero.dataType -> true val newMerge = f(merge, acc :: elem :: Nil) val newFinish = f(finish, acc :: Nil) @@ -470,7 +470,7 @@ case class ArrayAggregate( @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] + val arr = argument.eval(input).asInstanceOf[ArrayData] if (arr == null) { null } else { From 1ece98baeb5415b975995e1a1626dd90101dd64a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 11 Aug 2018 13:54:18 +0900 Subject: [PATCH 2/7] Rename partialArguments to argInfo. --- .../sql/catalyst/analysis/higherOrderFunctions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 7561ed9f555ed..8c14ad63fb500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -95,15 +95,15 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { */ private def createLambda( e: Expression, - partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match { + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { case f: LambdaFunction if f.bound => f case LambdaFunction(function, names, _) => - if (names.size != partialArguments.size) { + if (names.size != argInfo.size) { e.failAnalysis( s"The number of lambda function arguments '${names.size}' does not " + "match the number of arguments expected by the higher order function " + - s"'${partialArguments.size}'.") + s"'${argInfo.size}'.") } if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { @@ -111,7 +111,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { "Lambda function arguments should not have names that are semantically the same.") } - val arguments = partialArguments.zip(names).map { + val arguments = argInfo.zip(names).map { case ((dataType, nullable), ne) => NamedLambdaVariable(ne.name, dataType, nullable) } @@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { // create a lambda function with default parameters because this is expected by the higher // order function. Note that we hide the lambda variables produced by this function in order // to prevent accidental naming collisions. - val arguments = partialArguments.zipWithIndex.map { + val arguments = argInfo.zipWithIndex.map { case ((dataType, nullable), i) => NamedLambdaVariable(s"col$i", dataType, nullable) } From fb23aba7e3e5bb4eef2f3927212190b6507cbd65 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 11 Aug 2018 15:20:55 +0900 Subject: [PATCH 3/7] Add argument data type check. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 13 +++ .../analysis/higherOrderFunctions.scala | 2 +- .../expressions/ExpectsInputTypes.scala | 16 +++- .../expressions/higherOrderFunctions.scala | 85 +++++++++---------- .../spark/sql/DataFrameFunctionsSuite.scala | 25 ++++++ 5 files changed, 93 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4addc83add3e0..b3e04413c57d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -90,6 +90,19 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => + // Check argument data types of higher-order functions downwards first because function + // arguments of the higher-order functions might be unresolved due to the unresolved + // argument data types, otherwise always claims the function arguments are unresolved. + operator transformExpressionsDown { + case hof: HigherOrderFunction + if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => + hof.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + hof.failAnalysis( + s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message") + } + } + operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.qualifiedName).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 8c14ad63fb500..dd08190e1e8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { case _ if e.resolved => e - case h: HigherOrderFunction if h.argumentsResolved => + case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) case l: LambdaFunction if !l.bound => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index d8f046c0028a9..981ce0b6a29fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) + } +} + +object ExpectsInputTypes { + + def checkInputDataTypes( + inputs: Seq[Expression], + inputTypes: Seq[AbstractDataType]): TypeCheckResult = { + val mismatches = inputs.zip(inputTypes).zipWithIndex.collect { + case ((input, expected), idx) if !expected.acceptsType(input.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.catalogString} type." + s"however, '${input.sql}' is of ${input.dataType.catalogString} type." } if (mismatches.isEmpty) { @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression { } } - /** * A mixin for the analyzer to perform implicit type casting using * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7914e01aaa2e8..8bec2171f2b88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -103,6 +103,13 @@ trait HigherOrderFunction extends Expression { */ lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) + /** + * Checks the argument data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `argumentsResolved == true`. + */ + def checkArgumentDataTypes(): TypeCheckResult + /** * Functions applied by the higher order function. */ @@ -133,25 +140,6 @@ trait HigherOrderFunction extends Expression { } } -object HigherOrderFunction { - - def arrayArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } - } - - def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { - case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) - case _ => - val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType - (kType, vType, vContainsNull) - } -} - /** * Trait for functions having as input one argument and one function. */ @@ -161,12 +149,20 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp override def arguments: Seq[Expression] = argument :: Nil + def argumentType: AbstractDataType + + override def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentType :: Nil) + } + def function: Expression override def functions: Seq[Expression] = function :: Nil def expectingFunctionType: AbstractDataType = AnyDataType + override def inputTypes: Seq[AbstractDataType] = Seq(argumentType, expectingFunctionType) + @transient lazy val functionForEval: Expression = functionsForEval.head /** @@ -187,11 +183,11 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) + override def argumentType: AbstractDataType = ArrayType } trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + override def argumentType: AbstractDataType = MapType } /** @@ -218,12 +214,12 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) + val ArrayType(elementType, containsNull) = argument.dataType function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, elem :: (IntegerType, false) :: Nil)) + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, elem :: Nil)) + copy(function = f(function, (elementType, containsNull) :: Nil)) } } @@ -277,8 +273,7 @@ case class MapFilter( (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) } - @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(argument.dataType) + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) @@ -332,8 +327,8 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function @@ -379,8 +374,8 @@ case class ArrayExists( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { - val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function @@ -440,27 +435,31 @@ case class ArrayAggregate( override def dataType: DataType = finish.dataType + override def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, Seq(ArrayType, AnyDataType)) + } + override def checkInputDataTypes(): TypeCheckResult = { - if (!ArrayType.acceptsType(argument.dataType)) { - TypeCheckResult.TypeCheckFailure( - s"argument 1 requires ${ArrayType.simpleString} type, " + - s"however, '${argument.sql}' is of ${argument.dataType.catalogString} type.") - } else if (!DataType.equalsStructurally( - zero.dataType, merge.dataType, ignoreNullability = true)) { - TypeCheckResult.TypeCheckFailure( - s"argument 3 requires ${zero.dataType.simpleString} type, " + - s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") - } else { - TypeCheckResult.TypeCheckSuccess + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure } } override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = HigherOrderFunction.arrayArgumentType(argument.dataType) + val ArrayType(elementType, containsNull) = argument.dataType val acc = zero.dataType -> true - val newMerge = f(merge, acc :: elem :: Nil) + val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil) val newFinish = f(finish, acc :: Nil) copy(merge = newMerge, finish = newFinish) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2c4238e69ad7c..6401e3fc99783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1852,6 +1852,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("transform(i, x -> x)") } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("transform(a, x -> x)") + } + assert(ex3.getMessage.contains("cannot resolve '`a`'")) } test("map_filter") { @@ -1898,6 +1903,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("map_filter(i, (k, v) -> k > v)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_filter(a, (k, v) -> k > v)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("filter function - array for primitive type not containing null") { @@ -1994,6 +2004,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("filter(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("exists function - array for primitive type not containing null") { @@ -2090,6 +2105,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("exists(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { @@ -2211,6 +2231,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("aggregate(a, 0, (acc, x) -> x)") + } + assert(ex5.getMessage.contains("cannot resolve '`a`'")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From 3ccd995e8a1260205fea5db63c40c644499a5e01 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 11 Aug 2018 13:34:13 +0900 Subject: [PATCH 4/7] Address other comments. --- .../catalyst/expressions/higherOrderFunctions.scala | 10 ++++++---- .../org/apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8bec2171f2b88..6a3323849f66f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -35,8 +35,8 @@ case class NamedLambdaVariable( name: String, dataType: DataType, nullable: Boolean, - value: AtomicReference[Any] = new AtomicReference(), - exprId: ExprId = NamedExpression.newExprId) + exprId: ExprId = NamedExpression.newExprId, + value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression with NamedExpression with CodegenFallback { @@ -44,7 +44,7 @@ case class NamedLambdaVariable( override def qualifier: Seq[String] = Seq.empty override def newInstance(): NamedExpression = - copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) + copy(exprId = NamedExpression.newExprId, value = new AtomicReference()) override def toAttribute: Attribute = { AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) @@ -130,6 +130,8 @@ trait HigherOrderFunction extends Expression { */ def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + // Make sure the lambda variables refer the same instances as of arguments for case that the + // variables in instantiated separately during serialization or for some reason. @transient lazy val functionsForEval: Seq[Expression] = functions.map { case LambdaFunction(function, arguments, hidden) => val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap @@ -163,7 +165,7 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp override def inputTypes: Seq[AbstractDataType] = Seq(argumentType, expectingFunctionType) - @transient lazy val functionForEval: Expression = functionsForEval.head + def functionForEval: Expression = functionsForEval.head /** * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 9e95b192968c7..67740c3166471 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -81,7 +81,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) case lv: NamedLambdaVariable => - lv.copy(value = null, exprId = ExprId(0)) + lv.copy(exprId = ExprId(0), value = null) } } From 388c2d3d812bf749ddf9de029432eab729bcc932 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Aug 2018 01:21:04 +0900 Subject: [PATCH 5/7] Address comments. --- .../expressions/higherOrderFunctions.scala | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 6a3323849f66f..5d1b8c4da0bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -88,7 +88,7 @@ object LambdaFunction { * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. */ -trait HigherOrderFunction extends Expression { +trait HigherOrderFunction extends Expression with ExpectsInputTypes { override def children: Seq[Expression] = arguments ++ functions @@ -97,6 +97,8 @@ trait HigherOrderFunction extends Expression { */ def arguments: Seq[Expression] + def argumentTypes: Seq[AbstractDataType] + /** * All arguments have been resolved. This means that the types and nullabilty of (most of) the * lambda function arguments is known, and that we can start binding the lambda functions. @@ -108,13 +110,19 @@ trait HigherOrderFunction extends Expression { * or returns a `TypeCheckResult` with an error message if invalid. * Note: it's not valid to call this method until `argumentsResolved == true`. */ - def checkArgumentDataTypes(): TypeCheckResult + def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes) + } /** * Functions applied by the higher order function. */ def functions: Seq[Expression] + def functionTypes: Seq[AbstractDataType] + + override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes + /** * All inputs must be resolved and all functions must be resolved lambda functions. */ @@ -145,7 +153,7 @@ trait HigherOrderFunction extends Expression { /** * Trait for functions having as input one argument and one function. */ -trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +trait SimpleHigherOrderFunction extends HigherOrderFunction { def argument: Expression @@ -153,17 +161,15 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp def argumentType: AbstractDataType - override def checkArgumentDataTypes(): TypeCheckResult = { - ExpectsInputTypes.checkInputDataTypes(arguments, argumentType :: Nil) - } + override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil def function: Expression override def functions: Seq[Expression] = function :: Nil - def expectingFunctionType: AbstractDataType = AnyDataType + def functionType: AbstractDataType = AnyDataType - override def inputTypes: Seq[AbstractDataType] = Seq(argumentType, expectingFunctionType) + override def functionTypes: Seq[AbstractDataType] = functionType :: Nil def functionForEval: Expression = functionsForEval.head @@ -301,7 +307,7 @@ case class MapFilter( override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def prettyName: String = "map_filter" } @@ -326,7 +332,7 @@ case class ArrayFilter( override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { val ArrayType(elementType, containsNull) = argument.dataType @@ -373,7 +379,7 @@ case class ArrayExists( override def dataType: DataType = BooleanType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { val ArrayType(elementType, containsNull) = argument.dataType @@ -431,16 +437,16 @@ case class ArrayAggregate( override def arguments: Seq[Expression] = argument :: zero :: Nil + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil + override def functions: Seq[Expression] = merge :: finish :: Nil + override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil + override def nullable: Boolean = argument.nullable || finish.nullable override def dataType: DataType = finish.dataType - override def checkArgumentDataTypes(): TypeCheckResult = { - ExpectsInputTypes.checkInputDataTypes(arguments, Seq(ArrayType, AnyDataType)) - } - override def checkInputDataTypes(): TypeCheckResult = { checkArgumentDataTypes() match { case TypeCheckResult.TypeCheckSuccess => From 4b48a39238ab80f2bd1ebb36fd653ecc6495e492 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Aug 2018 11:49:20 +0900 Subject: [PATCH 6/7] Reword a comment. --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b3e04413c57d7..ca7394ee7af9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -90,9 +90,10 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => - // Check argument data types of higher-order functions downwards first because function - // arguments of the higher-order functions might be unresolved due to the unresolved - // argument data types, otherwise always claims the function arguments are unresolved. + // Check argument data types of higher-order functions downwards first. + // If the arguments of the higher-order functions are resolved but the type check fails, + // the argument functions will not get resolved, but we should report the argument type + // check failure instead of claiming the function arguments are unresolved. operator transformExpressionsDown { case hof: HigherOrderFunction if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => From deee1dcef6d7cbde516fa082e4210261ff89b8ff Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Aug 2018 13:09:52 +0900 Subject: [PATCH 7/7] Fix a comment. --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ca7394ee7af9d..6a91d556b2f3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -93,7 +93,7 @@ trait CheckAnalysis extends PredicateHelper { // Check argument data types of higher-order functions downwards first. // If the arguments of the higher-order functions are resolved but the type check fails, // the argument functions will not get resolved, but we should report the argument type - // check failure instead of claiming the function arguments are unresolved. + // check failure instead of claiming the argument functions are unresolved. operator transformExpressionsDown { case hof: HigherOrderFunction if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure =>