diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9bd113419ae4..90cf15f9f722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -505,10 +505,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => - CSVDataSource.checkHeader( - firstLine, - new CsvParser(parsedOptions.asParserSettings), + val parser = new CsvParser(parsedOptions.asParserSettings) + val columnNames = parser.parseLine(firstLine) + CSVDataSource.checkHeaderColumnNames( actualSchema, + columnNames, csvDataset.getClass.getCanonicalName, parsedOptions.enforceSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b7b46c7c86a2..2b86054c0ffc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,7 +54,8 @@ abstract class CSVDataSource extends Serializable { requiredSchema: StructType, // Actual schema of data in the csv file dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -181,25 +182,6 @@ object CSVDataSource extends Logging { } } } - - /** - * Checks that CSV header contains the same column names as fields names in the given schema - * by taking into account case sensitivity. - */ - def checkHeader( - header: String, - parser: CsvParser, - schema: StructType, - fileName: String, - enforceSchema: Boolean, - caseSensitive: Boolean): Unit = { - checkHeaderColumnNames( - schema, - parser.parseLine(header), - fileName, - enforceSchema, - caseSensitive) - } } object TextInputCSVDataSource extends CSVDataSource { @@ -211,7 +193,8 @@ object TextInputCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) @@ -227,10 +210,11 @@ object TextInputCSVDataSource extends CSVDataSource { // Note: if there are only comments in the first block, the header would probably // be not extracted. CSVUtils.extractHeader(lines, parser.options).foreach { header => - CSVDataSource.checkHeader( - header, - parser.tokenizer, - dataSchema, + val schema = if (columnPruning) requiredSchema else dataSchema + val columnNames = parser.tokenizer.parseLine(header) + CSVDataSource.checkHeaderColumnNames( + schema, + columnNames, file.filePath, parser.options.enforceSchema, caseSensitive) @@ -308,10 +292,12 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { def checkHeader(header: Array[String]): Unit = { + val schema = if (columnPruning) requiredSchema else dataSchema CSVDataSource.checkHeaderColumnNames( - dataSchema, + schema, header, file.filePath, parser.options.enforceSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index d59b9820bdee..9aad0bd55e73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -131,6 +131,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { ) } val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -144,7 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { parser, requiredSchema, dataSchema, - caseSensitive) + caseSensitive, + columnPruning) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 456b4535a0dc..8f3d89289653 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1603,6 +1603,39 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + test("SPARK-25134: check header on parsing of dataset with projection and column pruning") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + Seq(false, true).foreach { multiLine => + withTempPath { path => + val dir = path.getAbsolutePath + Seq(("a", "b")).toDF("columnA", "columnB").write + .format("csv") + .option("header", true) + .save(dir) + + // schema with one column + checkAnswer(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .select("columnA"), + Row("a")) + + // empty schema + assert(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .count() === 1L) + } + } + } + } + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { withTempPath { path =>