Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ object SparkParallelTestGrouping {
"org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite",
"org.apache.spark.ml.classification.LogisticRegressionSuite",
"org.apache.spark.ml.classification.LinearSVCSuite",
"org.apache.spark.sql.SQLQueryTestSuite"
"org.apache.spark.sql.SQLQueryTestSuite",
"org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite"
)

private val DEFAULT_TEST_GROUP = "default_test_group"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,33 +107,40 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
import IntegratedUDFTestUtils._

private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"

private val baseResourcePath = {
// If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded
// relative path. Otherwise, we use classloader's getResource to find the location.
if (regenerateGoldenFiles) {
java.nio.file.Paths.get("src", "test", "resources", "sql-tests").toFile
} else {
val res = getClass.getClassLoader.getResource("sql-tests")
new File(res.getFile)
protected val isTestWithConfigSets: Boolean = true

protected val baseResourcePath = {
// We use a path based on Spark home for 2 reasons:
// 1. Maven can't get correct resource directory when resources in other jars.
// 2. We test subclasses in the hive-thriftserver module.
val sparkHome = {
assert(sys.props.contains("spark.test.home") ||
sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.")
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
}

java.nio.file.Paths.get(sparkHome,
"sql", "core", "src", "test", "resources", "sql-tests").toFile
}

private val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath
private val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath
protected val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath
protected val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath

private val validFileExtensions = ".sql"
protected val validFileExtensions = ".sql"

private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName

/** List of test cases to ignore, in lower cases. */
private val blackList = Set(
protected def blackList: Set[String] = Set(
"blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality.
)

// Create all the test cases.
listTestCases().foreach(createScalaTestCase)

/** A single SQL query's output. */
private case class QueryOutput(sql: String, schema: String, output: String) {
protected case class QueryOutput(sql: String, schema: String, output: String) {
def toString(queryIndex: Int): String = {
// We are explicitly not using multi-line string due to stripMargin removing "|" in output.
s"-- !query $queryIndex\n" +
Expand All @@ -146,7 +153,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}

/** A test case. */
private trait TestCase {
protected trait TestCase {
val name: String
val inputFile: String
val resultFile: String
Expand All @@ -156,35 +163,35 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
* traits that indicate UDF or PgSQL to trigger the code path specific to each. For instance,
* PgSQL tests require to register some UDF functions.
*/
private trait PgSQLTest
protected trait PgSQLTest

private trait UDFTest {
protected trait UDFTest {
val udf: TestUDF
}

/** A regular test case. */
private case class RegularTestCase(
protected case class RegularTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase

/** A PostgreSQL test case. */
private case class PgSQLTestCase(
protected case class PgSQLTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase with PgSQLTest

/** A UDF test case. */
private case class UDFTestCase(
protected case class UDFTestCase(
name: String,
inputFile: String,
resultFile: String,
udf: TestUDF) extends TestCase with UDFTest

/** A UDF PostgreSQL test case. */
private case class UDFPgSQLTestCase(
protected case class UDFPgSQLTestCase(
name: String,
inputFile: String,
resultFile: String,
udf: TestUDF) extends TestCase with UDFTest with PgSQLTest

private def createScalaTestCase(testCase: TestCase): Unit = {
protected def createScalaTestCase(testCase: TestCase): Unit = {
if (blackList.exists(t =>
testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) {
// Create a test case to ignore this case.
Expand Down Expand Up @@ -222,7 +229,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}

/** Run a test case. */
private def runTest(testCase: TestCase): Unit = {
protected def runTest(testCase: TestCase): Unit = {
val input = fileToString(new File(testCase.inputFile))

val (comments, code) = input.split("\n").partition(_.trim.startsWith("--"))
Expand All @@ -235,7 +242,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {

// When we are regenerating the golden files, we don't need to set any config as they
// all need to return the same result
if (regenerateGoldenFiles) {
if (regenerateGoldenFiles || !isTestWithConfigSets) {
runQueries(queries, testCase, None)
} else {
val configSets = {
Expand Down Expand Up @@ -271,7 +278,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}
}

private def runQueries(
protected def runQueries(
queries: Seq[String],
testCase: TestCase,
configSet: Option[Seq[(String, String)]]): Unit = {
Expand Down Expand Up @@ -388,19 +395,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
try {
val df = session.sql(sql)
val schema = df.schema
val notIncludedMsg = "[not included in comparison]"
val clsName = this.getClass.getCanonicalName
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = hiveResultString(df.queryExecution.executedPlan)
.map(_.replaceAll("#\\d+", "#x")
.replaceAll(
s"Location.*/sql/core/spark-warehouse/$clsName/",
s"Location ${notIncludedMsg}sql/core/spark-warehouse/")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds
val answer = hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
Expand All @@ -418,7 +414,19 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}
}

private def listTestCases(): Seq[TestCase] = {
protected def replaceNotIncludedMsg(line: String): String = {
line.replaceAll("#\\d+", "#x")
.replaceAll(
s"Location.*/sql/core/spark-warehouse/$clsName/",
s"Location ${notIncludedMsg}sql/core/spark-warehouse/")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds
}

protected def listTestCases(): Seq[TestCase] = {
listFilesRecursively(new File(inputFilePath)).flatMap { file =>
val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out"
val absPath = file.getAbsolutePath
Expand All @@ -444,7 +452,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}

/** Returns all the files (not directories) in a directory, recursively. */
private def listFilesRecursively(path: File): Seq[File] = {
protected def listFilesRecursively(path: File): Seq[File] = {
val (dirs, files) = path.listFiles().partition(_.isDirectory)
// Filter out test files with invalid extensions such as temp files created
// by vi (.swp), Mac (.DS_Store) etc.
Expand Down
7 changes: 7 additions & 0 deletions sql/hive-thriftserver/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object HiveThriftServer2 extends Logging {
* Starts a new thrift server with the given context.
*/
@DeveloperApi
def startWithContext(sqlContext: SQLContext): Unit = {
def startWithContext(sqlContext: SQLContext): HiveThriftServer2 = {
val server = new HiveThriftServer2(sqlContext)

val executionHive = HiveUtils.newClientForExecution(
Expand All @@ -69,6 +69,7 @@ object HiveThriftServer2 extends Logging {
} else {
None
}
server
}

def main(args: Array[String]) {
Expand Down
Loading