From e4f7205cf07ab9230f1546b70ff115d92184c8b7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 31 Oct 2018 18:36:58 +0800 Subject: [PATCH 1/3] Add GetSchemasOperation --- .../SparkGetSchemasOperation.scala | 100 +++++++++++++++++ .../server/SparkSQLOperationManager.scala | 17 ++- .../HiveThriftServer2Suites.scala | 16 +++ .../SparkMetadataOperationSuite.scala | 103 ++++++++++++++++++ 4 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala new file mode 100644 index 000000000000..daf68ea95e59 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.UUID + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.GetSchemasOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.catalog.SessionCatalog + +/** + * Spark's own GetSchemasOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + * @param catalogName catalog name. null if not applicable. + * @param schemaName database name, null or a concrete database name + */ +private[hive] class SparkGetSchemasOperation( + sqlContext: SQLContext, + parentSession: HiveSession, + catalogName: String, + schemaName: String) + extends GetSchemasOperation(parentSession, catalogName, schemaName) with Logging { + + val catalog: SessionCatalog = sqlContext.sessionState.catalog + + private final val RESULT_SET_SCHEMA = new TableSchema() + .addStringColumn("TABLE_SCHEM", "Schema name.") + .addStringColumn("TABLE_CATALOG", "Catalog name.") + + private val rowSet = RowSetFactory.create(RESULT_SET_SCHEMA, getProtocolVersion) + + private var statementId: String = _ + + override def close(): Unit = { + logInfo(s"Close get schemas with $statementId") + setState(OperationState.CLOSED) + } + + override def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + logInfo(s"Getting schemas with $statementId") + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + if (isAuthV2Enabled) { + val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" + authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr) + } + + try { + catalog.listDatabases(convertSchemaPattern(schemaName)).foreach { dbName => + rowSet.addRow(Array[AnyRef](dbName, "")) + } + setState(OperationState.FINISHED) + } catch { + case e: HiveSQLException => + setState(OperationState.ERROR) + throw e + } + } + + override def getNextRowSet(order: FetchOrientation, maxRows: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + if (order.equals(FetchOrientation.FETCH_FIRST)) { + rowSet.setStartOffset(0) + } + rowSet.extractSubset(maxRows.toInt) + } + + override def cancel(): Unit = { + logInfo(s"Cancel get schemas with $statementId") + setState(OperationState.CANCELED) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index bf7c01f60fb5..f0368d22b4d7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -21,13 +21,13 @@ import java.util.{Map => JMap} import java.util.concurrent.ConcurrentHashMap import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation} import org.apache.spark.sql.internal.SQLConf /** @@ -63,6 +63,19 @@ private[thriftserver] class SparkSQLOperationManager() operation } + override def newGetSchemasOperation( + parentSession: HiveSession, + catalogName: String, + schemaName: String): GetSchemasOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + s" initialized or had already closed.") + val operation = new SparkGetSchemasOperation(sqlContext, parentSession, catalogName, schemaName) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetSchemasOperation with session=$parentSession.") + operation + } + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { val iterator = confMap.entrySet().iterator() while (iterator.hasNext) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 70eb28cdd0c6..f9509aed4aaa 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -818,6 +818,22 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } + def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) { + val user = System.getProperty("user.name") + val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } + val statements = connections.map(_.createStatement()) + + try { + statements.zip(fs).foreach { case (s, f) => f(s) } + } finally { + dbNames.foreach { name => + statements(0).execute(s"DROP DATABASE IF EXISTS $name") + } + statements.foreach(_.close()) + connections.foreach(_.close()) + } + } + def withJdbcStatement(tableNames: String*)(f: Statement => Unit) { withMultipleConnectionJdbcStatement(tableNames: _*)(f) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala new file mode 100644 index 000000000000..9a997ae01df9 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.Properties + +import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils} +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket + +class SparkMetadataOperationSuite extends HiveThriftJdbcTest { + + override def mode: ServerMode.Value = ServerMode.binary + + test("Spark's own GetSchemasOperation(SparkGetSchemasOperation)") { + def testGetSchemasOperation( + catalog: String, + schemaPattern: String)(f: HiveQueryResultSet => Unit): Unit = { + val rawTransport = new TSocket("localhost", serverPort) + val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val client = new TCLIService.Client(new TBinaryProtocol(transport)) + transport.open() + var rs: HiveQueryResultSet = null + try { + val openResp = client.OpenSession(new TOpenSessionReq) + val sessHandle = openResp.getSessionHandle + val schemaReq = new TGetSchemasReq(sessHandle) + + if (catalog != null) { + schemaReq.setCatalogName(catalog) + } + + if (schemaPattern == null) { + schemaReq.setSchemaName("%") + } else { + schemaReq.setSchemaName(schemaPattern) + } + + val schemaResp = client.GetSchemas(schemaReq) + JdbcUtils.verifySuccess(schemaResp.getStatus) + + rs = new HiveQueryResultSet.Builder(connection) + .setClient(client) + .setSessionHandle(sessHandle) + .setStmtHandle(schemaResp.getOperationHandle) + .build() + f(rs) + } finally { + rs.close() + connection.close() + transport.close() + rawTransport.close() + } + } + + def checkResult(dbNames: Seq[String], rs: HiveQueryResultSet): Unit = { + if (dbNames.nonEmpty) { + for (i <- dbNames.indices) { + assert(rs.next()) + assert(rs.getString("TABLE_SCHEM") === dbNames(i)) + } + } else { + assert(!rs.next()) + } + } + + withDatabase("db1", "db2") { statement => + Seq("CREATE DATABASE db1", "CREATE DATABASE db2").foreach(statement.execute) + + testGetSchemasOperation(null, "%") { rs => + checkResult(Seq("db1", "db2"), rs) + } + testGetSchemasOperation(null, "db1") { rs => + checkResult(Seq("db1"), rs) + } + testGetSchemasOperation(null, "db_not_exist") { rs => + checkResult(Seq.empty, rs) + } + testGetSchemasOperation(null, "db*") { rs => + checkResult(Seq("db1", "db2"), rs) + } + } + } +} From ca3a76774d673b4cadc96c1a4e7c9aeda3944303 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 29 Dec 2018 15:52:45 +0800 Subject: [PATCH 2/3] address comment --- .../SparkGetSchemasOperation.scala | 22 +++---------------- .../server/SparkSQLOperationManager.scala | 2 +- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index daf68ea95e59..e890757eeb0a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.UUID - import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.GetSchemasOperation @@ -51,16 +49,7 @@ private[hive] class SparkGetSchemasOperation( private val rowSet = RowSetFactory.create(RESULT_SET_SCHEMA, getProtocolVersion) - private var statementId: String = _ - - override def close(): Unit = { - logInfo(s"Close get schemas with $statementId") - setState(OperationState.CLOSED) - } - override def runInternal(): Unit = { - statementId = UUID.randomUUID().toString - logInfo(s"Getting schemas with $statementId") setState(OperationState.RUNNING) // Always use the latest class loader provided by executionHive's state. val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader @@ -83,18 +72,13 @@ private[hive] class SparkGetSchemasOperation( } } - override def getNextRowSet(order: FetchOrientation, maxRows: Long): RowSet = { - validateDefaultFetchOrientation(order) + override def getNextRowSet(orientation: FetchOrientation, maxRows: Long): RowSet = { assertState(OperationState.FINISHED) + validateDefaultFetchOrientation(orientation) setHasResultSet(true) - if (order.equals(FetchOrientation.FETCH_FIRST)) { + if (orientation.equals(FetchOrientation.FETCH_FIRST)) { rowSet.setStartOffset(0) } rowSet.extractSubset(maxRows.toInt) } - - override def cancel(): Unit = { - logInfo(s"Cancel get schemas with $statementId") - setState(OperationState.CANCELED) - } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index f0368d22b4d7..85b6c7134755 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -69,7 +69,7 @@ private[thriftserver] class SparkSQLOperationManager() schemaName: String): GetSchemasOperation = synchronized { val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + - s" initialized or had already closed.") + " initialized or had already closed.") val operation = new SparkGetSchemasOperation(sqlContext, parentSession, catalogName, schemaName) handleToOperation.put(operation.getHandle, operation) logDebug(s"Created GetSchemasOperation with session=$parentSession.") From ecc4e0d0d4d7b16157d7dad27e4e59f9a081b477 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 3 Jan 2019 12:40:21 +0800 Subject: [PATCH 3/3] Make RowSet protected --- .../cli/operation/GetSchemasOperation.java | 2 +- .../SparkGetSchemasOperation.scala | 28 ++++--------------- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java index d6f6280f1c39..3516bc2ba242 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java @@ -41,7 +41,7 @@ public class GetSchemasOperation extends MetadataOperation { .addStringColumn("TABLE_SCHEM", "Schema name.") .addStringColumn("TABLE_CATALOG", "Catalog name."); - private RowSet rowSet; + protected RowSet rowSet; protected GetSchemasOperation(HiveSession parentSession, String catalogName, String schemaName) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index e890757eeb0a..d585049c28e3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.GetSchemasOperation +import org.apache.hive.service.cli.operation.MetadataOperation.DEFAULT_HIVE_CATALOG import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.catalog.SessionCatalog /** * Spark's own GetSchemasOperation @@ -39,15 +38,7 @@ private[hive] class SparkGetSchemasOperation( parentSession: HiveSession, catalogName: String, schemaName: String) - extends GetSchemasOperation(parentSession, catalogName, schemaName) with Logging { - - val catalog: SessionCatalog = sqlContext.sessionState.catalog - - private final val RESULT_SET_SCHEMA = new TableSchema() - .addStringColumn("TABLE_SCHEM", "Schema name.") - .addStringColumn("TABLE_CATALOG", "Catalog name.") - - private val rowSet = RowSetFactory.create(RESULT_SET_SCHEMA, getProtocolVersion) + extends GetSchemasOperation(parentSession, catalogName, schemaName) { override def runInternal(): Unit = { setState(OperationState.RUNNING) @@ -61,8 +52,9 @@ private[hive] class SparkGetSchemasOperation( } try { - catalog.listDatabases(convertSchemaPattern(schemaName)).foreach { dbName => - rowSet.addRow(Array[AnyRef](dbName, "")) + val schemaPattern = convertSchemaPattern(schemaName) + sqlContext.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName => + rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG)) } setState(OperationState.FINISHED) } catch { @@ -71,14 +63,4 @@ private[hive] class SparkGetSchemasOperation( throw e } } - - override def getNextRowSet(orientation: FetchOrientation, maxRows: Long): RowSet = { - assertState(OperationState.FINISHED) - validateDefaultFetchOrientation(orientation) - setHasResultSet(true) - if (orientation.equals(FetchOrientation.FETCH_FIRST)) { - rowSet.setStartOffset(0) - } - rowSet.extractSubset(maxRows.toInt) - } }