diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 381b8f2324ca..fbcf97c2b668 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -18,19 +18,21 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File -import java.sql.{DriverManager, SQLException, Statement, Timestamp} -import java.util.Locale +import java.sql.{DriverManager, Statement, Timestamp} +import java.util.{Locale, MissingFormatArgumentException} import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.HiveSQLException -import org.scalatest.Ignore +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite} +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.util.fileToString import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -43,12 +45,12 @@ import org.apache.spark.sql.types._ * 2. Support DESC command. * 3. Support SHOW command. */ -@Ignore class ThriftServerQueryTestSuite extends SQLQueryTestSuite { private var hiveServer2: HiveThriftServer2 = _ - override def beforeEach(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 var listeningPort = 10000 + Random.nextInt(10000) @@ -65,10 +67,19 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { logInfo("HiveThriftServer2 started successfully") } - override def afterEach(): Unit = { - hiveServer2.stop() + override def afterAll(): Unit = { + try { + hiveServer2.stop() + } finally { + super.afterAll() + } } + override def sparkConf: SparkConf = super.sparkConf + // Hive Thrift server should not executes SQL queries in an asynchronous way + // because we may set session configuration. + .set(HiveUtils.HIVE_THRIFT_SERVER_ASYNC, false) + override val isTestWithConfigSets = false /** List of test cases to ignore, in lower cases. */ @@ -79,9 +90,6 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { "pgSQL/case.sql", // SPARK-28624 "date.sql", - // SPARK-28619 - "pgSQL/aggregates_part1.sql", - "group-by.sql", // SPARK-28620 "pgSQL/float4.sql", // SPARK-28636 @@ -89,12 +97,10 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { "literals.sql", "subquery/scalar-subquery/scalar-subquery-predicate.sql", "subquery/in-subquery/in-limit.sql", + "subquery/in-subquery/in-group-by.sql", "subquery/in-subquery/simple-in.sql", "subquery/in-subquery/in-order-by.sql", - "subquery/in-subquery/in-set-operations.sql", - // SPARK-28637 - "cast.sql", - "ansi/interval.sql" + "subquery/in-subquery/in-set-operations.sql" ) override def runQueries( @@ -166,19 +172,42 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { || d.sql.toUpperCase(Locale.ROOT).startsWith("DESC\n") || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE ") || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE\n") => + // Skip show command, see HiveResult.hiveResultString case s if s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW ") || s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW\n") => - // AnalysisException should exactly match. - // SQLException should not exactly match. We only assert the result contains Exception. - case _ if output.output.startsWith(classOf[SQLException].getName) => + + case _ if output.output.startsWith(classOf[NoSuchTableException].getPackage.getName) => + assert(expected.output.startsWith(classOf[NoSuchTableException].getPackage.getName), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[SparkException].getName) && + output.output.contains("overflow") => + assert(expected.output.contains(classOf[ArithmeticException].getName) && + expected.output.contains("overflow"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[RuntimeException].getName) => assert(expected.output.contains("Exception"), s"Exception did not match for query #$i\n${expected.sql}, " + s"expected: ${expected.output}, but got: ${output.output}") - // HiveSQLException is usually a feature that our ThriftServer cannot support. - // Please add SQL to blackList. - case _ if output.output.startsWith(classOf[HiveSQLException].getName) => - assert(false, s"${output.output} for query #$i\n${expected.sql}") + + case _ if output.output.startsWith(classOf[ArithmeticException].getName) && + output.output.contains("causes overflow") => + assert(expected.output.contains(classOf[ArithmeticException].getName) && + expected.output.contains("causes overflow"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[MissingFormatArgumentException].getName) && + output.output.contains("Format specifier") => + assert(expected.output.contains(classOf[MissingFormatArgumentException].getName) && + expected.output.contains("Format specifier"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + case _ => assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { output.output @@ -248,8 +277,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted case NonFatal(e) => + val rootCause = ExceptionUtils.getRootCause(e) // If there is an exception, put the exception class followed by the message. - Seq(e.getClass.getName, e.getMessage) + Seq(rootCause.getClass.getName, rootCause.getMessage) } }