Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 6 additions & 42 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql

import java.util.{Locale, Properties, ServiceConfigurationError}
import java.util.{Locale, Properties}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

val isUserDefinedDataSource =
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ class SparkSession private(
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
source.getDeclaredConstructor().newInstance()
DataSource.newDataSourceInstance(runner, source)
.asInstanceOf[ExternalCommandRunner], command, options))

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1022,7 +1022,9 @@ object DDLUtils extends Logging {

def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
DataSource.newDataSourceInstance(
provider,
DataSource.lookupDataSource(provider, SQLConf.get))
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data column names.", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand(
}

if (DDLUtils.isDatasourceTable(catalogTable)) {
DataSource.lookupDataSource(catalogTable.provider.get, conf).
getConstructor().newInstance() match {
DataSource.newDataSourceInstance(
catalogTable.provider.get,
DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
// For datasource table, this command can only support the following File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic will not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ case class DataSource(
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
// instead of `providingClass`.
cls.getDeclaredConstructor().newInstance() match {
DataSource.newDataSourceInstance(className, cls) match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}

private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()
private[sql] def providingInstance(): Any =
DataSource.newDataSourceInstance(className, providingClass)

private def newHadoopConfiguration(): Configuration =
sparkSession.sessionState.newHadoopConfWithOptions(options)
Expand Down Expand Up @@ -622,6 +623,15 @@ object DataSource extends Logging {
"org.apache.spark.sql.sources.HadoopFsRelationProvider",
"org.apache.spark.Logging")

/** Create the instance of the datasource */
def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = {
providingClass match {
case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
case cls => cls.getDeclaredConstructor().newInstance()
}
}

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
Expand Down Expand Up @@ -649,6 +659,9 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
Expand All @@ -657,6 +670,8 @@ object DataSource extends Logging {
throw QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
classOf[PythonTableProvider]
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand All @@ -673,6 +688,14 @@ object DataSource extends Logging {
}
case head :: Nil =>
// there is exactly one registered alias
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(provider)
}
head.getClass
case sources =>
// There are multiple registered aliases for the input. If there is single datasource
Expand Down Expand Up @@ -708,17 +731,18 @@ object DataSource extends Logging {
def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = {
val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
val cls = lookupDataSource(provider, conf)
val providingClass = lookupDataSource(provider, conf)
val instance = try {
cls.getDeclaredConstructor().newInstance()
newDataSourceInstance(provider, providingClass)
} catch {
// Throw the original error from the data source implementation.
case e: java.lang.reflect.InvocationTargetException => throw e.getCause
}
instance match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => None
case t: TableProvider
if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
if !useV1Sources.contains(
providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
Some(t)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,22 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import scala.jdk.CollectionConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap


/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
Expand All @@ -40,6 +50,8 @@ class DataSourceManager extends Logging {
CaseInsensitiveMap[String] // options
) => LogicalPlan

// TODO(SPARK-45917): Statically load Python Data Source so idempotently Python
// Data Sources can be loaded even when the Driver is restarted.
private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
Expand Down Expand Up @@ -81,3 +93,60 @@ class DataSourceManager extends Logging {
manager
}
}

/**
* Data Source V2 wrapper for Python Data Source.
*/
class PythonTableProvider(shortName: String) extends TableProvider {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support external metadata for this data source? I.e users can create a table using a python datasource with user defined table schema.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it already does (?).

private var sourceDataFrame: DataFrame = _

private def getOrCreateSourceDataFrame(
options: CaseInsensitiveStringMap, maybeSchema: Option[StructType]): DataFrame = {
if (sourceDataFrame != null) return sourceDataFrame
// TODO(SPARK-45600): should be session-based.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should be fixed?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For basic support, I think so. The thing is that we should take a look into session inheritance, testcase, etc. So I leave this as a todo for now.

val builder = SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)
val plan = builder(
SparkSession.active,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it get the correct session for spark connect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

shortName,
maybeSchema,
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap))
sourceDataFrame = Dataset.ofRows(SparkSession.active, plan)
sourceDataFrame
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType =
getOrCreateSourceDataFrame(options, None).schema

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
val givenSchema = schema
new Table with SupportsRead {
override def name(): String = shortName

override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new ScanBuilder with V1Scan {
override def build(): Scan = this
override def toV1TableScan[T <: BaseRelation with TableScan](
context: SQLContext): T = {
new BaseRelation with TableScan {
// Avoid Row <> InternalRow conversion
override val needConversion: Boolean = false
override def buildScan(): RDD[Row] =
getOrCreateSourceDataFrame(options, Some(givenSchema))
.queryExecution.toRdd.asInstanceOf[RDD[Row]]
override def schema: StructType = givenSchema
override def sqlContext: SQLContext = context
}.asInstanceOf[T]
}
override def readSchema(): StructType = givenSchema
}
}

override def schema(): StructType = givenSchema
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
extraOptions + ("path" -> path.get)
}

val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf).
getConstructor().newInstance()
val ds = DataSource.newDataSourceInstance(
source,
DataSource.lookupDataSource(source, sparkSession.sessionState.conf))
// We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2, since we don't know the
// writer or whether the query is continuous.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}

val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) {
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
val provider = DataSource.newDataSourceInstance(source, cls).asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = df.sparkSession.sessionState.conf)
val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))

// Test SQL
withTable("tblA") {
sql("CREATE TABLE tblA USING test")
// The path will be the actual temp path.
checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1)))
}
}

test("reader not implemented") {
Expand Down