Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ package object config {
private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation")
.internal()
.stringConf
.checkValues(Set("hive", "in-memory"))
.checkValues(Set("hive", "in-memory", "provided"))
.createWithDefault("in-memory")

private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ object Main extends Logging {
}

val builder = SparkSession.builder.config(conf)
if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") {
if (conf.get(CATALOG_IMPLEMENTATION.key, "").toLowerCase == "provided") {
sparkSession = builder.enableProvidedCatalog().getOrCreate()
logInfo("Created Spark session with provided external catalog")
} else if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") {
if (SparkSession.hiveClassesArePresent) {
// In the case that the property is not set at all, builder's config
// does not have this value set to 'hive' yet. The original default
Expand Down
19 changes: 14 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, LongType, StructType}
Expand Down Expand Up @@ -772,6 +772,16 @@ object SparkSession {
}
}

/**
* Enables the use of provided ExternalCatalog and SessionState classes.
*
* @since 2.1.0
*/
def enableProvidedCatalog(): Builder = synchronized {
// Assume that the classes exit in classpath.
config(CATALOG_IMPLEMENTATION.key, "provided")
}

/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
Expand Down Expand Up @@ -910,12 +920,11 @@ object SparkSession {
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"

private def sessionStateClassName(conf: SparkConf): String = {
conf.get(CATALOG_IMPLEMENTATION) match {
case "hive" => HIVE_SESSION_STATE_CLASS_NAME
case "hive" => SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME.defaultValueString
case "in-memory" => classOf[SessionState].getCanonicalName
case "provided" => conf.get(SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME)
}
}

Expand All @@ -941,7 +950,7 @@ object SparkSession {
*/
private[spark] def hiveClassesArePresent: Boolean = {
try {
Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME)
Utils.classForName(SQLConf.EXTERNAL_SESSION_STATE_CLASS_NAME.defaultValueString)
Utils.classForName("org.apache.hadoop.hive.conf.HiveConf")
true
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ object SQLConf {
.stringConf
.createWithDefault("${system:user.dir}/spark-warehouse")

val EXTERNAL_CATALOG_CLASS_NAME = ConfigBuilder("spark.sql.externalCatalog")
.internal()
.stringConf
.createWithDefault("org.apache.spark.sql.hive.HiveExternalCatalog")

val EXTERNAL_SESSION_STATE_CLASS_NAME = ConfigBuilder("spark.sql.externalSessionState")
.internal()
.stringConf
.createWithDefault("org.apache.spark.sql.hive.HiveSessionState")

val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations")
.internal()
.doc("The max number of iterations the optimizer and analyzer runs.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,11 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {

object SharedState {

private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog"

private def externalCatalogClassName(conf: SparkConf): String = {
conf.get(CATALOG_IMPLEMENTATION) match {
case "hive" => HIVE_EXTERNAL_CATALOG_CLASS_NAME
case "hive" => SQLConf.EXTERNAL_CATALOG_CLASS_NAME.defaultValueString
case "in-memory" => classOf[InMemoryCatalog].getCanonicalName
case "provided" => conf.get(SQLConf.EXTERNAL_CATALOG_CLASS_NAME)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.catalog._

/**
* Test cases for the builder pattern of [[SparkSession]].
Expand Down Expand Up @@ -123,4 +126,132 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
session.stop()
}
}

test("SPARK-17767 Spark SQL ExternalCatalog API custom implementation support") {
val session = SparkSession.builder()
.master("local")
.config("spark.sql.externalCatalog", "org.apache.spark.sql.MyExternalCatalog")
.config("spark.sql.externalSessionState", "org.apache.spark.sql.MySessionState")
.enableProvidedCatalog()
.getOrCreate()
assert(session.sharedState.externalCatalog.isInstanceOf[MyExternalCatalog])
assert(session.sessionState.isInstanceOf[MySessionState])
session.stop()
}
}

class MyExternalCatalog(conf: SparkConf, hadoopConf: Configuration) extends ExternalCatalog {
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec

def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {}

def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {}

def alterDatabase(dbDefinition: CatalogDatabase): Unit = {}

def getDatabase(db: String): CatalogDatabase = null

def databaseExists(db: String): Boolean = true

def listDatabases(): Seq[String] = Seq.empty

def listDatabases(pattern: String): Seq[String] = Seq.empty

def setCurrentDatabase(db: String): Unit = {}

def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {}

def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit = {}

def renameTable(db: String, oldName: String, newName: String): Unit = {}

def alterTable(tableDefinition: CatalogTable): Unit = {}

def getTable(db: String, table: String): CatalogTable = null

def getTableOption(db: String, table: String): Option[CatalogTable] = None

def tableExists(db: String, table: String): Boolean = true

def listTables(db: String): Seq[String] = Seq.empty

def listTables(db: String, pattern: String): Seq[String] = Seq.empty

def loadTable(
db: String,
table: String,
loadPath: String,
isOverwrite: Boolean,
holdDDLTime: Boolean): Unit = {}

def loadPartition(
db: String,
table: String,
loadPath: String,
partition: TablePartitionSpec,
isOverwrite: Boolean,
holdDDLTime: Boolean,
inheritTableSpecs: Boolean): Unit = {}

def loadDynamicPartitions(
db: String,
table: String,
loadPath: String,
partition: TablePartitionSpec,
replace: Boolean,
numDP: Int,
holdDDLTime: Boolean): Unit = {}

def createPartitions(
db: String,
table: String,
parts: Seq[CatalogTablePartition],
ignoreIfExists: Boolean): Unit = {}

def dropPartitions(
db: String,
table: String,
parts: Seq[TablePartitionSpec],
ignoreIfNotExists: Boolean,
purge: Boolean): Unit = {}

def renamePartitions(
db: String,
table: String,
specs: Seq[TablePartitionSpec],
newSpecs: Seq[TablePartitionSpec]): Unit = {}

def alterPartitions(
db: String,
table: String,
parts: Seq[CatalogTablePartition]): Unit = {}

def getPartition(db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition =
null

def getPartitionOption(
db: String,
table: String,
spec: TablePartitionSpec): Option[CatalogTablePartition] = None

def listPartitions(
db: String,
table: String,
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = Seq.empty

def createFunction(db: String, funcDefinition: CatalogFunction): Unit = {}

def dropFunction(db: String, funcName: String): Unit = {}

def renameFunction(db: String, oldName: String, newName: String): Unit = {}

def getFunction(db: String, funcName: String): CatalogFunction = null

def functionExists(db: String, funcName: String): Boolean = true

def listFunctions(db: String, pattern: String): Seq[String] = Seq.empty
}

class MySessionState(sparkSession: SparkSession)
extends org.apache.spark.sql.internal.SessionState(sparkSession) {
}