diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1896a1c7ac27..0da715735426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2360,7 +2360,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val inputType = extractInputType(args) val bound = unbound.bind(inputType) validateParameterModes(bound) - val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args) + val rearrangedArgs = + NamedParametersSupport.defaultRearrange(bound, args, SQLConf.get.resolver) Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 99e0c707d887..afa43e876b26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -1024,9 +1025,9 @@ object FunctionRegistry { name: String, builder: T, expressions: Seq[Expression]) : Seq[Expression] = { - val rearrangedExpressions = if (!builder.functionSignature.isEmpty) { + val rearrangedExpressions = if (builder.functionSignature.isDefined) { val functionSignature = builder.functionSignature.get - builder.rearrange(functionSignature, expressions, name) + builder.rearrange(functionSignature, expressions, name, SQLConf.get.resolver) } else { expressions } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index e0d1cf011e06..5464f5077e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1936,7 +1936,7 @@ class SessionCatalog( } NamedParametersSupport.defaultRearrange( - FunctionSignature(paramNames), expressions, functionName) + FunctionSignature(paramNames), expressions, functionName, SQLConf.get.resolver) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 638d20cff928..26bf05d256b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter} @@ -67,8 +68,10 @@ trait FunctionBuilderBase[T] { def rearrange( expectedSignature: FunctionSignature, providedArguments: Seq[Expression], - functionName: String) : Seq[Expression] = { - NamedParametersSupport.defaultRearrange(expectedSignature, providedArguments, functionName) + functionName: String, + resolver: Resolver) : Seq[Expression] = { + NamedParametersSupport.defaultRearrange( + expectedSignature, providedArguments, functionName, resolver) } def build(funcName: String, expressions: Seq[Expression]): T @@ -89,7 +92,9 @@ object NamedParametersSupport { */ def splitAndCheckNamedArguments( args: Seq[Expression], - functionName: String): (Seq[Expression], Seq[NamedArgumentExpression]) = { + functionName: String, + resolver: Resolver): + (Seq[Expression], Seq[NamedArgumentExpression]) = { val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) val namedParametersSet = collection.mutable.Set[String]() @@ -97,7 +102,7 @@ object NamedParametersSupport { (positionalArgs, namedArgs.zipWithIndex.map { case (namedArg @ NamedArgumentExpression(parameterName, _), _) => - if (namedParametersSet.contains(parameterName)) { + if (namedParametersSet.exists(resolver(_, parameterName))) { throw QueryCompilationErrors.doubleNamedArgumentReference( functionName, parameterName) } @@ -123,15 +128,20 @@ object NamedParametersSupport { final def defaultRearrange( functionSignature: FunctionSignature, args: Seq[Expression], - functionName: String): Seq[Expression] = { - defaultRearrange(functionName, functionSignature.parameters, args) + functionName: String, + resolver: Resolver): Seq[Expression] = { + defaultRearrange(functionName, functionSignature.parameters, args, resolver) } - final def defaultRearrange(procedure: BoundProcedure, args: Seq[Expression]): Seq[Expression] = { + final def defaultRearrange( + procedure: BoundProcedure, + args: Seq[Expression], + resolver: Resolver): Seq[Expression] = { defaultRearrange( procedure.name, procedure.parameters.map(toInputParameter).toSeq, - args) + args, + resolver) } private def toInputParameter(param: ProcedureParameter): InputParameter = { @@ -144,12 +154,13 @@ object NamedParametersSupport { private def defaultRearrange( routineName: String, parameters: Seq[InputParameter], - args: Seq[Expression]): Seq[Expression] = { + args: Seq[Expression], + resolver: Resolver): Seq[Expression] = { if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) { throw QueryCompilationErrors.unexpectedRequiredParameter(routineName, parameters) } - val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName, resolver) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -161,11 +172,11 @@ object NamedParametersSupport { namedArgs.foreach { namedArg => val parameterName = namedArg.key - if (!parameterNamesSet.contains(parameterName)) { + if (!parameterNamesSet.exists(resolver(_, parameterName))) { throw QueryCompilationErrors.unrecognizedParameterName(routineName, namedArg.key, parameterNamesSet.toSeq) } - if (positionalParametersSet.contains(parameterName)) { + if (positionalParametersSet.exists(resolver(_, parameterName))) { throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( routineName, namedArg.key) } @@ -187,14 +198,13 @@ object NamedParametersSupport { // We rearrange named arguments to match their positional order. val rearrangedNamedArgs: Seq[Expression] = namedParameters.zipWithIndex.map { case (param, index) => - namedArgMap.getOrElse( - param.name, + namedArgMap.view.filterKeys(resolver(_, param.name)).headOption.map(_._2).getOrElse { if (param.default.isEmpty) { throw QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index) } else { param.default.get } - ) + } } val rearrangedArgs = positionalArgs ++ rearrangedNamedArgs assert(rearrangedArgs.size == parameters.size) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala index 0715e27403bc..ca6cee26d4a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NamedArgu import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter, NamedParametersSupport} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types.DataType - case class DummyExpression( k1: Expression, k2: Expression, @@ -62,6 +63,8 @@ class NamedParameterFunctionSuite extends AnalysisTest { final val k3Arg = NamedArgumentExpression("k3", Literal("v3")) final val k4Arg = NamedArgumentExpression("k4", Literal("v4")) final val namedK1Arg = NamedArgumentExpression("k1", Literal("v1-2")) + final val upperCaseNamedK1Arg = NamedArgumentExpression("K1", Literal("v1")) + final val upperCaseNamedK4Arg = NamedArgumentExpression("K4", Literal("v4")) final val args = Seq(k1Arg, k4Arg, k2Arg, k3Arg) final val expectedSeq = Seq(Literal("v1"), Literal("v2"), Literal("v3"), Literal("v4")) @@ -70,77 +73,203 @@ class NamedParameterFunctionSuite extends AnalysisTest { InputParameter("k1"), InputParameter("k2", Option(Literal("v2"))), InputParameter("k3"))) test("Check rearrangement of expressions") { - val rearrangedArgs = NamedParametersSupport.defaultRearrange( - signature, args, "function") - for ((returnedArg, expectedArg) <- rearrangedArgs.zip(expectedSeq)) { - assert(returnedArg == expectedArg) - } - val rearrangedArgsWithBuilder = - FunctionRegistry.rearrangeExpressions("function", DummyExpressionBuilder, args) - for ((returnedArg, expectedArg) <- rearrangedArgsWithBuilder.zip(expectedSeq)) { - assert(returnedArg == expectedArg) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + val rearrangedArgs = NamedParametersSupport.defaultRearrange( + signature, args, "function", SQLConf.get.resolver) + for ((returnedArg, expectedArg) <- rearrangedArgs.zip(expectedSeq)) { + assert(returnedArg == expectedArg) + } + val rearrangedArgsWithBuilder = + FunctionRegistry.rearrangeExpressions("function", DummyExpressionBuilder, args) + for ((returnedArg, expectedArg) <- rearrangedArgsWithBuilder.zip(expectedSeq)) { + assert(returnedArg == expectedArg) + } + } } } - private def parseRearrangeException(functionSignature: FunctionSignature, - expressions: Seq[Expression], - functionName: String = "function"): SparkThrowable = { + private def parseRearrangeException( + functionSignature: FunctionSignature, + expressions: Seq[Expression], + functionName: String = "function"): SparkThrowable = { intercept[SparkThrowable]( - NamedParametersSupport.defaultRearrange(functionSignature, expressions, functionName)) + NamedParametersSupport.defaultRearrange( + functionSignature, expressions, functionName, SQLConf.get.resolver)) } test("DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT") { - val condition = - "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED" - checkError( - exception = parseRearrangeException( - signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, namedK1Arg), "foo"), - condition = condition, - parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k1")) - ) - checkError( - exception = parseRearrangeException( - signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, k4Arg), "foo"), - condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", - parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) - ) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, namedK1Arg), "foo"), + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k1")) + ) + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, k4Arg), "foo"), + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k4")) + ) + } + } + + withSQLConf(CASE_SENSITIVE.key -> "true") { + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, upperCaseNamedK1Arg), "foo"), + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("foo"), + "argumentName" -> toSQLId("K1"), + "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3"))) + ) + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, upperCaseNamedK4Arg), "foo"), + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("foo"), + "argumentName" -> toSQLId("K4"), + "proposal" -> (toSQLId("k4") + " " + toSQLId("k1") + " " + toSQLId("k2"))) + ) + } + + withSQLConf(CASE_SENSITIVE.key -> "false") { + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, upperCaseNamedK1Arg), "foo"), + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("K1")) + ) + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, upperCaseNamedK4Arg), "foo"), + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("K4")) + ) + } } test("REQUIRED_PARAMETER_NOT_FOUND") { - checkError( - exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), - condition = "REQUIRED_PARAMETER_NOT_FOUND", - parameters = Map( - "routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4"), "index" -> "2")) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k4"), + "index" -> "2")) + } + } + + withSQLConf(CASE_SENSITIVE.key -> "true") { + checkError( + exception = parseRearrangeException( + signature, Seq(upperCaseNamedK1Arg, k2Arg, k3Arg), "foo"), + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("foo"), + "argumentName" -> toSQLId("K1"), + "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3")))) + } + + withSQLConf(CASE_SENSITIVE.key -> "false") { + checkError( + exception = parseRearrangeException( + signature, Seq(upperCaseNamedK1Arg, k2Arg, k3Arg), "foo"), + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k4"), + "index" -> "3")) + } } test("UNRECOGNIZED_PARAMETER_NAME") { - checkError( - exception = parseRearrangeException(signature, - Seq(k1Arg, k2Arg, k3Arg, k4Arg, NamedArgumentExpression("k5", Literal("k5"))), "foo"), - condition = "UNRECOGNIZED_PARAMETER_NAME", - parameters = Map("routineName" -> toSQLId("foo"), "argumentName" -> toSQLId("k5"), - "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3"))) - ) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException(signature, + Seq(k1Arg, k2Arg, k3Arg, k4Arg, NamedArgumentExpression("k5", Literal("k5"))), "foo"), + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("foo"), + "argumentName" -> toSQLId("k5"), + "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3"))) + ) + } + } + + withSQLConf(CASE_SENSITIVE.key -> "true") { + checkError( + exception = parseRearrangeException( + signature, Seq(upperCaseNamedK1Arg, k2Arg, k3Arg, k4Arg), "foo"), + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("foo"), + "argumentName" -> toSQLId("K1"), + "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3")))) + } + + withSQLConf(CASE_SENSITIVE.key -> "false") { + val rearrangedArgs = NamedParametersSupport.defaultRearrange( + signature, Seq(upperCaseNamedK1Arg, k2Arg, k3Arg, k4Arg), "foo", SQLConf.get.resolver) + for ((returnedArg, expectedArg) <- rearrangedArgs.zip(expectedSeq)) { + assert(returnedArg == expectedArg) + } + } } test("UNEXPECTED_POSITIONAL_ARGUMENT") { - checkError( - exception = parseRearrangeException(signature, - Seq(k2Arg, k3Arg, k1Arg, k4Arg), "foo"), - condition = "UNEXPECTED_POSITIONAL_ARGUMENT", - parameters = Map("routineName" -> toSQLId("foo"), "parameterName" -> toSQLId("k3")) - ) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException(signature, + Seq(k4Arg, k3Arg, k1Arg, k2Arg), "foo"), + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k3")) + ) + } + } + + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException(signature, + Seq(upperCaseNamedK4Arg, k3Arg, k1Arg, k2Arg), "foo"), + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("foo"), + "parameterName" -> toSQLId("k3")) + ) + } + } } test("INTERNAL_ERROR: Enforce optional arguments after required arguments") { val errorMessage = s"Routine ${toSQLId("foo")} has an unexpected required argument for" + s" the provided routine signature ${illegalSignature.parameters.mkString("[", ", ", "]")}." + s" All required arguments should come before optional arguments." - checkError( - exception = parseRearrangeException(illegalSignature, args, "foo"), - condition = "INTERNAL_ERROR", - parameters = Map("message" -> errorMessage) - ) + Seq("true", "false").foreach { cs => + withSQLConf(CASE_SENSITIVE.key -> cs) { + checkError( + exception = parseRearrangeException(illegalSignature, args, "foo"), + condition = "INTERNAL_ERROR", + parameters = Map("message" -> errorMessage) + ) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 5ab0d259d83c..037b784c9dd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.classic.ExpressionUtils.expression import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.TableValuedFunctionArgument +import org.apache.spark.sql.internal.{SQLConf, TableValuedFunctionArgument} import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ case class UserDefinedPythonFunction( * - don't have duplicated names * - don't contain positional arguments after named arguments */ - NamedParametersSupport.splitAndCheckNamedArguments(e, name) + NamedParametersSupport.splitAndCheckNamedArguments(e, name, SQLConf.get.resolver) } else if (e.exists(_.isInstanceOf[NamedArgumentExpression])) { throw QueryCompilationErrors.namedArgumentsNotSupported(name) } @@ -121,7 +121,7 @@ case class UserDefinedPythonTableFunction( * - don't have duplicated names * - don't contain positional arguments after named arguments */ - NamedParametersSupport.splitAndCheckNamedArguments(exprs, name) + NamedParametersSupport.splitAndCheckNamedArguments(exprs, name, SQLConf.get.resolver) // Check which argument is a table argument here since it will be replaced with // `UnresolvedAttribute` to construct lateral join.