Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ private[spark] object ConfigEntry {

val UNDEFINED = "<undefined>"

private val knownConfigs = new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]()
private[spark] val knownConfigs =
new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]()

def registerEntry(entry: ConfigEntry[_]): Unit = {
val existing = knownConfigs.putIfAbsent(entry.key, entry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,17 @@ statement
| SET TIME ZONE interval #setTimeZone
| SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone
| SET TIME ZONE .*? #setTimeZone
| SET configKey (EQ .*?)? #setQuotedConfiguration
| SET .*? #setConfiguration
| RESET configKey #resetQuotedConfiguration
| RESET .*? #resetConfiguration
| unsupportedHiveNativeCommands .*? #failNativeCommand
;

configKey
: quotedIdentifier
;

unsupportedHiveNativeCommands
: kw1=CREATE kw2=ROLE
| kw1=DROP kw2=ROLE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) {
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
import org.apache.spark.sql.catalyst.parser.ParserUtils._

private val configKeyValueDef = """([a-zA-Z_\d\\.:]+)\s*=(.*)""".r
private val configKeyDef = """([a-zA-Z_\d\\.:]+)$""".r

/**
* Create a [[SetCommand]] logical plan.
*
Expand All @@ -66,17 +69,28 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* character in the raw string.
*/
override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) {
// Construct the command.
val raw = remainder(ctx.SET.getSymbol)
val keyValueSeparatorIndex = raw.indexOf('=')
if (keyValueSeparatorIndex >= 0) {
val key = raw.substring(0, keyValueSeparatorIndex).trim
val value = raw.substring(keyValueSeparatorIndex + 1).trim
SetCommand(Some(key -> Option(value)))
} else if (raw.nonEmpty) {
SetCommand(Some(raw.trim -> None))
remainder(ctx.SET.getSymbol).trim match {
case configKeyValueDef(key, value) =>
SetCommand(Some(key -> Option(value.trim)))
case configKeyDef(key) =>
SetCommand(Some(key -> None))
case s if s == "-v" =>
SetCommand(Some("-v" -> None))
case s if s.isEmpty =>
SetCommand(None)
case _ => throw new ParseException("Expected format is 'SET', 'SET key', or " +
"'SET key=value'. If you want to include special characters in key, " +
"please use quotes, e.g., SET `ke y`=value.", ctx)
}
}

override def visitSetQuotedConfiguration(ctx: SetQuotedConfigurationContext)
: LogicalPlan = withOrigin(ctx) {
val keyStr = ctx.configKey().getText
if (ctx.EQ() != null) {
SetCommand(Some(keyStr -> Option(remainder(ctx.EQ().getSymbol).trim)))
} else {
SetCommand(None)
SetCommand(Some(keyStr -> None))
}
}

Expand All @@ -90,7 +104,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
*/
override def visitResetConfiguration(
ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) {
ResetCommand(Option(remainder(ctx.RESET().getSymbol).trim).filter(_.nonEmpty))
remainder(ctx.RESET.getSymbol).trim match {
case configKeyDef(key) =>
ResetCommand(Some(key))
case s if s.trim.isEmpty =>
ResetCommand(None)
case _ => throw new ParseException("Expected format is 'RESET' or 'RESET key'. " +
"If you want to include special characters in key, " +
"please use quotes, e.g., RESET `ke y`.", ctx)
}
}

override def visitResetQuotedConfiguration(
ctx: ResetQuotedConfigurationContext): LogicalPlan = withOrigin(ctx) {
ResetCommand(Some(ctx.configKey().getText))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution

import scala.collection.JavaConverters._

import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
Expand All @@ -25,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource}
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

/**
Expand Down Expand Up @@ -61,6 +64,82 @@ class SparkSqlParserSuite extends AnalysisTest {
private def intercept(sqlCommand: String, messages: String*): Unit =
interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)

test("Checks if SET/RESET can parse all the configurations") {
// Force to build static SQL configurations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for the Spark confs, too.

StaticSQLConf
ConfigEntry.knownConfigs.values.asScala.foreach { config =>
assertEqual(s"SET ${config.key}", SetCommand(Some(config.key -> None)))
if (config.defaultValue.isDefined && config.defaultValueString != null) {
assertEqual(s"SET ${config.key}=${config.defaultValueString}",
SetCommand(Some(config.key -> Some(config.defaultValueString))))
}
assertEqual(s"RESET ${config.key}", ResetCommand(Some(config.key)))
}
}

test("Report Error for invalid usage of SET command") {
assertEqual("SET", SetCommand(None))
assertEqual("SET -v", SetCommand(Some("-v", None)))
assertEqual("SET spark.sql.key", SetCommand(Some("spark.sql.key" -> None)))
assertEqual("SET spark.sql.key ", SetCommand(Some("spark.sql.key" -> None)))
assertEqual("SET spark:sql:key=false", SetCommand(Some("spark:sql:key" -> Some("false"))))
assertEqual("SET spark:sql:key=", SetCommand(Some("spark:sql:key" -> Some(""))))
assertEqual("SET spark:sql:key= ", SetCommand(Some("spark:sql:key" -> Some(""))))
assertEqual("SET spark:sql:key=-1 ", SetCommand(Some("spark:sql:key" -> Some("-1"))))
assertEqual("SET spark:sql:key = -1", SetCommand(Some("spark:sql:key" -> Some("-1"))))
assertEqual("SET 1.2.key=value", SetCommand(Some("1.2.key" -> Some("value"))))
assertEqual("SET spark.sql.3=4", SetCommand(Some("spark.sql.3" -> Some("4"))))
assertEqual("SET 1:2:key=value", SetCommand(Some("1:2:key" -> Some("value"))))
assertEqual("SET spark:sql:3=4", SetCommand(Some("spark:sql:3" -> Some("4"))))
assertEqual("SET 5=6", SetCommand(Some("5" -> Some("6"))))
assertEqual("SET spark:sql:key = va l u e ",
SetCommand(Some("spark:sql:key" -> Some("va l u e"))))
assertEqual("SET `spark.sql. key`=value",
SetCommand(Some("spark.sql. key" -> Some("value"))))
assertEqual("SET `spark.sql. key`= v a lu e ",
SetCommand(Some("spark.sql. key" -> Some("v a lu e"))))
assertEqual("SET `spark.sql. key`= -1",
SetCommand(Some("spark.sql. key" -> Some("-1"))))

val expectedErrMsg = "Expected format is 'SET', 'SET key', or " +
"'SET key=value'. If you want to include special characters in key, " +
"please use quotes, e.g., SET `ke y`=value."
intercept("SET spark.sql.key value", expectedErrMsg)
intercept("SET spark.sql.key 'value'", expectedErrMsg)
intercept("SET spark.sql.key \"value\" ", expectedErrMsg)
intercept("SET spark.sql.key value1 value2", expectedErrMsg)
intercept("SET spark. sql.key=value", expectedErrMsg)
intercept("SET spark :sql:key=value", expectedErrMsg)
intercept("SET spark . sql.key=value", expectedErrMsg)
intercept("SET spark.sql. key=value", expectedErrMsg)
intercept("SET spark.sql :key=value", expectedErrMsg)
intercept("SET spark.sql . key=value", expectedErrMsg)
}

test("Report Error for invalid usage of RESET command") {
assertEqual("RESET", ResetCommand(None))
assertEqual("RESET spark.sql.key", ResetCommand(Some("spark.sql.key")))
assertEqual("RESET spark.sql.key ", ResetCommand(Some("spark.sql.key")))
assertEqual("RESET 1.2.key ", ResetCommand(Some("1.2.key")))
assertEqual("RESET spark.sql.3", ResetCommand(Some("spark.sql.3")))
assertEqual("RESET 1:2:key ", ResetCommand(Some("1:2:key")))
assertEqual("RESET spark:sql:3", ResetCommand(Some("spark:sql:3")))
assertEqual("RESET `spark.sql. key`", ResetCommand(Some("spark.sql. key")))

val expectedErrMsg = "Expected format is 'RESET' or 'RESET key'. " +
"If you want to include special characters in key, " +
"please use quotes, e.g., RESET `ke y`."
intercept("RESET spark.sql.key1 key2", expectedErrMsg)
intercept("RESET spark. sql.key1 key2", expectedErrMsg)
intercept("RESET spark.sql.key1 key2 key3", expectedErrMsg)
intercept("RESET spark: sql:key", expectedErrMsg)
intercept("RESET spark .sql.key", expectedErrMsg)
intercept("RESET spark : sql:key", expectedErrMsg)
intercept("RESET spark.sql: key", expectedErrMsg)
intercept("RESET spark.sql .key", expectedErrMsg)
intercept("RESET spark.sql : key", expectedErrMsg)
}

test("refresh resource") {
assertEqual("REFRESH prefix_path", RefreshResource("prefix_path"))
assertEqual("REFRESH /", RefreshResource("/"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession {
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10)
} finally {
sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original")
sql(s"set ${SQLConf.SHUFFLE_PARTITIONS.key}=$original")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing code looks incorrect.

}
}

Expand Down Expand Up @@ -146,7 +146,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession {
assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL))
assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0)
} finally {
sql(s"set ${SQLConf.GROUP_BY_ORDINAL}=$original")
sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=$original")
}
}

Expand All @@ -162,7 +162,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession {
assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100)
assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 0)
} finally {
sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS}=$original")
sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=$original")
}
}

Expand Down