Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
Expand Down Expand Up @@ -209,8 +209,12 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformSql(sql: proto.SQL): LogicalPlan = {
val args = sql.getArgsMap.asScala.toMap
val parser = session.sessionState.sqlParser
val parsedArgs = args.mapValues(parser.parseExpression).toMap
Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
val parsedPlan = parser.parsePlan(sql.getQuery)
if (args.nonEmpty) {
ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
} else {
parsedPlan
}
}

private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = {
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,11 @@
"Pandas user defined aggregate function in the PIVOT clause."
]
},
"PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT" : {
"message" : [
"Parameter markers in unexpected statement: <statement>. Parameter markers must only be used in a query, or DML statement."
]
},
"PIVOT_AFTER_GROUP_BY" : {
"message" : [
"PIVOT clause following a GROUP BY clause. Consider pushing the GROUP BY into a subquery."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// at the beginning of analysis.
OptimizeUpdateFields,
CTESubstitution,
BindParameters,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case p: Parameter =>
p.failAnalysis(
errorClass = "UNBOUND_SQL_PARAMETER",
messageParameters = Map("name" -> toSQLId(p.name)))
messageParameters = Map("name" -> p.name))

case _ =>
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types.DataType

/**
* The expression represents a named parameter that should be replaced by a literal.
*
* @param name The identifier of the parameter without the marker.
*/
case class Parameter(name: String) extends LeafExpression with Unevaluable {
override lazy val resolved: Boolean = false

private def unboundError(methodName: String): Nothing = {
throw SparkException.internalError(
s"Cannot call `$methodName()` of the unbound parameter `$name`.")
}
override def dataType: DataType = unboundError("dataType")
override def nullable: Boolean = unboundError("nullable")

final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER)
}

/**
* The logical plan representing a parameterized query. It will be removed during analysis after
* the parameters are bind.
*/
case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind, can we have a negative test case for this case-sensitive?

assert(args.nonEmpty)
override def output: Seq[Attribute] = Nil
override lazy val resolved = false
final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY)
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(child = newChild)
}

/**
* Finds all named parameters in `ParameterizedQuery` and substitutes them by literals from the
* user-specified arguments.
*/
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (plan.containsPattern(PARAMETERIZED_QUERY)) {
// One unresolved plan can have at most one ParameterizedQuery.
val parameterizedQueries = plan.collect { case p: ParameterizedQuery => p }
assert(parameterizedQueries.length == 1)
}

plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) {
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
// relations are not children of `UnresolvedWith`.
case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) =>
// Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc.
// We can't store the original SQL text with parameters, as we don't store the arguments and
// are not able to resolve it after parsing it back. Since parameterized query is mostly
// used to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here.
child match {
case _: InsertIntoStatement => // OK
case _: UpdateTable => // OK
case _: DeleteFromTable => // OK
case _: MergeIntoTable => // OK
case cmd: Command =>
child.failAnalysis(
errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT",
messageParameters = Map("statement" -> cmd.nodeName)
)
case _ => // OK
}

args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) =>
expr.failAnalysis(
errorClass = "INVALID_SQL_ARG",
messageParameters = Map("name" -> name))
}

def bind(p: LogicalPlan): LogicalPlan = {
p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) {
case Parameter(name) if args.contains(name) =>
args(name)
case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan))
}
}
val res = bind(child)
res.copyTagsFrom(p)
res

case _ => plan
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object TreePattern extends Enumeration {
val OR: Value = Value
val OUTER_REFERENCE: Value = Value
val PARAMETER: Value = Value
val PARAMETERIZED_QUERY: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1346,17 +1346,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-41271: bind named parameters to literals") {
comparePlans(
Parameter.bind(
plan = parsePlan("SELECT * FROM a LIMIT :limitA"),
args = Map("limitA" -> Literal(10))),
parsePlan("SELECT * FROM a LIMIT 10"))
CTERelationDef.curId.set(0)
val actual1 = ParameterizedQuery(
child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"),
args = Map("limitA" -> Literal(10))).analyze
CTERelationDef.curId.set(0)
val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze
comparePlans(actual1, expected1)
// Ignore unused arguments
comparePlans(
Parameter.bind(
plan = parsePlan("SELECT c FROM a WHERE c < :param2"),
args = Map("param1" -> Literal(10), "param2" -> Literal(20))),
parsePlan("SELECT c FROM a WHERE c < 20"))
CTERelationDef.curId.set(0)
val actual2 = ParameterizedQuery(
child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"),
args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze
CTERelationDef.curId.set(0)
val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze
comparePlans(actual2, expected2)
}

test("SPARK-41489: type of filter expression should be a bool") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.plans._
Expand Down
12 changes: 8 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.ExternalCommandRunner
Expand Down Expand Up @@ -623,8 +623,12 @@ class SparkSession private(
val tracker = new QueryPlanningTracker
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parser = sessionState.sqlParser
val parsedArgs = args.mapValues(parser.parseExpression).toMap
Parameter.bind(parser.parsePlan(sqlText), parsedArgs)
val parsedPlan = parser.parsePlan(sqlText)
if (args.nonEmpty) {
ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap)
} else {
parsedPlan
}
}
Dataset.ofRows(self, plan, tracker)
}
Expand Down
Loading