Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2360,7 +2360,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val inputType = extractInputType(args)
val bound = unbound.bind(inputType)
validateParameterModes(bound)
val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args)
val rearrangedArgs =
NamedParametersSupport.defaultRearrange(bound, args, SQLConf.get.resolver)
Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -1024,9 +1025,9 @@ object FunctionRegistry {
name: String,
builder: T,
expressions: Seq[Expression]) : Seq[Expression] = {
val rearrangedExpressions = if (!builder.functionSignature.isEmpty) {
val rearrangedExpressions = if (builder.functionSignature.isDefined) {
val functionSignature = builder.functionSignature.get
builder.rearrange(functionSignature, expressions, name)
builder.rearrange(functionSignature, expressions, name, SQLConf.get.resolver)
} else {
expressions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1936,7 +1936,7 @@ class SessionCatalog(
}

NamedParametersSupport.defaultRearrange(
FunctionSignature(paramNames), expressions, functionName)
FunctionSignature(paramNames), expressions, functionName, SQLConf.get.resolver)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter}
Expand Down Expand Up @@ -67,8 +68,10 @@ trait FunctionBuilderBase[T] {
def rearrange(
expectedSignature: FunctionSignature,
providedArguments: Seq[Expression],
functionName: String) : Seq[Expression] = {
NamedParametersSupport.defaultRearrange(expectedSignature, providedArguments, functionName)
functionName: String,
resolver: Resolver) : Seq[Expression] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fix the javadoc to add the new args (in this and 'defaultRearrange')

NamedParametersSupport.defaultRearrange(
expectedSignature, providedArguments, functionName, resolver)
}

def build(funcName: String, expressions: Seq[Expression]): T
Expand All @@ -89,15 +92,17 @@ object NamedParametersSupport {
*/
def splitAndCheckNamedArguments(
args: Seq[Expression],
functionName: String): (Seq[Expression], Seq[NamedArgumentExpression]) = {
functionName: String,
resolver: Resolver):
(Seq[Expression], Seq[NamedArgumentExpression]) = {
val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression])

val namedParametersSet = collection.mutable.Set[String]()

(positionalArgs,
namedArgs.zipWithIndex.map {
case (namedArg @ NamedArgumentExpression(parameterName, _), _) =>
if (namedParametersSet.contains(parameterName)) {
if (namedParametersSet.exists(resolver(_, parameterName))) {
throw QueryCompilationErrors.doubleNamedArgumentReference(
functionName, parameterName)
}
Expand All @@ -123,15 +128,20 @@ object NamedParametersSupport {
final def defaultRearrange(
functionSignature: FunctionSignature,
args: Seq[Expression],
functionName: String): Seq[Expression] = {
defaultRearrange(functionName, functionSignature.parameters, args)
functionName: String,
resolver: Resolver): Seq[Expression] = {
defaultRearrange(functionName, functionSignature.parameters, args, resolver)
}

final def defaultRearrange(procedure: BoundProcedure, args: Seq[Expression]): Seq[Expression] = {
final def defaultRearrange(
procedure: BoundProcedure,
args: Seq[Expression],
resolver: Resolver): Seq[Expression] = {
defaultRearrange(
procedure.name,
procedure.parameters.map(toInputParameter).toSeq,
args)
args,
resolver)
}

private def toInputParameter(param: ProcedureParameter): InputParameter = {
Expand All @@ -144,12 +154,13 @@ object NamedParametersSupport {
private def defaultRearrange(
routineName: String,
parameters: Seq[InputParameter],
args: Seq[Expression]): Seq[Expression] = {
args: Seq[Expression],
resolver: Resolver): Seq[Expression] = {
if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) {
throw QueryCompilationErrors.unexpectedRequiredParameter(routineName, parameters)
}

val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName)
val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName, resolver)
val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size)

// The following loop checks for the following:
Expand All @@ -161,11 +172,11 @@ object NamedParametersSupport {

namedArgs.foreach { namedArg =>
val parameterName = namedArg.key
if (!parameterNamesSet.contains(parameterName)) {
if (!parameterNamesSet.exists(resolver(_, parameterName))) {
throw QueryCompilationErrors.unrecognizedParameterName(routineName, namedArg.key,
parameterNamesSet.toSeq)
}
if (positionalParametersSet.contains(parameterName)) {
if (positionalParametersSet.exists(resolver(_, parameterName))) {
throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference(
routineName, namedArg.key)
}
Expand All @@ -187,14 +198,13 @@ object NamedParametersSupport {
// We rearrange named arguments to match their positional order.
val rearrangedNamedArgs: Seq[Expression] = namedParameters.zipWithIndex.map {
case (param, index) =>
namedArgMap.getOrElse(
param.name,
namedArgMap.view.filterKeys(resolver(_, param.name)).headOption.map(_._2).getOrElse {
if (param.default.isEmpty) {
throw QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index)
} else {
param.default.get
}
)
}
}
val rearrangedArgs = positionalArgs ++ rearrangedNamedArgs
assert(rearrangedArgs.size == parameters.size)
Expand Down
Loading