Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,39 @@ package org.apache.spark.sql.execution.datasources.v2

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}

abstract class FileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex) extends Scan with Batch {
fileIndex: PartitioningAwareFileIndex,
readSchema: StructType) extends Scan with Batch {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang, why validate the read schema here in FileScan instead of in the scan builder?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR description:

the table schema is determined in TableProvider.getTable. The actual read schema can be a subset of the table schema. This PR proposes to validate the actual read schema in FileScan

/**
* Returns whether a file with `path` could be split or not.
*/
def isSplitable(path: Path): Boolean = {
false
}

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
Expand All @@ -57,5 +74,13 @@ abstract class FileScan(
partitions.toArray
}

override def toBatch: Batch = this
override def toBatch: Batch = {
readSchema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration

abstract class FileWriteBuilder(options: DataSourceOptions)
Expand Down Expand Up @@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory

/**
* Returns whether this format supports the given [[DataType]] in write path.
* By default all data types are supported.
*/
def supportsDataType(dataType: DataType): Boolean = true

/**
* The string that represents the format that this data source provider uses. This is
* overridden by children to provide a nice alias for the data source. For example:
*
* {{{
* override def formatName(): String = "ORC"
* }}}
*/
def formatName: String

private def validateInputs(): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
assert(mode != null, "Missing save mode")
assert(options.paths().length == 1)
DataSource.validateSchema(schema)
schema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
}
}
}

private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._

class OrcDataSourceV2 extends FileDataSourceV2 {

Expand All @@ -44,3 +44,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
OrcTable(tableName, sparkSession, fileIndex, Some(schema))
}
}

object OrcDataSourceV2 {
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true

case st: StructType => st.forall { f => supportsDataType(f.dataType) }

case ArrayType(elementType, _) => supportsDataType(elementType)

case MapType(keyType, valueType, _) =>
supportsDataType(keyType) && supportsDataType(valueType)

case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)

case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration

case class OrcScan(
sparkSession: SparkSession,
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readSchema: StructType) extends FileScan(sparkSession, fileIndex) {
readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand All @@ -40,4 +40,10 @@ case class OrcScan(
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio
}
}
}

override def supportsDataType(dataType: DataType): Boolean = {
OrcDataSourceV2.supportsDataType(dataType)
}

override def formatName: String = "ORC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
// write path
Seq("csv", "json", "parquet", "orc").foreach { format =>
var msg = intercept[AnalysisException] {
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.contains("Cannot save interval data type into external storage."))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new IntervalData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))
Seq(true, false).foreach { useV1 =>
val useV1List = if (useV1) {
"orc"
} else {
""
}
def errorMessage(format: String, isWrite: Boolean): String = {
if (isWrite && (useV1 || format != "orc")) {
"cannot save interval data type into external storage."
} else {
s"$format data source does not support calendarinterval data type."
}
}

withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
// write path
Seq("csv", "json", "parquet", "orc").foreach { format =>
var msg = intercept[AnalysisException] {
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true)))
}

// read path
Seq("parquet", "csv").foreach { format =>
var msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support calendarinterval data type."))
// read path
Seq("parquet", "csv").foreach { format =>
var msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
}
}
}
}
}

test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") {
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath

Seq("parquet", "csv", "orc").foreach { format =>
// write path
var msg = intercept[AnalysisException] {
sql("select null").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new NullData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

// read path
msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", NullType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(s"$format data source does not support null data type."))
Seq(true, false).foreach { useV1 =>
val useV1List = if (useV1) {
"orc"
} else {
""
}
def errorMessage(format: String): String = {
s"$format data source does not support null data type."
}
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List,
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath

Seq("parquet", "csv", "orc").foreach { format =>
// write path
var msg = intercept[AnalysisException] {
sql("select null").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new NullData())
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

// read path
msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", NullType, true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))

msg = intercept[AnalysisException] {
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
spark.read.schema(schema).format(format).load(tempDir).collect()
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
.contains(errorMessage(format)))
}
}
}
}
Expand Down