diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 643b83b1741ae..f4faa8c7d3612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import org.apache.commons.lang.StringUtils.isNotBlank import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String + object StringUtils { /** @@ -90,12 +93,145 @@ object StringUtils { funcNames.toSeq } + /** + * Split the text into one or more SQLs with bracketed comments reserved + * + * Highlighted Corner Cases: semicolon in double quotes, single quotes or inline comments. + * Expected Behavior: The blanks will be trimed and a blank line will be omitted. + * + * @param text One or more SQLs separated by semicolons + * @return the trimmed SQL array (Array is for Java introp) + */ + def split(text: String): Array[String] = { + val D_QUOTE: Char = '"' + val S_QUOTE: Char = '\'' + val Q_QUOTE: Char = '`' + val SEMICOLON: Char = ';' + val ESCAPE: Char = '\\' + val DOT = '.' + val STAR = '*' + val DASH = '-' + val SINGLE_COMMENT = "--" + val BRACKETED_COMMENT_START = "/*" + val BRACKETED_COMMENT_END = "*/" + val FORWARD_SLASH = '/' + + sealed trait Flag + case object Common extends Flag + trait Quote extends Flag { + def toChar: Char + def sameAs(quoteChar: Char): Boolean = { toChar == quoteChar } + } + object Quote { + def validate(quoteChar: Char): Boolean = { + List(D_QUOTE, S_QUOTE, Q_QUOTE).contains(quoteChar) + } + + def apply(quoteChar: Char): Quote = quoteChar match { + case D_QUOTE => DoubleQuote + case S_QUOTE => SingleQuote + case Q_QUOTE => QuasiQuote + case _ => + throw new IllegalArgumentException(s"$quoteChar is not a character for quoting") + } + } + trait Comment extends Flag + trait OpenAndClose { + def openBy(text: String): Boolean + def closeBy(text: String): Boolean + } + // the cursor stands on a doulbe quoted string + case object DoubleQuote extends Quote with OpenAndClose { + override def openBy(text: String): Boolean = text.startsWith(D_QUOTE.toString) + override def closeBy(text: String): Boolean = text.startsWith(D_QUOTE.toString) + override def toChar: Char = D_QUOTE + } + // the cursor stands on a quasiquoted string + case object QuasiQuote extends Quote with OpenAndClose { + override def openBy(text: String): Boolean = text.startsWith(Q_QUOTE.toString) + override def closeBy(text: String): Boolean = text.startsWith(Q_QUOTE.toString) + override def toChar: Char = Q_QUOTE + } + // the cursor stands on a single quoted string + case object SingleQuote extends Quote with OpenAndClose { + override def openBy(text: String): Boolean = text.startsWith(S_QUOTE.toString) + override def closeBy(text: String): Boolean = text.startsWith(S_QUOTE.toString) + override def toChar: Char = S_QUOTE + } + // the cursor stands in the SINGLE_COMMENT + case object SingleComment extends Comment { + def openBy(text: String): Boolean = text.startsWith(SINGLE_COMMENT) + } + // the cursor stands in the BRACKETED_COMMENT + case object BracketedComment extends Comment with OpenAndClose { + override def openBy(text: String): Boolean = text.startsWith(BRACKETED_COMMENT_START) + override def closeBy(text: String): Boolean = text.startsWith(BRACKETED_COMMENT_END) + } + + var flag: Flag = Common + var cursor: Int = 0 + val ret: mutable.ArrayBuffer[String] = mutable.ArrayBuffer() + var currentSQL: mutable.StringBuilder = mutable.StringBuilder.newBuilder + + while (cursor < text.length) { + val current: Char = text(cursor) + val remaining = text.substring(cursor) + + (flag, current) match { + // if it stands on the opening of a bracketed comment, consume 2 characters + case (Common, FORWARD_SLASH) if BracketedComment.openBy(remaining) => + flag = BracketedComment + currentSQL.append(BRACKETED_COMMENT_START) + cursor += 2 + // if it stands on the ending of a bracketed comment, consume 2 characters + case (BracketedComment, STAR) if BracketedComment.closeBy(remaining) => + flag = Common + currentSQL.append(BRACKETED_COMMENT_END) + cursor += 2 + + // if it stands on the opening of inline comment, move cursor at the end of this line + case (Common, DASH) if SingleComment.openBy(remaining) => + cursor += remaining.takeWhile(x => x != '\n').length + + // if it stands on a normal semicolon, stage the current sql and move the cursor on + case (Common, SEMICOLON) => + ret += currentSQL.toString.trim + currentSQL.clear() + cursor += 1 + + // if it stands on the openning of quotes, mark the flag and move on + case (Common, quoteChar) if Quote.validate(quoteChar) => + flag = Quote(quoteChar) + currentSQL += current + cursor += 1 + // if it stands on '\' in the quotes, consume 2 characters to avoid the ESCAPE of " or ' + case (quote: Quote, ESCAPE) if remaining.length >= 2 => + currentSQL.append(remaining.take(2)) + cursor += 2 + // if it stands on the ending of quotes, clear the flag and move on + case (quote: Quote, quoteChar) if quote.sameAs(quoteChar) => + flag = Common + currentSQL += current + cursor += 1 + + // move on and push the char to the currentSQL + case _ => + currentSQL += current + cursor += 1 + } + } + + ret += currentSQL.toString.trim + ret.filter(isNotBlank).toArray + } + + /** * Concatenation of sequence of strings to final string with cheap append method * and one memory allocation for the final string. */ class StringConcat { - private val strings = new ArrayBuffer[String] + private val strings = new mutable.ArrayBuffer[String] private var length: Int = 0 /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index 616ec12032dbd..14d1e79e07906 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -44,6 +44,74 @@ class StringUtilsSuite extends SparkFunSuite { assert(filterPattern(names, " d* ") === Nil) } + test("split the SQL text") { + val statement = "select * from tmp.dada;" + assert(StringUtils.split(statement) === Array("select * from tmp.dada")) + + // blanks will be omitted + val statements = " select * from tmp.data;;select * from tmp.ata;" + assert(StringUtils.split(statements) === + Array("select * from tmp.data", "select * from tmp.ata")) + + val escapedSingleQuote = + raw""" + |select "\';" + """.stripMargin.trim + val escapedDoubleQuote = + raw""" + |select "\";" + """.stripMargin.trim + assert(StringUtils.split(escapedSingleQuote) === Array(escapedSingleQuote)) + assert(StringUtils.split(escapedDoubleQuote) === Array(escapedDoubleQuote)) + + val semicolonInDoubleQuotes = + """ + |select "^;^" + """.stripMargin.trim + val semicolonInSingleQuotes = + """ + |select '^;^' + """.stripMargin.trim + assert(StringUtils.split(semicolonInDoubleQuotes) === Array(semicolonInDoubleQuotes)) + assert(StringUtils.split(semicolonInSingleQuotes) === Array(semicolonInSingleQuotes)) + + val inlineComments = + """ + |select 1; --;;;;;;;; + |select "---"; + """.stripMargin + val select1 = "select 1" + val selectComments = + """ + |select "---" + """.stripMargin.trim + assert(StringUtils.split(inlineComments) === Array(select1, selectComments)) + + val bracketedComment1 = "select 1 /*;*/" // Good + assert(StringUtils.split(bracketedComment1) === Array(bracketedComment1)) + val bracketedComment2 = "select 1 /* /* ; */" // Good + assert(StringUtils.split(bracketedComment2) === Array(bracketedComment2)) + val bracketedComment3 = "select 1 /* */ ; */" // Bad + assert(StringUtils.split(bracketedComment3).head === "select 1 /* */") + val bracketedComment4 = "select 1 /**/ ; /* */" // Good + assert(StringUtils.split(bracketedComment4).head === "select 1 /**/") + val bracketedComment5 = "select 1 /**/ ; /**/" // Good + assert(StringUtils.split(bracketedComment5).head === "select 1 /**/") + val bracketedComment6 = "select /* bla bla */ 1" // Hints are reserved + assert(StringUtils.split(bracketedComment6).head === bracketedComment6) + + val qQuote1 = "select 1 as `;`" // Good + assert(StringUtils.split(qQuote1) === Array(qQuote1)) + val qQuote2 = "select 1 as ``;`" // Bad + assert(StringUtils.split(qQuote2) === Array("select 1 as ``", "`")) + + // The splitter rule of the following two cases does not match the actual antlr4 rule + // We should not make the splitter two complicated + // val bracketedComment6 = "select 1 /**/ ; */" // Good + // val bracketedComment7 = "select 1 /* */ ; /* */" // Bad + // val qQuote3 = "select 1 as ```;`" // Good + } + test("string concatenation") { def concat(seq: String*): String = { seq.foldLeft(new StringConcat())((acc, s) => {acc.append(s); acc}).toString diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 4a4629fae2706..1e519ae6aacd7 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -84,6 +84,11 @@ selenium-htmlunit-driver test + + org.mockito + mockito-core + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 463cf1680efde..9cfdc4e3258c6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import jline.console.ConsoleReader import jline.console.history.FileHistory -import org.apache.commons.lang3.StringUtils +import org.apache.commons.lang.{StringUtils => ApacheStringUtils} import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} @@ -37,14 +37,16 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.Level import org.apache.thrift.transport.TSocket +import sun.misc.{Signal, SignalHandler} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.security.HiveDelegationTokenProvider -import org.apache.spark.util.ShutdownHookManager +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -143,8 +145,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { // See also: code in ExecDriver.java var loader = conf.getClassLoader val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) - if (StringUtils.isNotBlank(auxJars)) { - loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + if (ApacheStringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, ApacheStringUtils.split(auxJars, ",")) } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) @@ -309,12 +311,16 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() - // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver - // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { - SparkSQLEnv.init() - if (sessionState.getIsSilent) { - SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + // Utils.isTesing consists of env[SPARK_TESTING] or props[spark.testing] + // env is multi-process-level, props is single-process-level + // CliSuite with env[SPARK_TESTING] requires SparkSQLEnv + // props[spark.testing] acts as a switcher for SparkSQLCLIDriverSuite + if (!sys.props.contains("spark.testing")) { + SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } } else { // Hive 1.2 + not supported in CLI @@ -331,6 +337,65 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { console.printInfo(s"Spark master: $master, Application Id: $appId") } + // method body imported from Hive and translated from Java to Scala + override def processLine(line: String, allowInterrupting: Boolean): Int = { + var oldSignal: SignalHandler = null + var interruptSignal: Signal = null + + if (allowInterrupting) { + // Remember all threads that were running at the time we started line processing. + // Hook up the custom Ctrl+C handler while processing this line + interruptSignal = new Signal("INT") + oldSignal = Signal.handle(interruptSignal, new SignalHandler() { + private val cliThread = Thread.currentThread() + private var interruptRequested: Boolean = false + + override def handle(signal: Signal) { + val initialRequest = !interruptRequested + interruptRequested = true + + // Kill the VM on second ctrl+c + if (!initialRequest) { + console.printInfo("Exiting the JVM") + System.exit(127) + } + + // Interrupt the CLI thread to stop the current statement and return + // to prompt + console.printInfo("Interrupting... Be patient, this might take some time.") + console.printInfo("Press Ctrl+C again to kill JVM") + + // First, kill any running Spark jobs + // TODO + } + }) + } + + try { + var lastRet: Int = 0 + var ret: Int = 0 + + for (command <- StringUtils.split(line)) { + ret = processCmd(command) + // wipe cli query state + sessionState.setCommandType(null) + lastRet = ret + val ignoreErrors = HiveConf.getBoolVar(conf, HiveConf.ConfVars.CLIIGNOREERRORS) + if (ret != 0 && !ignoreErrors) { + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]) + return ret + } + } + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]); + lastRet + } finally { + // Once we are done processing the line, restore the old handler + if (oldSignal != null && interruptSignal != null) { + Signal.handle(interruptSignal, oldSignal) + } + } + } + override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala new file mode 100644 index 0000000000000..d1151cc9851dc --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriverSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.cli.CliSessionState +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.mockito.ArgumentMatcher +import org.mockito.ArgumentMatchers.argThat +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite + +class SparkSQLCLIDriverSuite extends SparkFunSuite { + + def matchSQL(sqlText: String, candidates: String*): Unit = { + class SQLMatcher extends ArgumentMatcher[String] { + override def matches(command: String): Boolean = + candidates.contains(command.asInstanceOf[String]) + } + + val conf = new HiveConf(classOf[SessionState]) + val sessionState = new CliSessionState(conf) + SessionState.start(sessionState) + val cli = mock(classOf[SparkSQLCLIDriver]) + + when(cli.processCmd(argThat(new SQLMatcher))).thenReturn(0) + assert(cli.processLine(sqlText) == 0) + } + + test("SPARK-26312: sql text splitting for the processCmd method") { + // semicolon in a string + val sql = + """ + |select "^;^" + """.stripMargin.trim + matchSQL(sql, sql) + + // normal statements + val statements = + """ + |select d from data; + |select a from data + """.stripMargin + val dStatement = "select d from data" + val aStatement = "select a from data" + matchSQL(statements, dStatement, aStatement) + } +}