Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,54 @@

package org.apache.spark.sql.hive.thriftserver

import java.io.File
import java.sql.{DriverManager, Statement}
import java.util.Locale

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.util.Try

import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.cli.thrift.ThriftCLIService

import org.apache.spark.SparkConf
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils

trait SharedThriftServer extends SharedSparkSession {
Utils.classForName(classOf[HiveDriver].getCanonicalName)

private var hiveServer2: HiveThriftServer2 = _
private var serverPort: Int = 0
protected var hiveServer2: HiveThriftServer2 = _
protected var hostName: String = _
protected var serverPort: Int = 0
protected val hiveConfList = "a=avalue;b=bvalue"
protected val hiveVarList = "c=cvalue;d=dvalue"

private def getTempDir(): File = {
val file = Utils.createTempDir()
file.delete()
file
}

protected val operationLogPath: File = getTempDir()
protected val lScratchDir: File = getTempDir()
protected val metastorePath: File = getTempDir()

protected def extraConf: Map[String, String] = Map.empty
protected def user: String = System.getProperty("user.name")

def mode: ServerMode.Value = ServerMode.binary

override def sparkConf: SparkConf = {
super.sparkConf
.set(StaticSQLConf.CATALOG_IMPLEMENTATION, "hive")
.set("spark.hadoop." + ConfVars.METASTORECONNECTURLKEY.varname,
s"jdbc:derby:;databaseName=$metastorePath;create=true")
.set("spark.ui.enabled", "false")
}

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -46,46 +79,96 @@ trait SharedThriftServer extends SharedSparkSession {
}

override def afterAll(): Unit = {
super.afterAll()
try {
hiveServer2.stop()
if (hiveServer2 != null) {
hiveServer2.stop()
hiveServer2 = null
}
} finally {
super.afterAll()
Utils.deleteRecursively(operationLogPath)
Utils.deleteRecursively(lScratchDir)
Utils.deleteRecursively(metastorePath)
}
}

protected def withJdbcStatement(fs: (Statement => Unit)*): Unit = {
val user = System.getProperty("user.name")
require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2")
val connections =
fs.map { _ => DriverManager.getConnection(s"jdbc:hive2://localhost:$serverPort", user, "") }
protected def jdbcUri: String = if (mode == ServerMode.http) {
s"""jdbc:hive2://localhost:$serverPort/
|default;
|transportMode=http;
|httpPath=cliservice;
|$hiveConfList#$hiveVarList
""".stripMargin.split("\n").mkString.trim
} else {
s"""jdbc:hive2://localhost:$serverPort/?$hiveConfList#$hiveVarList"""
}

def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*): Unit = {
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())

try {
statements.zip(fs).foreach { case (s, f) => f(s) }
} finally {
tableNames.foreach { name =>
// TODO: Need a better way to drop the view.
if (name.toUpperCase(Locale.ROOT).startsWith("VIEW")) {
statements.head.execute(s"DROP VIEW IF EXISTS $name")
} else {
statements.head.execute(s"DROP TABLE IF EXISTS $name")
}
}
statements.foreach(_.close())
connections.foreach(_.close())
}
}

def withDatabase(dbNames: String*)(fs: (Statement => Unit)*): Unit = {
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())

try {
statements.zip(fs).foreach { case (s, f) => f(s) }
} finally {
dbNames.foreach { name =>
statements.head.execute(s"DROP DATABASE IF EXISTS $name")
}
statements.foreach(_.close())
connections.foreach(_.close())
}
}

def withJdbcStatement(tableNames: String*)(f: Statement => Unit): Unit = {
withMultipleConnectionJdbcStatement(tableNames: _*)(f)
}

private def startThriftServer(attempt: Int): Unit = {
logInfo(s"Trying to start HiveThriftServer2:, attempt=$attempt")
logInfo(s"Trying to start HiveThriftServer2: mode: $mode, attempt=$attempt")
val sqlContext = spark.newSession().sqlContext
// Set the HIVE_SERVER2_THRIFT_PORT to 0, so it could randomly pick any free port to use.
// It's much more robust than set a random port generated by ourselves ahead
sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, "0")
sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT.varname, "0")
sqlContext.setConf(ConfVars.HIVE_SERVER2_TRANSPORT_MODE.varname, mode.toString)
sqlContext.setConf(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION.varname,
operationLogPath.getAbsolutePath)
sqlContext.setConf(ConfVars.LOCALSCRATCHDIR.varname, lScratchDir.getAbsolutePath)
extraConf.foreach { case (k, v) => sqlContext.setConf(k, v) }
hiveServer2 = HiveThriftServer2.startWithContext(sqlContext)
hiveServer2.getServices.asScala.foreach {
case t: ThriftCLIService if t.getPortNumber != 0 =>
case t: ThriftCLIService =>
if (t.getPortNumber == 0) {
Thread.sleep(3000)
}
serverPort = t.getPortNumber
logInfo(s"Started HiveThriftServer2: port=$serverPort, attempt=$attempt")
logInfo(s"Started HiveThriftServer2: mode: $mode port=$serverPort, attempt=$attempt")
case _ =>
}

// Wait for thrift server to be ready to serve the query, via executing simple query
// till the query succeeds. See SPARK-30345 for more details.
eventually(timeout(30.seconds), interval(1.seconds)) {
withJdbcStatement { _.execute("SELECT 1") }
withJdbcStatement() { _.execute("SELECT 1") }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package org.apache.spark.sql.hive.thriftserver

import java.sql.{DatabaseMetaData, ResultSet}

class SparkMetadataOperationSuite extends HiveThriftJdbcTest {

override def mode: ServerMode.Value = ServerMode.binary
class SparkMetadataOperationSuite extends SharedThriftServer {

test("Spark's own GetSchemasOperation(SparkGetSchemasOperation)") {
def checkResult(rs: ResultSet, dbNames: Seq[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,13 @@ import org.apache.thrift.transport.TSocket
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.unsafe.types.UTF8String

class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {

override def mode: ServerMode.Value = ServerMode.binary
class SparkThriftServerProtocolVersionsSuite extends SharedThriftServer {

def testExecuteStatementWithProtocolVersion(
version: ThriftserverShimUtils.TProtocolVersion,
sql: String)(f: HiveQueryResultSet => Unit): Unit = {
val rawTransport = new TSocket("localhost", serverPort)
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
transport.open()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
testCase: TestCase,
configSet: Seq[(String, String)]): Unit = {
// We do not test with configSet.
withJdbcStatement { statement =>
withJdbcStatement() { statement =>

configSet.foreach { case (k, v) =>
statement.execute(s"SET $k = $v")
Expand Down Expand Up @@ -236,7 +236,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
}

test("Check if ThriftServer can work") {
withJdbcStatement { statement =>
withJdbcStatement() { statement =>
val rs = statement.executeQuery("select 1L")
rs.next()
assert(rs.getLong(1) === 1L)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ class ThriftServerWithSparkContextSuite extends SharedThriftServer {
test("SPARK-29911: Uncache cached tables when session closed") {
val cacheManager = spark.sharedState.cacheManager
val globalTempDB = spark.sharedState.globalTempViewManager.database
withJdbcStatement { statement =>
withJdbcStatement() { statement =>
statement.execute("CACHE TABLE tempTbl AS SELECT 1")
}
// the cached data of local temporary view should be uncached
assert(cacheManager.isEmpty)
try {
withJdbcStatement { statement =>
withJdbcStatement() { statement =>
statement.execute("CREATE GLOBAL TEMP VIEW globalTempTbl AS SELECT 1, 2")
statement.execute(s"CACHE TABLE $globalTempDB.globalTempTbl")
}
// the cached data of global temporary view shouldn't be uncached
assert(!cacheManager.isEmpty)
} finally {
withJdbcStatement { statement =>
withJdbcStatement() { statement =>
statement.execute(s"UNCACHE TABLE IF EXISTS $globalTempDB.globalTempTbl")
}
assert(cacheManager.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,25 @@

package org.apache.spark.sql.hive.thriftserver

import scala.util.Random

import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest.{BeforeAndAfterAll, Matchers}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.scalatestplus.selenium.WebBrowser

import org.apache.spark.SparkConf
import org.apache.spark.ui.SparkUICssErrorHandler

class UISeleniumSuite
extends HiveThriftJdbcTest
class UISeleniumSuite extends SharedThriftServer
with WebBrowser with Matchers with BeforeAndAfterAll {

implicit var webDriver: WebDriver = _
var server: HiveThriftServer2 = _
val uiPort = 20000 + Random.nextInt(10000)
override def mode: ServerMode.Value = ServerMode.binary

override def sparkConf: SparkConf = {
super.sparkConf
.set("spark.ui.enabled", "true")
.set("spark.ui.port", "0")
}

override def beforeAll(): Unit = {
webDriver = new HtmlUnitDriver {
Expand All @@ -55,30 +54,9 @@ class UISeleniumSuite
}
}

override protected def serverStartCommand(port: Int) = {
val portConf = if (mode == ServerMode.binary) {
ConfVars.HIVE_SERVER2_THRIFT_PORT
} else {
ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
}

s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
| --hiveconf $portConf=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=true
| --conf spark.ui.port=$uiPort
""".stripMargin.split("\\s+").toSeq
}

ignore("thrift server ui test") {
test("thrift server ui test") {
withJdbcStatement("test_map") { statement =>
val baseURL = s"http://localhost:$uiPort"
val baseURL = spark.sparkContext.uiWebUrl.get

val queries = Seq(
"CREATE TABLE test_map(key INT, value STRING)",
Expand All @@ -92,13 +70,14 @@ class UISeleniumSuite
}

eventually(timeout(10.seconds), interval(50.milliseconds)) {
go to (baseURL + "/sql")
go to (baseURL + "/sqlserver")
find(id("sessionstat")) should not be None
find(id("sqlstat")) should not be None

// check whether statements exists
queries.foreach { line =>
findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line)
findAll(cssSelector("""table tbody tr td span"""))
.map(_.text).toList should contain (line)
}
}
}
Expand Down