diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f1d2e3788918..25c2fb4af5c3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -474,7 +474,8 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite", "org.apache.spark.ml.classification.LogisticRegressionSuite", "org.apache.spark.ml.classification.LinearSVCSuite", - "org.apache.spark.sql.SQLQueryTestSuite" + "org.apache.spark.sql.SQLQueryTestSuite", + "org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite" ) private val DEFAULT_TEST_GROUP = "default_test_group" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4bdf25051127..5c1ff9cd735e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -107,8 +107,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { import IntegratedUDFTestUtils._ private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" + protected val isTestWithConfigSets: Boolean = true - private val baseResourcePath = { + protected val baseResourcePath = { // If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded // relative path. Otherwise, we use classloader's getResource to find the location. if (regenerateGoldenFiles) { @@ -119,13 +120,16 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } } - private val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath - private val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath + protected val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath + protected val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath - private val validFileExtensions = ".sql" + protected val validFileExtensions = ".sql" + + private val notIncludedMsg = "[not included in comparison]" + private val clsName = this.getClass.getCanonicalName /** List of test cases to ignore, in lower cases. */ - private val blackList = Set( + protected def blackList: Set[String] = Set( "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. ) @@ -133,7 +137,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { listTestCases().foreach(createScalaTestCase) /** A single SQL query's output. */ - private case class QueryOutput(sql: String, schema: String, output: String) { + protected case class QueryOutput(sql: String, schema: String, output: String) { def toString(queryIndex: Int): String = { // We are explicitly not using multi-line string due to stripMargin removing "|" in output. s"-- !query $queryIndex\n" + @@ -146,7 +150,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } /** A test case. */ - private trait TestCase { + protected trait TestCase { val name: String val inputFile: String val resultFile: String @@ -156,35 +160,35 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { * traits that indicate UDF or PgSQL to trigger the code path specific to each. For instance, * PgSQL tests require to register some UDF functions. */ - private trait PgSQLTest + protected trait PgSQLTest - private trait UDFTest { + protected trait UDFTest { val udf: TestUDF } /** A regular test case. */ - private case class RegularTestCase( + protected case class RegularTestCase( name: String, inputFile: String, resultFile: String) extends TestCase /** A PostgreSQL test case. */ - private case class PgSQLTestCase( + protected case class PgSQLTestCase( name: String, inputFile: String, resultFile: String) extends TestCase with PgSQLTest /** A UDF test case. */ - private case class UDFTestCase( + protected case class UDFTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends TestCase with UDFTest /** A UDF PostgreSQL test case. */ - private case class UDFPgSQLTestCase( + protected case class UDFPgSQLTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends TestCase with UDFTest with PgSQLTest - private def createScalaTestCase(testCase: TestCase): Unit = { + protected def createScalaTestCase(testCase: TestCase): Unit = { if (blackList.exists(t => testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. @@ -222,7 +226,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } /** Run a test case. */ - private def runTest(testCase: TestCase): Unit = { + protected def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) val (comments, code) = input.split("\n").partition(_.trim.startsWith("--")) @@ -235,7 +239,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // When we are regenerating the golden files, we don't need to set any config as they // all need to return the same result - if (regenerateGoldenFiles) { + if (regenerateGoldenFiles || !isTestWithConfigSets) { runQueries(queries, testCase, None) } else { val configSets = { @@ -271,7 +275,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } } - private def runQueries( + protected def runQueries( queries: Seq[String], testCase: TestCase, configSet: Option[Seq[(String, String)]]): Unit = { @@ -388,19 +392,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { try { val df = session.sql(sql) val schema = df.schema - val notIncludedMsg = "[not included in comparison]" - val clsName = this.getClass.getCanonicalName // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = hiveResultString(df.queryExecution.executedPlan) - .map(_.replaceAll("#\\d+", "#x") - .replaceAll( - s"Location.*/sql/core/spark-warehouse/$clsName/", - s"Location ${notIncludedMsg}sql/core/spark-warehouse/") - .replaceAll("Created By.*", s"Created By $notIncludedMsg") - .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") - .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") - .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") - .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds + val answer = hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) @@ -418,7 +411,19 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } } - private def listTestCases(): Seq[TestCase] = { + protected def replaceNotIncludedMsg(line: String): String = { + line.replaceAll("#\\d+", "#x") + .replaceAll( + s"Location.*/sql/core/spark-warehouse/$clsName/", + s"Location ${notIncludedMsg}sql/core/spark-warehouse/") + .replaceAll("Created By.*", s"Created By $notIncludedMsg") + .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds + } + + protected def listTestCases(): Seq[TestCase] = { listFilesRecursively(new File(inputFilePath)).flatMap { file => val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" val absPath = file.getAbsolutePath @@ -444,7 +449,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } /** Returns all the files (not directories) in a directory, recursively. */ - private def listFilesRecursively(path: File): Seq[File] = { + protected def listFilesRecursively(path: File): Seq[File] = { val (dirs, files) = path.listFiles().partition(_.isDirectory) // Filter out test files with invalid extensions such as temp files created // by vi (.swp), Mac (.DS_Store) etc. diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 1abc65ad806b..5b1352adddd8 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -47,6 +47,13 @@ test-jar test + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-hive_${scala.binary.version} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index b4d1d0d58aad..abb53cf3429f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -53,7 +53,7 @@ object HiveThriftServer2 extends Logging { * Starts a new thrift server with the given context. */ @DeveloperApi - def startWithContext(sqlContext: SQLContext): Unit = { + def startWithContext(sqlContext: SQLContext): HiveThriftServer2 = { val server = new HiveThriftServer2(sqlContext) val executionHive = HiveUtils.newClientForExecution( @@ -69,6 +69,7 @@ object HiveThriftServer2 extends Logging { } else { None } + server } def main(args: Array[String]) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala new file mode 100644 index 000000000000..ba3284462b46 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.File +import java.sql.{DriverManager, SQLException, Statement, Timestamp} +import java.util.Locale + +import scala.util.{Random, Try} +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.HiveSQLException + +import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite} +import org.apache.spark.sql.catalyst.util.fileToString +import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Re-run all the tests in SQLQueryTestSuite via Thrift Server. + * + * TODO: + * 1. Support UDF testing. + * 2. Support DESC command. + * 3. Support SHOW command. + */ +class ThriftServerQueryTestSuite extends SQLQueryTestSuite { + + private var hiveServer2: HiveThriftServer2 = _ + + override def beforeAll(): Unit = { + super.beforeAll() + // Chooses a random port between 10000 and 19999 + var listeningPort = 10000 + Random.nextInt(10000) + + // Retries up to 3 times with different port numbers if the server fails to start + (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => + started.orElse { + listeningPort += 1 + Try(startThriftServer(listeningPort, attempt)) + } + }.recover { + case cause: Throwable => + throw cause + }.get + logInfo(s"HiveThriftServer2 started successfully") + } + + override def afterAll(): Unit = { + try { + hiveServer2.stop() + } finally { + super.afterAll() + } + } + + override val isTestWithConfigSets = false + + /** List of test cases to ignore, in lower cases. */ + override def blackList: Set[String] = Set( + "blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality. + // Missing UDF + "pgSQL/boolean.sql", + "pgSQL/case.sql", + // SPARK-28624 + "date.sql", + // SPARK-28619 + "pgSQL/aggregates_part1.sql", + "group-by.sql", + // SPARK-28620 + "pgSQL/float4.sql", + // SPARK-28636 + "decimalArithmeticOperations.sql", + "literals.sql", + "subquery/scalar-subquery/scalar-subquery-predicate.sql", + "subquery/in-subquery/in-limit.sql", + "subquery/in-subquery/simple-in.sql", + "subquery/in-subquery/in-order-by.sql", + "subquery/in-subquery/in-set-operations.sql", + // SPARK-28637 + "cast.sql", + "ansi/interval.sql" + ) + + override def runQueries( + queries: Seq[String], + testCase: TestCase, + configSet: Option[Seq[(String, String)]]): Unit = { + // We do not test with configSet. + withJdbcStatement { statement => + + loadTestData(statement) + + testCase match { + case _: PgSQLTest => + // PostgreSQL enabled cartesian product by default. + statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true") + statement.execute(s"SET ${SQLConf.ANSI_SQL_PARSER.key} = true") + statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true") + case _ => + } + + // Run the SQL queries preparing them for comparison. + val outputs: Seq[QueryOutput] = queries.map { sql => + val output = getNormalizedResult(statement, sql) + // We might need to do some query canonicalization in the future. + QueryOutput( + sql = sql, + schema = "", + output = output.mkString("\n").replaceAll("\\s+$", "")) + } + + // Read back the golden file. + val expectedOutputs: Seq[QueryOutput] = { + val goldenOutput = fileToString(new File(testCase.resultFile)) + val segments = goldenOutput.split("-- !query.+\n") + + // each query has 3 segments, plus the header + assert(segments.size == outputs.size * 3 + 1, + s"Expected ${outputs.size * 3 + 1} blocks in result file but got ${segments.size}. " + + s"Try regenerate the result files.") + Seq.tabulate(outputs.size) { i => + val sql = segments(i * 3 + 1).trim + val originalOut = segments(i * 3 + 3) + val output = if (isNeedSort(sql)) { + originalOut.split("\n").sorted.mkString("\n") + } else { + originalOut + } + QueryOutput( + sql = sql, + schema = "", + output = output.replaceAll("\\s+$", "") + ) + } + } + + // Compare results. + assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") { + outputs.size + } + + outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) => + assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") { + output.sql + } + + expected match { + // Skip desc command, see HiveResult.hiveResultString + case d if d.sql.toUpperCase(Locale.ROOT).startsWith("DESC ") + || d.sql.toUpperCase(Locale.ROOT).startsWith("DESC\n") + || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE ") + || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE\n") => + // Skip show command, see HiveResult.hiveResultString + case s if s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW ") + || s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW\n") => + // AnalysisException should exactly match. + // SQLException should not exactly match. We only assert the result contains Exception. + case _ if output.output.startsWith(classOf[SQLException].getName) => + assert(expected.output.contains("Exception"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + // HiveSQLException is usually a feature that our ThriftServer cannot support. + // Please add SQL to blackList. + case _ if output.output.startsWith(classOf[HiveSQLException].getName) => + assert(false, s"${output.output} for query #$i\n${expected.sql}") + case _ => + assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { + output.output + } + } + } + } + } + + override def createScalaTestCase(testCase: TestCase): Unit = { + if (blackList.exists(t => + testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { + // Create a test case to ignore this case. + ignore(testCase.name) { /* Do nothing */ } + } else { + // Create a test case to run this case. + test(testCase.name) { + runTest(testCase) + } + } + } + + override def listTestCases(): Seq[TestCase] = { + listFilesRecursively(new File(inputFilePath)).flatMap { file => + val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" + val absPath = file.getAbsolutePath + val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) + + if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) { + Seq.empty + } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) { + PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil + } else { + RegularTestCase(testCaseName, absPath, resultFile) :: Nil + } + } + } + + test("Check if ThriftServer can work") { + withJdbcStatement { statement => + val rs = statement.executeQuery("select 1L") + rs.next() + assert(rs.getLong(1) === 1L) + } + } + + private def getNormalizedResult(statement: Statement, sql: String): Seq[String] = { + try { + val rs = statement.executeQuery(sql) + val cols = rs.getMetaData.getColumnCount + val buildStr = () => (for (i <- 1 to cols) yield { + getHiveResult(rs.getObject(i)) + }).mkString("\t") + + val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq + .map(replaceNotIncludedMsg) + if (isNeedSort(sql)) { + answer.sorted + } else { + answer + } + } catch { + case a: AnalysisException => + // Do not output the logical plan tree which contains expression IDs. + // Also implement a crude way of masking expression IDs in the error message + // with a generic pattern "###". + val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage + Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted + case NonFatal(e) => + // If there is an exception, put the exception class followed by the message. + Seq(e.getClass.getName, e.getMessage) + } + } + + private def startThriftServer(port: Int, attempt: Int): Unit = { + logInfo(s"Trying to start HiveThriftServer2: port=$port, attempt=$attempt") + val sqlContext = spark.newSession().sqlContext + sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, port.toString) + hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) + } + + private def withJdbcStatement(fs: (Statement => Unit)*) { + val user = System.getProperty("user.name") + + val serverPort = hiveServer2.getHiveConf.get(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname) + val connections = + fs.map { _ => DriverManager.getConnection(s"jdbc:hive2://localhost:$serverPort", user, "") } + val statements = connections.map(_.createStatement()) + + try { + statements.zip(fs).foreach { case (s, f) => f(s) } + } finally { + statements.foreach(_.close()) + connections.foreach(_.close()) + } + } + + /** Load built-in test tables. */ + private def loadTestData(statement: Statement): Unit = { + // Prepare the data + statement.execute( + """ + |CREATE OR REPLACE TEMPORARY VIEW testdata as + |SELECT id AS key, CAST(id AS string) AS value FROM range(1, 101) + """.stripMargin) + statement.execute( + """ + |CREATE OR REPLACE TEMPORARY VIEW arraydata as + |SELECT * FROM VALUES + |(ARRAY(1, 2, 3), ARRAY(ARRAY(1, 2, 3))), + |(ARRAY(2, 3, 4), ARRAY(ARRAY(2, 3, 4))) AS v(arraycol, nestedarraycol) + """.stripMargin) + statement.execute( + """ + |CREATE OR REPLACE TEMPORARY VIEW mapdata as + |SELECT * FROM VALUES + |MAP(1, 'a1', 2, 'b1', 3, 'c1', 4, 'd1', 5, 'e1'), + |MAP(1, 'a2', 2, 'b2', 3, 'c2', 4, 'd2'), + |MAP(1, 'a3', 2, 'b3', 3, 'c3'), + |MAP(1, 'a4', 2, 'b4'), + |MAP(1, 'a5') AS v(mapcol) + """.stripMargin) + statement.execute( + s""" + |CREATE TEMPORARY VIEW aggtest + | (a int, b float) + |USING csv + |OPTIONS (path '${testFile("test-data/postgresql/agg.data")}', + | header 'false', delimiter '\t') + """.stripMargin) + statement.execute( + s""" + |CREATE OR REPLACE TEMPORARY VIEW onek + | (unique1 int, unique2 int, two int, four int, ten int, twenty int, hundred int, + | thousand int, twothousand int, fivethous int, tenthous int, odd int, even int, + | stringu1 string, stringu2 string, string4 string) + |USING csv + |OPTIONS (path '${testFile("test-data/postgresql/onek.data")}', + | header 'false', delimiter '\t') + """.stripMargin) + statement.execute( + s""" + |CREATE OR REPLACE TEMPORARY VIEW tenk1 + | (unique1 int, unique2 int, two int, four int, ten int, twenty int, hundred int, + | thousand int, twothousand int, fivethous int, tenthous int, odd int, even int, + | stringu1 string, stringu2 string, string4 string) + |USING csv + | OPTIONS (path '${testFile("test-data/postgresql/tenk.data")}', + | header 'false', delimiter '\t') + """.stripMargin) + } + + // Returns true if sql is retrieving data. + private def isNeedSort(sql: String): Boolean = { + val upperCase = sql.toUpperCase(Locale.ROOT) + upperCase.startsWith("SELECT ") || upperCase.startsWith("SELECT\n") || + upperCase.startsWith("WITH ") || upperCase.startsWith("WITH\n") || + upperCase.startsWith("VALUES ") || upperCase.startsWith("VALUES\n") || + // pgSQL/union.sql + upperCase.startsWith("(") + } + + private def getHiveResult(obj: Object): String = { + obj match { + case null => + HiveResult.toHiveString((null, StringType)) + case d: java.sql.Date => + HiveResult.toHiveString((d, DateType)) + case t: Timestamp => + HiveResult.toHiveString((t, TimestampType)) + case d: java.math.BigDecimal => + HiveResult.toHiveString((d, DecimalType.fromBigDecimal(d))) + case bin: Array[Byte] => + HiveResult.toHiveString((bin, BinaryType)) + case other => + other.toString + } + } +}