Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions dev/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def determine_modules_to_test(changed_modules):
['graphx', 'examples']
>>> x = [x.name for x in determine_modules_to_test([modules.sql])]
>>> x # doctest: +NORMALIZE_WHITESPACE
... # doctest: +SKIP
['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver',
'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml']
"""
Expand All @@ -123,15 +122,9 @@ def determine_modules_to_test(changed_modules):
# If we need to run all of the tests, then we should short-circuit and return 'root'
if modules.root in modules_to_test:
return [modules.root]
changed_modules = toposort_flatten(
return toposort_flatten(
{m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True)

# TODO: Skip hive-thriftserver module for hadoop-3.2. remove this once hadoop-3.2 support it
if modules.hadoop_version == "hadoop3.2":
changed_modules = [m for m in changed_modules if m.name != "hive-thriftserver"]

return changed_modules


def determine_tags_to_exclude(changed_modules):
tags = []
Expand Down
11 changes: 0 additions & 11 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
# limitations under the License.
#

from __future__ import print_function
from functools import total_ordering
import itertools
import re
import os

all_modules = []

Expand Down Expand Up @@ -558,15 +556,6 @@ def __hash__(self):
]
)

# TODO: Skip hive-thriftserver module for hadoop-3.2. remove this once hadoop-3.2 support it
if os.environ.get("AMPLAB_JENKINS"):
hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop2.7")
else:
hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.7")
if hadoop_version == "hadoop3.2":
print("[info] Skip unsupported module:", "hive-thriftserver")
all_modules = [m for m in all_modules if m.name != "hive-thriftserver"]

# The root module is a dummy module which is used to run all of the tests.
# No other modules should directly depend on this module.
root = Module(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ private[hive] class SparkExecuteStatementOperation(
validateDefaultFetchOrientation(order)
assertState(OperationState.FINISHED)
setHasResultSet(true)
val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion)
val resultRowSet: RowSet =
ThriftserverShimUtils.resultRowSet(getResultSetSchema, getProtocolVersion)

// Reset iter to header when fetching start from first row
if (order.equals(FetchOrientation.FETCH_FIRST)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.hive.HiveUtils

/**
* Spark's own GetTablesOperation
Expand Down Expand Up @@ -83,7 +84,12 @@ private[hive] class SparkGetTablesOperation(
catalogTable.identifier.table,
tableType,
catalogTable.comment.getOrElse(""))
rowSet.addRow(rowData)
// Since HIVE-7575(Hive 2.0.0), adds 5 additional columns to the ResultSet of GetTables.
if (HiveUtils.isHive23) {
rowSet.addRow(rowData ++ Array(null, null, null, null, null))
} else {
rowSet.addRow(rowData)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.collection.JavaConverters._
import jline.console.ConsoleReader
import jline.console.history.FileHistory
import org.apache.commons.lang3.StringUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
import org.apache.hadoop.hive.common.HiveInterruptUtils
Expand Down Expand Up @@ -297,9 +296,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
private val sessionState = SessionState.get().asInstanceOf[CliSessionState]

private val LOG = LogFactory.getLog(classOf[SparkSQLCLIDriver])

private val console = new SessionState.LogHelper(LOG)
private val console = ThriftserverShimUtils.getConsole

private val isRemoteMode = {
SparkSQLCLIDriver.isRemoteMode(sessionState)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import org.apache.hive.service.Service.STATE
import org.apache.hive.service.auth.HiveAuthFactory
import org.apache.hive.service.cli._
import org.apache.hive.service.server.HiveServer2
import org.slf4j.Logger

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._

private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLContext)
Expand Down Expand Up @@ -112,6 +114,10 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
setAncestorField(this, 3, "hiveConf", hiveConf)
invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
if (HiveUtils.isHive23) {
getAncestorField[Logger](this, 3, "LOG").info(s"Service: $getName is inited.")
} else {
getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.SessionHandle
import org.apache.hive.service.cli.session.SessionManager
import org.apache.hive.service.cli.thrift.TProtocolVersion
import org.apache.hive.service.server.HiveServer2

import org.apache.spark.sql.SQLContext
Expand All @@ -45,7 +44,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext:
}

override def openSession(
protocol: TProtocolVersion,
protocol: ThriftserverShimUtils.TProtocolVersion,
username: String,
passwd: String,
ipAddress: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper
import org.apache.hive.service.cli.{FetchOrientation, FetchType, GetInfoType}
import org.apache.hive.service.cli.thrift.TCLIService.Client
import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket
Expand Down Expand Up @@ -66,7 +65,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val protocol = new TBinaryProtocol(transport)
val client = new ThriftCLIServiceClient(new Client(protocol))
val client = new ThriftCLIServiceClient(new ThriftserverShimUtils.Client(protocol))

transport.open()
try f(client) finally transport.close()
Expand Down Expand Up @@ -536,7 +535,11 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
conf += resultSet.getString(1) -> resultSet.getString(2)
}

assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
if (HiveUtils.isHive23) {
assert(conf.get("spark.sql.hive.version") === Some("2.3.5"))
} else {
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
}
}
}

Expand All @@ -549,7 +552,11 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
conf += resultSet.getString(1) -> resultSet.getString(2)
}

assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
if (HiveUtils.isHive23) {
assert(conf.get("spark.sql.hive.version") === Some("2.3.5"))
} else {
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
}
}
}

Expand Down Expand Up @@ -627,7 +634,11 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
val sessionHandle = client.openSession(user, "")
val sessionID = sessionHandle.getSessionId

assert(pipeoutFileList(sessionID).length == 1)
if (HiveUtils.isHive23) {
assert(pipeoutFileList(sessionID).length == 2)
} else {
assert(pipeoutFileList(sessionID).length == 1)
}

client.closeSession(sessionHandle)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite {
val tableSchema = StructType(Seq(field1, field2))
val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
assert(columns.size() == 2)
assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
assert(columns.get(0).getType().getName == "VOID")
assert(columns.get(1).getType().getName == "VOID")
}

test("SPARK-20146 Comment should be preserved") {
Expand All @@ -37,9 +37,9 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite {
val tableSchema = StructType(Seq(field1, field2))
val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
assert(columns.size() == 2)
assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
assert(columns.get(0).getType().getName == "STRING")
assert(columns.get(0).getComment() == "comment 1")
assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
assert(columns.get(1).getType().getName == "INT")
assert(columns.get(1).getComment() == "")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ package org.apache.spark.sql.hive.thriftserver

import java.util.{Arrays => JArrays, List => JList, Properties}

import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils}
import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet}
import org.apache.hive.service.auth.PlainSaslHelper
import org.apache.hive.service.cli.thrift._
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket

Expand All @@ -37,13 +36,13 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val client = new TCLIService.Client(new TBinaryProtocol(transport))
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
transport.open()
var rs: HiveQueryResultSet = null
try {
val openResp = client.OpenSession(new TOpenSessionReq)
val openResp = client.OpenSession(new ThriftserverShimUtils.TOpenSessionReq)
val sessHandle = openResp.getSessionHandle
val schemaReq = new TGetSchemasReq(sessHandle)
val schemaReq = new ThriftserverShimUtils.TGetSchemasReq(sessHandle)

if (catalog != null) {
schemaReq.setCatalogName(catalog)
Expand All @@ -55,13 +54,10 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
schemaReq.setSchemaName(schemaPattern)
}

val schemaResp = client.GetSchemas(schemaReq)
JdbcUtils.verifySuccess(schemaResp.getStatus)

rs = new HiveQueryResultSet.Builder(connection)
.setClient(client)
.setSessionHandle(sessHandle)
.setStmtHandle(schemaResp.getOperationHandle)
.setStmtHandle(client.GetSchemas(schemaReq).getOperationHandle)
.build()
f(rs)
} finally {
Expand Down Expand Up @@ -110,28 +106,24 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val client = new TCLIService.Client(new TBinaryProtocol(transport))
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
transport.open()

var rs: HiveQueryResultSet = null

try {
val openResp = client.OpenSession(new TOpenSessionReq)
val openResp = client.OpenSession(new ThriftserverShimUtils.TOpenSessionReq)
val sessHandle = openResp.getSessionHandle

val getTableReq = new TGetTablesReq(sessHandle)
val getTableReq = new ThriftserverShimUtils.TGetTablesReq(sessHandle)
getTableReq.setSchemaName(schema)
getTableReq.setTableName(tableNamePattern)
getTableReq.setTableTypes(tableTypes)

val getTableResp = client.GetTables(getTableReq)

JdbcUtils.verifySuccess(getTableResp.getStatus)

rs = new HiveQueryResultSet.Builder(connection)
.setClient(client)
.setSessionHandle(sessHandle)
.setStmtHandle(getTableResp.getOperationHandle)
.setStmtHandle(client.GetTables(getTableReq).getOperationHandle)
.build()

f(rs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.thriftserver

import org.apache.commons.logging.LogFactory
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema}

/**
* Various utilities for hive-thriftserver used to upgrade the built-in Hive.
*/
private[thriftserver] object ThriftserverShimUtils {

private[thriftserver] type TProtocolVersion = org.apache.hive.service.cli.thrift.TProtocolVersion
private[thriftserver] type Client = org.apache.hive.service.cli.thrift.TCLIService.Client
private[thriftserver] type TOpenSessionReq = org.apache.hive.service.cli.thrift.TOpenSessionReq
private[thriftserver] type TGetSchemasReq = org.apache.hive.service.cli.thrift.TGetSchemasReq
private[thriftserver] type TGetTablesReq = org.apache.hive.service.cli.thrift.TGetTablesReq

private[thriftserver] def getConsole: SessionState.LogHelper = {
val LOG = LogFactory.getLog(classOf[SparkSQLCLIDriver])
new SessionState.LogHelper(LOG)
}

private[thriftserver] def resultRowSet(
getResultSetSchema: TableSchema,
getProtocolVersion: TProtocolVersion): RowSet = {
RowSetFactory.create(getResultSetSchema, getProtocolVersion)
}

}
Loading