diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d536cc5097b2..24897a591a13 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -95,7 +95,7 @@ package object config { private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation") .internal() .stringConf - .checkValues(Set("hive", "in-memory")) + .checkValues(Set("hive", "in-memory", "provided")) .createWithDefault("in-memory") private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 5dfe18ad4982..b67f4614f710 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -87,7 +87,10 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (conf.get(CATALOG_IMPLEMENTATION.key, "").toLowerCase == "provided") { + sparkSession = builder.enableProvidedCatalog().getOrCreate() + logInfo("Created Spark session with provided external catalog") + } else if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { if (SparkSession.hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6d7ac0f6c1bb..e389858f7dee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} @@ -772,6 +772,16 @@ object SparkSession { } } + /** + * Enables the use of provided ExternalCatalog and SessionState classes. + * + * @since 2.1.0 + */ + def enableProvidedCatalog(): Builder = synchronized { + // Assume that the classes exit in classpath. + config(CATALOG_IMPLEMENTATION.key, "provided") + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -910,12 +920,11 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" - private def sessionStateClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_SESSION_STATE_CLASS_NAME + case "hive" => SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME.defaultValueString case "in-memory" => classOf[SessionState].getCanonicalName + case "provided" => conf.get(SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME) } } @@ -941,7 +950,7 @@ object SparkSession { */ private[spark] def hiveClassesArePresent: Boolean = { try { - Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) + Utils.classForName(SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME.defaultValueString) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fecdf792fd14..e609dc6f037a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -58,6 +58,16 @@ object SQLConf { .stringConf .createWithDefault("${system:user.dir}/spark-warehouse") + val EXTERNAL_CATALOG_CLASS_NAME = ConfigBuilder("spark.sql.externalCatalog") + .internal() + .stringConf + .createWithDefault("org.apache.spark.sql.hive.HiveExternalCatalog") + + val EXTERNAL_SESSION_STATE_CLASS_NAME = ConfigBuilder("spark.sql.externalSessionState") + .internal() + .stringConf + .createWithDefault("org.apache.spark.sql.hive.HiveSessionState") + val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6387f0150631..b6b0af6e6f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -110,12 +110,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { object SharedState { - private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" - private def externalCatalogClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_EXTERNAL_CATALOG_CLASS_NAME + case "hive" => SQLConf.EXTERNAL_CATALOG_CLASS_NAME.defaultValueString case "in-memory" => classOf[InMemoryCatalog].getCanonicalName + case "provided" => conf.get(SQLConf.EXTERNAL_CATALOG_CLASS_NAME) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 386d13d07a95..8df0a065d870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql +import org.apache.hadoop.conf.Configuration + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.catalog._ /** * Test cases for the builder pattern of [[SparkSession]]. @@ -123,4 +126,132 @@ class SparkSessionBuilderSuite extends SparkFunSuite { session.stop() } } + + test("SPARK-17767 Spark SQL ExternalCatalog API custom implementation support") { + val session = SparkSession.builder() + .master("local") + .config("spark.sql.externalCatalog", "org.apache.spark.sql.MyExternalCatalog") + .config("spark.sql.externalSessionState", "org.apache.spark.sql.MySessionState") + .enableProvidedCatalog() + .getOrCreate() + assert(session.sharedState.externalCatalog.isInstanceOf[MyExternalCatalog]) + assert(session.sessionState.isInstanceOf[MySessionState]) + session.stop() + } +} + +class MyExternalCatalog(conf: SparkConf, hadoopConf: Configuration) extends ExternalCatalog { + import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {} + + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {} + + def alterDatabase(dbDefinition: CatalogDatabase): Unit = {} + + def getDatabase(db: String): CatalogDatabase = null + + def databaseExists(db: String): Boolean = true + + def listDatabases(): Seq[String] = Seq.empty + + def listDatabases(pattern: String): Seq[String] = Seq.empty + + def setCurrentDatabase(db: String): Unit = {} + + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {} + + def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit = {} + + def renameTable(db: String, oldName: String, newName: String): Unit = {} + + def alterTable(tableDefinition: CatalogTable): Unit = {} + + def getTable(db: String, table: String): CatalogTable = null + + def getTableOption(db: String, table: String): Option[CatalogTable] = None + + def tableExists(db: String, table: String): Boolean = true + + def listTables(db: String): Seq[String] = Seq.empty + + def listTables(db: String, pattern: String): Seq[String] = Seq.empty + + def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + holdDDLTime: Boolean): Unit = {} + + def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean): Unit = {} + + def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean): Unit = {} + + def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = {} + + def dropPartitions( + db: String, + table: String, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = {} + + def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = {} + + def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = {} + + def getPartition(db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition = + null + + def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = None + + def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = Seq.empty + + def createFunction(db: String, funcDefinition: CatalogFunction): Unit = {} + + def dropFunction(db: String, funcName: String): Unit = {} + + def renameFunction(db: String, oldName: String, newName: String): Unit = {} + + def getFunction(db: String, funcName: String): CatalogFunction = null + + def functionExists(db: String, funcName: String): Boolean = true + + def listFunctions(db: String, pattern: String): Seq[String] = Seq.empty +} + +class MySessionState(sparkSession: SparkSession) + extends org.apache.spark.sql.internal.SessionState(sparkSession) { }