From 050bbb88af4f23eae587f31af2424a11e48ad44b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 Mar 2023 15:10:40 +0800 Subject: [PATCH 1/5] support parameterized query in subquery and CTE --- .../connect/planner/SparkConnectPlanner.scala | 4 +- .../main/resources/error/error-classes.json | 5 + .../sql/catalyst/analysis/Analyzer.scala | 1 + .../sql/catalyst/analysis/parameters.scala | 96 +++++++++++++++++++ .../sql/catalyst/expressions/parameters.scala | 64 ------------- .../sql/catalyst/analysis/AnalysisSuite.scala | 8 +- .../sql/catalyst/parser/PlanParserSuite.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 6 +- .../apache/spark/sql/ParametersSuite.scala | 68 +++++++++++++ 9 files changed, 180 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5dd0a7ea3097..29d20b650b51 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -32,7 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} @@ -210,7 +210,7 @@ class SparkConnectPlanner(val session: SparkSession) { val args = sql.getArgsMap.asScala.toMap val parser = session.sessionState.sqlParser val parsedArgs = args.mapValues(parser.parseExpression).toMap - Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs) + ParameterizedQuery(parser.parsePlan(sql.getQuery), parsedArgs) } private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = { diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 061074cd47f8..4d847ce9fe0b 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1730,6 +1730,11 @@ "Pandas user defined aggregate function in the PIVOT clause." ] }, + "PARAMETERIZED_COMMAND" : { + "message" : [ + "Query parameters in Command." + ] + }, "PIVOT_AFTER_GROUP_BY" : { "message" : [ "PIVOT clause following a GROUP BY clause. Consider pushing the GROUP BY into a subquery." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d7cc34d6f15b..e5d78b21f19e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -265,6 +265,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // at the beginning of analysis. OptimizeUpdateFields, CTESubstitution, + BindParameters, WindowsSubstitution, EliminateUnions, SubstituteUnresolvedOrdinals), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala new file mode 100644 index 000000000000..2b2d76aec839 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern, UNRESOLVED_WITH} +import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.types.DataType + +/** + * The expression represents a named parameter that should be replaced by a literal. + * + * @param name The identifier of the parameter without the marker. + */ +case class Parameter(name: String) extends LeafExpression with Unevaluable { + override lazy val resolved: Boolean = false + + private def unboundError(methodName: String): Nothing = { + throw SparkException.internalError( + s"Cannot call `$methodName()` of the unbound parameter `$name`.") + } + override def dataType: DataType = unboundError("dataType") + override def nullable: Boolean = unboundError("nullable") + + final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER) +} + +/** + * The logical plan representing a parameterized query. It will be removed during analysis after + * the parameters are bind. + */ +case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode { + override def output: Seq[Attribute] = Nil + override lazy val resolved = false + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(child = newChild) +} + +/** + * Finds all named parameters in `ParameterizedQuery` and substitutes them by literals from the + * user-specified arguments. + */ +object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // No arguments, remove `ParameterizedQuery` directly. + case ParameterizedQuery(child, args) if args.isEmpty => child + + // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE + // relations are not children of `UnresolvedWith`. + case ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => + // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc. + // We can't store the original SQL text with parameters, as we don't store the arguments and + // are not able to resolve it after parsing it back. Since parameterized query is mostly used + // to avoid SQL injection for SELECT queries, we simply forbid commands here. + if (child.exists(_.isInstanceOf[Command])) { + child.failAnalysis( + errorClass = "UNSUPPORTED_FEATURE.PARAMETERIZED_COMMAND", + messageParameters = Map.empty + ) + } + + args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => + expr.failAnalysis( + errorClass = "INVALID_SQL_ARG", + messageParameters = Map("name" -> toSQLId(name))) + } + + def bind(p: LogicalPlan): LogicalPlan = { + p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) { + case Parameter(name) if args.contains(name) => args(name) + case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)) + } + } + bind(child) + + case _ => plan + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala deleted file mode 100644 index fae2b9a1a9f4..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/parameters.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern} -import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.types.DataType - -/** - * The expression represents a named parameter that should be replaced by a literal. - * - * @param name The identifier of the parameter without the marker. - */ -case class Parameter(name: String) extends LeafExpression with Unevaluable { - override lazy val resolved: Boolean = false - - private def unboundError(methodName: String): Nothing = { - throw SparkException.internalError( - s"Cannot call `$methodName()` of the unbound parameter `$name`.") - } - override def dataType: DataType = unboundError("dataType") - override def nullable: Boolean = unboundError("nullable") - - final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETER) -} - - -/** - * Finds all named parameters in the given plan and substitutes them by literals of `args` values. - */ -object Parameter extends QueryErrorsBase { - def bind(plan: LogicalPlan, args: Map[String, Expression]): LogicalPlan = { - if (!args.isEmpty) { - args.filter(!_._2.isInstanceOf[Literal]).headOption.foreach { case (name, expr) => - expr.failAnalysis( - errorClass = "INVALID_SQL_ARG", - messageParameters = Map("name" -> toSQLId(name))) - } - plan.transformAllExpressionsWithPruning(_.containsPattern(PARAMETER)) { - case Parameter(name) if args.contains(name) => args(name) - } - } else { - plan - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0f26d3a2dc94..16555df0be1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1347,14 +1347,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-41271: bind named parameters to literals") { comparePlans( - Parameter.bind( - plan = parsePlan("SELECT * FROM a LIMIT :limitA"), + ParameterizedQuery( + child = parsePlan("SELECT * FROM a LIMIT :limitA"), args = Map("limitA" -> Literal(10))), parsePlan("SELECT * FROM a LIMIT 10")) // Ignore unused arguments comparePlans( - Parameter.bind( - plan = parsePlan("SELECT c FROM a WHERE c < :param2"), + ParameterizedQuery( + child = parsePlan("SELECT c FROM a WHERE c < :param2"), args = Map("param1" -> Literal(10), "param2" -> Literal(20))), parsePlan("SELECT c FROM a WHERE c < 20")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 6fc83d8c7825..3b5a24013358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Parameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.plans._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index adbe593ac56f..6e4bf35a9145 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.{ParameterizedQuery, UnresolvedRelation} import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Parameter} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner @@ -624,7 +624,7 @@ class SparkSession private( val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parser = sessionState.sqlParser val parsedArgs = args.mapValues(parser.parseExpression).toMap - Parameter.bind(parser.parsePlan(sqlText), parsedArgs) + ParameterizedQuery(parser.parsePlan(sqlText), parsedArgs) } Dataset.ofRows(self, plan, tracker) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 668a1e4ad7d9..c56cc892cb18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -38,6 +38,74 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(true)) } + test("parameters in CTE") { + val sqlText = + """ + |WITH w1 AS (SELECT :p1 AS p) + |SELECT p + :p2 FROM w1 + |""".stripMargin + val args = Map("p1" -> "1", "p2" -> "2") + checkAnswer( + spark.sql(sqlText, args), + Row(3)) + } + + test("parameters in nested CTE") { + val sqlText = + """ + |WITH w1 AS + | (WITH w2 AS (SELECT :p1 AS p) SELECT p + :p2 AS p2 FROM w2) + |SELECT p2 + :p3 FROM w1 + |""".stripMargin + val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3") + checkAnswer( + spark.sql(sqlText, args), + Row(6)) + } + + test("parameters in subquery expression") { + val sqlText = "SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2" + val args = Map("p1" -> "1", "p2" -> "2") + checkAnswer( + spark.sql(sqlText, args), + Row(12)) + } + + test("parameters in nested subquery expression") { + val sqlText = "SELECT (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2) + :p3" + val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3") + checkAnswer( + spark.sql(sqlText, args), + Row(15)) + } + + test("parameters in subquery expression inside CTE") { + val sqlText = + """ + |WITH w1 AS (SELECT (SELECT max(id) + :p1 FROM range(10)) + :p2 AS p) + |SELECT p + :p3 FROM w1 + |""".stripMargin + val args = Map("p1" -> "1", "p2" -> "2", "p3" -> "3") + checkAnswer( + spark.sql(sqlText, args), + Row(15)) + } + + test("parameters not allowed in commands") { + val sqlText = "CREATE VIEW v AS SELECT :p AS p" + val args = Map("p" -> "1") + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlText, args) + }, + errorClass = "UNSUPPORTED_FEATURE.PARAMETERIZED_COMMAND", + parameters = Map.empty, + context = ExpectedContext( + fragment = "CREATE VIEW v AS SELECT :p AS p", + start = 0, + stop = sqlText.length - 1)) + } + test("non-substituted parameters") { checkError( exception = intercept[AnalysisException] { From c0a1fd46ac3e1d8c3761ee6a7146feb3e5e1bc4d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 Mar 2023 23:51:03 +0800 Subject: [PATCH 2/5] address comments --- .../main/resources/error/error-classes.json | 4 ++-- .../sql/catalyst/analysis/parameters.scala | 20 ++++++++++++------- .../apache/spark/sql/ParametersSuite.scala | 14 ++++++++++--- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 4d847ce9fe0b..34026083bb9a 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1730,9 +1730,9 @@ "Pandas user defined aggregate function in the PIVOT clause." ] }, - "PARAMETERIZED_COMMAND" : { + "PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT" : { "message" : [ - "Query parameters in Command." + "Parameter markers in unexpected statement: . Parameter markers must only be used in a query, or DML statement." ] }, "PIVOT_AFTER_GROUP_BY" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index 2b2d76aec839..f1b5ff2c02d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable} -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase @@ -69,12 +69,18 @@ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc. // We can't store the original SQL text with parameters, as we don't store the arguments and // are not able to resolve it after parsing it back. Since parameterized query is mostly used - // to avoid SQL injection for SELECT queries, we simply forbid commands here. - if (child.exists(_.isInstanceOf[Command])) { - child.failAnalysis( - errorClass = "UNSUPPORTED_FEATURE.PARAMETERIZED_COMMAND", - messageParameters = Map.empty - ) + // to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here. + child match { + case _: InsertIntoStatement => // OK + case _: UpdateTable => // OK + case _: DeleteFromTable => // OK + case _: MergeIntoTable => // OK + case cmd: Command => + child.failAnalysis( + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + messageParameters = Map("statement" -> cmd.nodeName) + ) + case _ => // OK } args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index c56cc892cb18..512a52058195 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -91,15 +91,23 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(15)) } - test("parameters not allowed in commands") { + test("parameters in INSERT") { + withTable("t") { + sql("CREATE TABLE t (col INT) USING json") + spark.sql("INSERT INTO t SELECT :p", Map("p" -> "1")) + checkAnswer(spark.table("t"), Row(1)) + } + } + + test("parameters not allowed in DDL commands") { val sqlText = "CREATE VIEW v AS SELECT :p AS p" val args = Map("p" -> "1") checkError( exception = intercept[AnalysisException] { spark.sql(sqlText, args) }, - errorClass = "UNSUPPORTED_FEATURE.PARAMETERIZED_COMMAND", - parameters = Map.empty, + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "CreateView"), context = ExpectedContext( fragment = "CREATE VIEW v AS SELECT :p AS p", start = 0, From f8522e29fd778a55eca92c6aae8fcfec1719b40e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Mar 2023 09:58:20 +0800 Subject: [PATCH 3/5] fix --- .../connect/planner/SparkConnectPlanner.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../sql/catalyst/analysis/parameters.scala | 80 +++++++++++-------- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../org/apache/spark/sql/SparkSession.scala | 8 +- .../apache/spark/sql/ParametersSuite.scala | 24 +++++- 6 files changed, 80 insertions(+), 43 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 29d20b650b51..24717e07b002 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -209,8 +209,12 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformSql(sql: proto.SQL): LogicalPlan = { val args = sql.getArgsMap.asScala.toMap val parser = session.sessionState.sqlParser - val parsedArgs = args.mapValues(parser.parseExpression).toMap - ParameterizedQuery(parser.parsePlan(sql.getQuery), parsedArgs) + val parsedPlan = parser.parsePlan(sql.getQuery) + if (args.nonEmpty) { + ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap) + } else { + parsedPlan + } } private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fafdd679aa55..321d8dfd1ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -336,7 +336,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case p: Parameter => p.failAnalysis( errorClass = "UNBOUND_SQL_PARAMETER", - messageParameters = Map("name" -> toSQLId(p.name))) + messageParameters = Map("name" -> p.name)) case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index f1b5ff2c02d8..29c36300673b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LeafExpression, Literal, SubqueryExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, InsertIntoStatement, LogicalPlan, MergeIntoTable, UnaryNode, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, TreePattern, UNRESOLVED_WITH} +import org.apache.spark.sql.catalyst.trees.TreePattern.{PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.DataType @@ -48,8 +48,10 @@ case class Parameter(name: String) extends LeafExpression with Unevaluable { * the parameters are bind. */ case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) extends UnaryNode { + assert(args.nonEmpty) override def output: Seq[Attribute] = Nil override lazy val resolved = false + final override val nodePatterns: Seq[TreePattern] = Seq(PARAMETERIZED_QUERY) override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(child = newChild) } @@ -59,44 +61,52 @@ case class ParameterizedQuery(child: LogicalPlan, args: Map[String, Expression]) * user-specified arguments. */ object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // No arguments, remove `ParameterizedQuery` directly. - case ParameterizedQuery(child, args) if args.isEmpty => child + override def apply(plan: LogicalPlan): LogicalPlan = { + if (plan.containsPattern(PARAMETERIZED_QUERY)) { + // One unresolved plan can have at most one ParameterizedQuery. + val parameterizedQueries = plan.collect { case p: ParameterizedQuery => p } + assert(parameterizedQueries.length == 1) + } - // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE - // relations are not children of `UnresolvedWith`. - case ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => - // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc. - // We can't store the original SQL text with parameters, as we don't store the arguments and - // are not able to resolve it after parsing it back. Since parameterized query is mostly used - // to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here. - child match { - case _: InsertIntoStatement => // OK - case _: UpdateTable => // OK - case _: DeleteFromTable => // OK - case _: MergeIntoTable => // OK - case cmd: Command => - child.failAnalysis( - errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", - messageParameters = Map("statement" -> cmd.nodeName) - ) - case _ => // OK - } + plan.resolveOperatorsWithPruning(_.containsPattern(PARAMETERIZED_QUERY)) { + // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE + // relations are not children of `UnresolvedWith`. + case p @ ParameterizedQuery(child, args) if !child.containsPattern(UNRESOLVED_WITH) => + // Some commands may store the original SQL text, like CREATE VIEW, GENERATED COLUMN, etc. + // We can't store the original SQL text with parameters, as we don't store the arguments and + // are not able to resolve it after parsing it back. Since parameterized query is mostly + // used to avoid SQL injection for SELECT queries, we simply forbid non-DML commands here. + child match { + case _: InsertIntoStatement => // OK + case _: UpdateTable => // OK + case _: DeleteFromTable => // OK + case _: MergeIntoTable => // OK + case cmd: Command => + child.failAnalysis( + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + messageParameters = Map("statement" -> cmd.nodeName) + ) + case _ => // OK + } - args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => - expr.failAnalysis( - errorClass = "INVALID_SQL_ARG", - messageParameters = Map("name" -> toSQLId(name))) - } + args.find(!_._2.isInstanceOf[Literal]).foreach { case (name, expr) => + expr.failAnalysis( + errorClass = "INVALID_SQL_ARG", + messageParameters = Map("name" -> name)) + } - def bind(p: LogicalPlan): LogicalPlan = { - p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) { - case Parameter(name) if args.contains(name) => args(name) - case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)) + def bind(p: LogicalPlan): LogicalPlan = { + p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) { + case Parameter(name) if args.contains(name) => + args(name) + case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)) + } } - } - bind(child) + val res = bind(child) + res.copyTagsFrom(p) + res - case _ => plan + case _ => plan + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 11b47b7d5c8a..37d3ada53494 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -73,6 +73,7 @@ object TreePattern extends Enumeration { val OR: Value = Value val OUTER_REFERENCE: Value = Value val PARAMETER: Value = Value + val PARAMETERIZED_QUERY: Value = Value val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6e4bf35a9145..066e609a6d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -623,8 +623,12 @@ class SparkSession private( val tracker = new QueryPlanningTracker val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parser = sessionState.sqlParser - val parsedArgs = args.mapValues(parser.parseExpression).toMap - ParameterizedQuery(parser.parsePlan(sqlText), parsedArgs) + val parsedPlan = parser.parsePlan(sqlText) + if (args.nonEmpty) { + ParameterizedQuery(parsedPlan, args.mapValues(parser.parseExpression).toMap) + } else { + parsedPlan + } } Dataset.ofRows(self, plan, tracker) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 512a52058195..e6e5eb9fac4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -38,6 +38,24 @@ class ParametersSuite extends QueryTest with SharedSparkSession { Row(true)) } + test("parameter binding is case sensitive") { + checkAnswer( + spark.sql("SELECT :p, :P", Map("p" -> "1", "P" -> "2")), + Row(1, 2) + ) + + checkError( + exception = intercept[AnalysisException] { + spark.sql("select :P", Map("p" -> "1")) + }, + errorClass = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "P"), + context = ExpectedContext( + fragment = ":P", + start = 7, + stop = 8)) + } + test("parameters in CTE") { val sqlText = """ @@ -120,7 +138,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession { spark.sql("select :abc, :def", Map("abc" -> "1")) }, errorClass = "UNBOUND_SQL_PARAMETER", - parameters = Map("name" -> "`def`"), + parameters = Map("name" -> "def"), context = ExpectedContext( fragment = ":def", start = 13, @@ -130,7 +148,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession { sql("select :abc").collect() }, errorClass = "UNBOUND_SQL_PARAMETER", - parameters = Map("name" -> "`abc`"), + parameters = Map("name" -> "abc"), context = ExpectedContext( fragment = ":abc", start = 7, @@ -144,7 +162,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession { spark.sql("SELECT :param1 FROM VALUES (1) AS t(col1)", Map("param1" -> arg)) }, errorClass = "INVALID_SQL_ARG", - parameters = Map("name" -> "`param1`"), + parameters = Map("name" -> "param1"), context = ExpectedContext( fragment = arg, start = 0, From 5d2f0aed525cf6cd05dd9179ddb9be4a95409533 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Mar 2023 22:26:23 +0800 Subject: [PATCH 4/5] fix tests --- .../sql/catalyst/analysis/AnalysisSuite.scala | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 16555df0be1d..54ea4086c9b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1346,17 +1346,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-41271: bind named parameters to literals") { - comparePlans( - ParameterizedQuery( - child = parsePlan("SELECT * FROM a LIMIT :limitA"), - args = Map("limitA" -> Literal(10))), - parsePlan("SELECT * FROM a LIMIT 10")) + CTERelationDef.curId.set(0) + val actual1 = ParameterizedQuery( + child = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT :limitA"), + args = Map("limitA" -> Literal(10))).analyze + CTERelationDef.curId.set(0) + val expected1 = parsePlan("WITH a AS (SELECT 1 c) SELECT * FROM a LIMIT 10").analyze + comparePlans(actual1, expected1) // Ignore unused arguments - comparePlans( - ParameterizedQuery( - child = parsePlan("SELECT c FROM a WHERE c < :param2"), - args = Map("param1" -> Literal(10), "param2" -> Literal(20))), - parsePlan("SELECT c FROM a WHERE c < 20")) + CTERelationDef.curId.set(0) + val actual2 = ParameterizedQuery( + child = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < :param2"), + args = Map("param1" -> Literal(10), "param2" -> Literal(20))).analyze + CTERelationDef.curId.set(0) + val expected2 = parsePlan("WITH a AS (SELECT 1 c) SELECT c FROM a WHERE c < 20").analyze + comparePlans(actual2, expected2) } test("SPARK-41489: type of filter expression should be a bool") { From 2a989069048d169ea36d1a37bbe39b00cd21009a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 10 Mar 2023 10:59:16 +0800 Subject: [PATCH 5/5] fix compilation --- .../apache/spark/sql/errors/QueryExecutionErrorsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index d6a310df39e8..03a310b9e695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -29,8 +29,8 @@ import org.mockito.Mockito.{mock, spy, when} import org.apache.spark._ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator -import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, Parameter} +import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator} +import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean import org.apache.spark.sql.catalyst.util.BadRecordException