diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index 643b83b1741ae..f4faa8c7d3612 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.catalyst.util
import java.util.regex.{Pattern, PatternSyntaxException}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import org.apache.commons.lang.StringUtils.isNotBlank
import org.apache.spark.sql.AnalysisException
import org.apache.spark.unsafe.types.UTF8String
+
object StringUtils {
/**
@@ -90,12 +93,145 @@ object StringUtils {
funcNames.toSeq
}
+ /**
+ * Split the text into one or more SQLs with bracketed comments reserved
+ *
+ * Highlighted Corner Cases: semicolon in double quotes, single quotes or inline comments.
+ * Expected Behavior: The blanks will be trimed and a blank line will be omitted.
+ *
+ * @param text One or more SQLs separated by semicolons
+ * @return the trimmed SQL array (Array is for Java introp)
+ */
+ def split(text: String): Array[String] = {
+ val D_QUOTE: Char = '"'
+ val S_QUOTE: Char = '\''
+ val Q_QUOTE: Char = '`'
+ val SEMICOLON: Char = ';'
+ val ESCAPE: Char = '\\'
+ val DOT = '.'
+ val STAR = '*'
+ val DASH = '-'
+ val SINGLE_COMMENT = "--"
+ val BRACKETED_COMMENT_START = "/*"
+ val BRACKETED_COMMENT_END = "*/"
+ val FORWARD_SLASH = '/'
+
+ sealed trait Flag
+ case object Common extends Flag
+ trait Quote extends Flag {
+ def toChar: Char
+ def sameAs(quoteChar: Char): Boolean = { toChar == quoteChar }
+ }
+ object Quote {
+ def validate(quoteChar: Char): Boolean = {
+ List(D_QUOTE, S_QUOTE, Q_QUOTE).contains(quoteChar)
+ }
+
+ def apply(quoteChar: Char): Quote = quoteChar match {
+ case D_QUOTE => DoubleQuote
+ case S_QUOTE => SingleQuote
+ case Q_QUOTE => QuasiQuote
+ case _ =>
+ throw new IllegalArgumentException(s"$quoteChar is not a character for quoting")
+ }
+ }
+ trait Comment extends Flag
+ trait OpenAndClose {
+ def openBy(text: String): Boolean
+ def closeBy(text: String): Boolean
+ }
+ // the cursor stands on a doulbe quoted string
+ case object DoubleQuote extends Quote with OpenAndClose {
+ override def openBy(text: String): Boolean = text.startsWith(D_QUOTE.toString)
+ override def closeBy(text: String): Boolean = text.startsWith(D_QUOTE.toString)
+ override def toChar: Char = D_QUOTE
+ }
+ // the cursor stands on a quasiquoted string
+ case object QuasiQuote extends Quote with OpenAndClose {
+ override def openBy(text: String): Boolean = text.startsWith(Q_QUOTE.toString)
+ override def closeBy(text: String): Boolean = text.startsWith(Q_QUOTE.toString)
+ override def toChar: Char = Q_QUOTE
+ }
+ // the cursor stands on a single quoted string
+ case object SingleQuote extends Quote with OpenAndClose {
+ override def openBy(text: String): Boolean = text.startsWith(S_QUOTE.toString)
+ override def closeBy(text: String): Boolean = text.startsWith(S_QUOTE.toString)
+ override def toChar: Char = S_QUOTE
+ }
+ // the cursor stands in the SINGLE_COMMENT
+ case object SingleComment extends Comment {
+ def openBy(text: String): Boolean = text.startsWith(SINGLE_COMMENT)
+ }
+ // the cursor stands in the BRACKETED_COMMENT
+ case object BracketedComment extends Comment with OpenAndClose {
+ override def openBy(text: String): Boolean = text.startsWith(BRACKETED_COMMENT_START)
+ override def closeBy(text: String): Boolean = text.startsWith(BRACKETED_COMMENT_END)
+ }
+
+ var flag: Flag = Common
+ var cursor: Int = 0
+ val ret: mutable.ArrayBuffer[String] = mutable.ArrayBuffer()
+ var currentSQL: mutable.StringBuilder = mutable.StringBuilder.newBuilder
+
+ while (cursor < text.length) {
+ val current: Char = text(cursor)
+ val remaining = text.substring(cursor)
+
+ (flag, current) match {
+ // if it stands on the opening of a bracketed comment, consume 2 characters
+ case (Common, FORWARD_SLASH) if BracketedComment.openBy(remaining) =>
+ flag = BracketedComment
+ currentSQL.append(BRACKETED_COMMENT_START)
+ cursor += 2
+ // if it stands on the ending of a bracketed comment, consume 2 characters
+ case (BracketedComment, STAR) if BracketedComment.closeBy(remaining) =>
+ flag = Common
+ currentSQL.append(BRACKETED_COMMENT_END)
+ cursor += 2
+
+ // if it stands on the opening of inline comment, move cursor at the end of this line
+ case (Common, DASH) if SingleComment.openBy(remaining) =>
+ cursor += remaining.takeWhile(x => x != '\n').length
+
+ // if it stands on a normal semicolon, stage the current sql and move the cursor on
+ case (Common, SEMICOLON) =>
+ ret += currentSQL.toString.trim
+ currentSQL.clear()
+ cursor += 1
+
+ // if it stands on the openning of quotes, mark the flag and move on
+ case (Common, quoteChar) if Quote.validate(quoteChar) =>
+ flag = Quote(quoteChar)
+ currentSQL += current
+ cursor += 1
+ // if it stands on '\' in the quotes, consume 2 characters to avoid the ESCAPE of " or '
+ case (quote: Quote, ESCAPE) if remaining.length >= 2 =>
+ currentSQL.append(remaining.take(2))
+ cursor += 2
+ // if it stands on the ending of quotes, clear the flag and move on
+ case (quote: Quote, quoteChar) if quote.sameAs(quoteChar) =>
+ flag = Common
+ currentSQL += current
+ cursor += 1
+
+ // move on and push the char to the currentSQL
+ case _ =>
+ currentSQL += current
+ cursor += 1
+ }
+ }
+
+ ret += currentSQL.toString.trim
+ ret.filter(isNotBlank).toArray
+ }
+
+
/**
* Concatenation of sequence of strings to final string with cheap append method
* and one memory allocation for the final string.
*/
class StringConcat {
- private val strings = new ArrayBuffer[String]
+ private val strings = new mutable.ArrayBuffer[String]
private var length: Int = 0
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
index 616ec12032dbd..14d1e79e07906 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -44,6 +44,74 @@ class StringUtilsSuite extends SparkFunSuite {
assert(filterPattern(names, " d* ") === Nil)
}
+ test("split the SQL text") {
+ val statement = "select * from tmp.dada;"
+ assert(StringUtils.split(statement) === Array("select * from tmp.dada"))
+
+ // blanks will be omitted
+ val statements = " select * from tmp.data;;select * from tmp.ata;"
+ assert(StringUtils.split(statements) ===
+ Array("select * from tmp.data", "select * from tmp.ata"))
+
+ val escapedSingleQuote =
+ raw"""
+ |select "\';"
+ """.stripMargin.trim
+ val escapedDoubleQuote =
+ raw"""
+ |select "\";"
+ """.stripMargin.trim
+ assert(StringUtils.split(escapedSingleQuote) === Array(escapedSingleQuote))
+ assert(StringUtils.split(escapedDoubleQuote) === Array(escapedDoubleQuote))
+
+ val semicolonInDoubleQuotes =
+ """
+ |select "^;^"
+ """.stripMargin.trim
+ val semicolonInSingleQuotes =
+ """
+ |select '^;^'
+ """.stripMargin.trim
+ assert(StringUtils.split(semicolonInDoubleQuotes) === Array(semicolonInDoubleQuotes))
+ assert(StringUtils.split(semicolonInSingleQuotes) === Array(semicolonInSingleQuotes))
+
+ val inlineComments =
+ """
+ |select 1; --;;;;;;;;
+ |select "---";
+ """.stripMargin
+ val select1 = "select 1"
+ val selectComments =
+ """
+ |select "---"
+ """.stripMargin.trim
+ assert(StringUtils.split(inlineComments) === Array(select1, selectComments))
+
+ val bracketedComment1 = "select 1 /*;*/" // Good
+ assert(StringUtils.split(bracketedComment1) === Array(bracketedComment1))
+ val bracketedComment2 = "select 1 /* /* ; */" // Good
+ assert(StringUtils.split(bracketedComment2) === Array(bracketedComment2))
+ val bracketedComment3 = "select 1 /* */ ; */" // Bad
+ assert(StringUtils.split(bracketedComment3).head === "select 1 /* */")
+ val bracketedComment4 = "select 1 /**/ ; /* */" // Good
+ assert(StringUtils.split(bracketedComment4).head === "select 1 /**/")
+ val bracketedComment5 = "select 1 /**/ ; /**/" // Good
+ assert(StringUtils.split(bracketedComment5).head === "select 1 /**/")
+ val bracketedComment6 = "select /* bla bla */ 1" // Hints are reserved
+ assert(StringUtils.split(bracketedComment6).head === bracketedComment6)
+
+ val qQuote1 = "select 1 as `;`" // Good
+ assert(StringUtils.split(qQuote1) === Array(qQuote1))
+ val qQuote2 = "select 1 as ``;`" // Bad
+ assert(StringUtils.split(qQuote2) === Array("select 1 as ``", "`"))
+
+ // The splitter rule of the following two cases does not match the actual antlr4 rule
+ // We should not make the splitter two complicated
+ // val bracketedComment6 = "select 1 /**/ ; */" // Good
+ // val bracketedComment7 = "select 1 /* */ ; /* */" // Bad
+ // val qQuote3 = "select 1 as ```;`" // Good
+ }
+
test("string concatenation") {
def concat(seq: String*): String = {
seq.foldLeft(new StringConcat())((acc, s) => {acc.append(s); acc}).toString
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 4a4629fae2706..1e519ae6aacd7 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -84,6 +84,11 @@
selenium-htmlunit-driver
test
+
+ org.mockito
+ mockito-core
+ test
+
org.apache.spark
spark-sql_${scala.binary.version}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 463cf1680efde..9cfdc4e3258c6 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import jline.console.ConsoleReader
import jline.console.history.FileHistory
-import org.apache.commons.lang3.StringUtils
+import org.apache.commons.lang.{StringUtils => ApacheStringUtils}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
@@ -37,14 +37,16 @@ import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.log4j.Level
import org.apache.thrift.transport.TSocket
+import sun.misc.{Signal, SignalHandler}
import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.security.HiveDelegationTokenProvider
-import org.apache.spark.util.ShutdownHookManager
+import org.apache.spark.util.{ShutdownHookManager, Utils}
/**
* This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver
@@ -143,8 +145,8 @@ private[hive] object SparkSQLCLIDriver extends Logging {
// See also: code in ExecDriver.java
var loader = conf.getClassLoader
val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
- if (StringUtils.isNotBlank(auxJars)) {
- loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ","))
+ if (ApacheStringUtils.isNotBlank(auxJars)) {
+ loader = Utilities.addToClassPath(loader, ApacheStringUtils.split(auxJars, ","))
}
conf.setClassLoader(loader)
Thread.currentThread().setContextClassLoader(loader)
@@ -309,12 +311,16 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
private val conf: Configuration =
if (sessionState != null) sessionState.getConf else new Configuration()
- // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver
- // because the Hive unit tests do not go through the main() code path.
if (!isRemoteMode) {
- SparkSQLEnv.init()
- if (sessionState.getIsSilent) {
- SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString)
+ // Utils.isTesing consists of env[SPARK_TESTING] or props[spark.testing]
+ // env is multi-process-level, props is single-process-level
+ // CliSuite with env[SPARK_TESTING] requires SparkSQLEnv
+ // props[spark.testing] acts as a switcher for SparkSQLCLIDriverSuite
+ if (!sys.props.contains("spark.testing")) {
+ SparkSQLEnv.init()
+ if (sessionState.getIsSilent) {
+ SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString)
+ }
}
} else {
// Hive 1.2 + not supported in CLI
@@ -331,6 +337,65 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
console.printInfo(s"Spark master: $master, Application Id: $appId")
}
+ // method body imported from Hive and translated from Java to Scala
+ override def processLine(line: String, allowInterrupting: Boolean): Int = {
+ var oldSignal: SignalHandler = null
+ var interruptSignal: Signal = null
+
+ if (allowInterrupting) {
+ // Remember all threads that were running at the time we started line processing.
+ // Hook up the custom Ctrl+C handler while processing this line
+ interruptSignal = new Signal("INT")
+ oldSignal = Signal.handle(interruptSignal, new SignalHandler() {
+ private val cliThread = Thread.currentThread()
+ private var interruptRequested: Boolean = false
+
+ override def handle(signal: Signal) {
+ val initialRequest = !interruptRequested
+ interruptRequested = true
+
+ // Kill the VM on second ctrl+c
+ if (!initialRequest) {
+ console.printInfo("Exiting the JVM")
+ System.exit(127)
+ }
+
+ // Interrupt the CLI thread to stop the current statement and return
+ // to prompt
+ console.printInfo("Interrupting... Be patient, this might take some time.")
+ console.printInfo("Press Ctrl+C again to kill JVM")
+
+ // First, kill any running Spark jobs
+ // TODO
+ }
+ })
+ }
+
+ try {
+ var lastRet: Int = 0
+ var ret: Int = 0
+
+ for (command <- StringUtils.split(line)) {
+ ret = processCmd(command)
+ // wipe cli query state
+ sessionState.setCommandType(null)
+ lastRet = ret
+ val ignoreErrors = HiveConf.getBoolVar(conf, HiveConf.ConfVars.CLIIGNOREERRORS)
+ if (ret != 0 && !ignoreErrors) {
+ CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf])
+ return ret
+ }
+ }
+ CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]);
+ lastRet
+ } finally {
+ // Once we are done processing the line, restore the old handler
+ if (oldSignal != null && interruptSignal != null) {
+ Signal.handle(interruptSignal, oldSignal)
+ }
+ }
+ }
+
override def processCmd(cmd: String): Int = {
val cmd_trimmed: String = cmd.trim()
val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT)
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala
new file mode 100644
index 0000000000000..d1151cc9851dc
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive.thriftserver
+
+import org.apache.hadoop.hive.cli.CliSessionState
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.mockito.ArgumentMatcher
+import org.mockito.ArgumentMatchers.argThat
+import org.mockito.Mockito._
+
+import org.apache.spark.SparkFunSuite
+
+class SparkSQLCLIDriverSuite extends SparkFunSuite {
+
+ def matchSQL(sqlText: String, candidates: String*): Unit = {
+ class SQLMatcher extends ArgumentMatcher[String] {
+ override def matches(command: String): Boolean =
+ candidates.contains(command.asInstanceOf[String])
+ }
+
+ val conf = new HiveConf(classOf[SessionState])
+ val sessionState = new CliSessionState(conf)
+ SessionState.start(sessionState)
+ val cli = mock(classOf[SparkSQLCLIDriver])
+
+ when(cli.processCmd(argThat(new SQLMatcher))).thenReturn(0)
+ assert(cli.processLine(sqlText) == 0)
+ }
+
+ test("SPARK-26312: sql text splitting for the processCmd method") {
+ // semicolon in a string
+ val sql =
+ """
+ |select "^;^"
+ """.stripMargin.trim
+ matchSQL(sql, sql)
+
+ // normal statements
+ val statements =
+ """
+ |select d from data;
+ |select a from data
+ """.stripMargin
+ val dStatement = "select d from data"
+ val aStatement = "select a from data"
+ matchSQL(statements, dStatement, aStatement)
+ }
+}