diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index b98c7436f9906..a295ef06a6376 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -275,7 +275,8 @@ private[spark] object ConfigEntry { val UNDEFINED = "" - private val knownConfigs = new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]() + private[spark] val knownConfigs = + new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]() def registerEntry(entry: ConfigEntry[_]): Unit = { val existing = knownConfigs.putIfAbsent(entry.key, entry) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 39f94651a0cb5..6fce7819897a6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -246,11 +246,17 @@ statement | SET TIME ZONE interval #setTimeZone | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone | SET TIME ZONE .*? #setTimeZone + | SET configKey (EQ .*?)? #setQuotedConfiguration | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration | RESET .*? #resetConfiguration | unsupportedHiveNativeCommands .*? #failNativeCommand ; +configKey + : quotedIdentifier + ; + unsupportedHiveNativeCommands : kw1=CREATE kw2=ROLE | kw1=DROP kw2=ROLE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 012ae0a76043c..129312160b1b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -58,6 +58,9 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) { class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ + private val configKeyValueDef = """([a-zA-Z_\d\\.:]+)\s*=(.*)""".r + private val configKeyDef = """([a-zA-Z_\d\\.:]+)$""".r + /** * Create a [[SetCommand]] logical plan. * @@ -66,17 +69,28 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * character in the raw string. */ override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { - // Construct the command. - val raw = remainder(ctx.SET.getSymbol) - val keyValueSeparatorIndex = raw.indexOf('=') - if (keyValueSeparatorIndex >= 0) { - val key = raw.substring(0, keyValueSeparatorIndex).trim - val value = raw.substring(keyValueSeparatorIndex + 1).trim - SetCommand(Some(key -> Option(value))) - } else if (raw.nonEmpty) { - SetCommand(Some(raw.trim -> None)) + remainder(ctx.SET.getSymbol).trim match { + case configKeyValueDef(key, value) => + SetCommand(Some(key -> Option(value.trim))) + case configKeyDef(key) => + SetCommand(Some(key -> None)) + case s if s == "-v" => + SetCommand(Some("-v" -> None)) + case s if s.isEmpty => + SetCommand(None) + case _ => throw new ParseException("Expected format is 'SET', 'SET key', or " + + "'SET key=value'. If you want to include special characters in key, " + + "please use quotes, e.g., SET `ke y`=value.", ctx) + } + } + + override def visitSetQuotedConfiguration(ctx: SetQuotedConfigurationContext) + : LogicalPlan = withOrigin(ctx) { + val keyStr = ctx.configKey().getText + if (ctx.EQ() != null) { + SetCommand(Some(keyStr -> Option(remainder(ctx.EQ().getSymbol).trim))) } else { - SetCommand(None) + SetCommand(Some(keyStr -> None)) } } @@ -90,7 +104,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { */ override def visitResetConfiguration( ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { - ResetCommand(Option(remainder(ctx.RESET().getSymbol).trim).filter(_.nonEmpty)) + remainder(ctx.RESET.getSymbol).trim match { + case configKeyDef(key) => + ResetCommand(Some(key)) + case s if s.trim.isEmpty => + ResetCommand(None) + case _ => throw new ParseException("Expected format is 'RESET' or 'RESET key'. " + + "If you want to include special characters in key, " + + "please use quotes, e.g., RESET `ke y`.", ctx) + } + } + + override def visitResetQuotedConfiguration( + ctx: ResetQuotedConfigurationContext): LogicalPlan = withOrigin(ctx) { + ResetCommand(Some(ctx.configKey().getText)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 1991f139e48c5..62712cf72eb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} @@ -25,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource} -import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} /** @@ -61,6 +64,82 @@ class SparkSqlParserSuite extends AnalysisTest { private def intercept(sqlCommand: String, messages: String*): Unit = interceptParseException(parser.parsePlan)(sqlCommand, messages: _*) + test("Checks if SET/RESET can parse all the configurations") { + // Force to build static SQL configurations + StaticSQLConf + ConfigEntry.knownConfigs.values.asScala.foreach { config => + assertEqual(s"SET ${config.key}", SetCommand(Some(config.key -> None))) + if (config.defaultValue.isDefined && config.defaultValueString != null) { + assertEqual(s"SET ${config.key}=${config.defaultValueString}", + SetCommand(Some(config.key -> Some(config.defaultValueString)))) + } + assertEqual(s"RESET ${config.key}", ResetCommand(Some(config.key))) + } + } + + test("Report Error for invalid usage of SET command") { + assertEqual("SET", SetCommand(None)) + assertEqual("SET -v", SetCommand(Some("-v", None))) + assertEqual("SET spark.sql.key", SetCommand(Some("spark.sql.key" -> None))) + assertEqual("SET spark.sql.key ", SetCommand(Some("spark.sql.key" -> None))) + assertEqual("SET spark:sql:key=false", SetCommand(Some("spark:sql:key" -> Some("false")))) + assertEqual("SET spark:sql:key=", SetCommand(Some("spark:sql:key" -> Some("")))) + assertEqual("SET spark:sql:key= ", SetCommand(Some("spark:sql:key" -> Some("")))) + assertEqual("SET spark:sql:key=-1 ", SetCommand(Some("spark:sql:key" -> Some("-1")))) + assertEqual("SET spark:sql:key = -1", SetCommand(Some("spark:sql:key" -> Some("-1")))) + assertEqual("SET 1.2.key=value", SetCommand(Some("1.2.key" -> Some("value")))) + assertEqual("SET spark.sql.3=4", SetCommand(Some("spark.sql.3" -> Some("4")))) + assertEqual("SET 1:2:key=value", SetCommand(Some("1:2:key" -> Some("value")))) + assertEqual("SET spark:sql:3=4", SetCommand(Some("spark:sql:3" -> Some("4")))) + assertEqual("SET 5=6", SetCommand(Some("5" -> Some("6")))) + assertEqual("SET spark:sql:key = va l u e ", + SetCommand(Some("spark:sql:key" -> Some("va l u e")))) + assertEqual("SET `spark.sql. key`=value", + SetCommand(Some("spark.sql. key" -> Some("value")))) + assertEqual("SET `spark.sql. key`= v a lu e ", + SetCommand(Some("spark.sql. key" -> Some("v a lu e")))) + assertEqual("SET `spark.sql. key`= -1", + SetCommand(Some("spark.sql. key" -> Some("-1")))) + + val expectedErrMsg = "Expected format is 'SET', 'SET key', or " + + "'SET key=value'. If you want to include special characters in key, " + + "please use quotes, e.g., SET `ke y`=value." + intercept("SET spark.sql.key value", expectedErrMsg) + intercept("SET spark.sql.key 'value'", expectedErrMsg) + intercept("SET spark.sql.key \"value\" ", expectedErrMsg) + intercept("SET spark.sql.key value1 value2", expectedErrMsg) + intercept("SET spark. sql.key=value", expectedErrMsg) + intercept("SET spark :sql:key=value", expectedErrMsg) + intercept("SET spark . sql.key=value", expectedErrMsg) + intercept("SET spark.sql. key=value", expectedErrMsg) + intercept("SET spark.sql :key=value", expectedErrMsg) + intercept("SET spark.sql . key=value", expectedErrMsg) + } + + test("Report Error for invalid usage of RESET command") { + assertEqual("RESET", ResetCommand(None)) + assertEqual("RESET spark.sql.key", ResetCommand(Some("spark.sql.key"))) + assertEqual("RESET spark.sql.key ", ResetCommand(Some("spark.sql.key"))) + assertEqual("RESET 1.2.key ", ResetCommand(Some("1.2.key"))) + assertEqual("RESET spark.sql.3", ResetCommand(Some("spark.sql.3"))) + assertEqual("RESET 1:2:key ", ResetCommand(Some("1:2:key"))) + assertEqual("RESET spark:sql:3", ResetCommand(Some("spark:sql:3"))) + assertEqual("RESET `spark.sql. key`", ResetCommand(Some("spark.sql. key"))) + + val expectedErrMsg = "Expected format is 'RESET' or 'RESET key'. " + + "If you want to include special characters in key, " + + "please use quotes, e.g., RESET `ke y`." + intercept("RESET spark.sql.key1 key2", expectedErrMsg) + intercept("RESET spark. sql.key1 key2", expectedErrMsg) + intercept("RESET spark.sql.key1 key2 key3", expectedErrMsg) + intercept("RESET spark: sql:key", expectedErrMsg) + intercept("RESET spark .sql.key", expectedErrMsg) + intercept("RESET spark : sql:key", expectedErrMsg) + intercept("RESET spark.sql: key", expectedErrMsg) + intercept("RESET spark.sql .key", expectedErrMsg) + intercept("RESET spark.sql : key", expectedErrMsg) + } + test("refresh resource") { assertEqual("REFRESH prefix_path", RefreshResource("prefix_path")) assertEqual("REFRESH /", RefreshResource("/")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 0ecc5ee04ce16..0da8c6c4ef5fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -115,7 +115,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { - sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS.key}=$original") } } @@ -146,7 +146,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL)) assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0) } finally { - sql(s"set ${SQLConf.GROUP_BY_ORDINAL}=$original") + sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=$original") } } @@ -162,7 +162,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 0) } finally { - sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS}=$original") + sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=$original") } }