Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPl
val relation = HadoopFsRelation(
table.fileIndex,
table.fileIndex.partitionSchema,
table.schema(),
table.schema,
None,
v1FileFormat,
d.options.asScala.toMap)(sparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

abstract class FileScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
readSchema: StructType) extends Scan with Batch {
readSchema: StructType,
options: CaseInsensitiveStringMap) extends Scan with Batch {
/**
* Returns whether a file with `path` could be split or not.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.spark.sql.types.StructType

abstract class FileScanBuilder(schema: StructType)
extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownFilters {
with SupportsPushDownRequiredColumns {
protected var readSchema = schema

override def pruneColumns(requiredSchema: StructType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils

abstract class FileTable(
sparkSession: SparkSession,
Expand Down Expand Up @@ -52,10 +53,15 @@ abstract class FileTable(
s"Unable to infer schema for $name. It must be specified manually.")
}.asNullable

override def schema(): StructType = {
override lazy val schema: StructType = {
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
SchemaUtils.checkColumnNameDuplication(dataSchema.fieldNames,
"in the data schema", caseSensitive)
val partitionSchema = fileIndex.partitionSchema
SchemaUtils.checkColumnNameDuplication(partitionSchema.fieldNames,
"in the partition schema", caseSensitive)
PartitioningUtils.mergeDataAndPartitionSchema(dataSchema,
fileIndex.partitionSchema, caseSensitive)._1
partitionSchema, caseSensitive)._1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.SerializableConfiguration

abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
Expand All @@ -60,10 +61,11 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St
}

override def buildForBatch(): BatchWrite = {
validateInputs()
val path = new Path(paths.head)
val sparkSession = SparkSession.active
validateInputs(sparkSession.sessionState.conf.caseSensitiveAnalysis)
val path = new Path(paths.head)
val optionsAsScala = options.asScala.toMap

val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala)
val job = getJobInstance(hadoopConf, path)
val committer = FileCommitProtocol.instantiate(
Expand Down Expand Up @@ -122,12 +124,20 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St
*/
def formatName: String

private def validateInputs(): Unit = {
private def validateInputs(caseSensitiveAnalysis: Boolean): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
assert(mode != null, "Missing save mode")
assert(paths.length == 1)

if (paths.length != 1) {
throw new IllegalArgumentException("Expected exactly one path to be specified, but " +
s"got: ${paths.mkString(", ")}")
}
val pathName = paths.head
SchemaUtils.checkColumnNameDuplication(schema.fields.map(_.name),
s"when inserting into $pathName", caseSensitiveAnalysis)
DataSource.validateSchema(schema)

schema.foreach { field =>
if (!supportsDataType(field.dataType)) {
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

case class OrcScan(
sparkSession: SparkSession,
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) {
readSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScan(sparkSession, fileIndex, readSchema, options) {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.sources.v2.reader.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -35,11 +35,12 @@ case class OrcScanBuilder(
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) {
options: CaseInsensitiveStringMap)
extends FileScanBuilder(schema) with SupportsPushDownFilters {
lazy val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)

override def build(): Scan = {
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema)
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema, options)
}

private var _pushedFilters: Array[Filter] = Array.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
checkReadUserSpecifiedDataColumnDuplication(
Seq((1, 1)).toDF("c0", "c1"), "parquet", c0, c1, src)
checkReadPartitionColumnDuplication("parquet", c0, c1, src)

// Check ORC format
checkWriteDataColumnDuplication("orc", c0, c1, src)
checkReadUserSpecifiedDataColumnDuplication(
Seq((1, 1)).toDF("c0", "c1"), "orc", c0, c1, src)
checkReadPartitionColumnDuplication("orc", c0, c1, src)
}
}
}
Expand Down