diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fe69f252d43e..72694463cedb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -505,7 +505,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + val linesWithoutHeader = if (parsedOptions.headerFlag && maybeFirstLine.isDefined) { + val firstLine = maybeFirstLine.get val parser = new CsvParser(parsedOptions.asParserSettings) val columnNames = parser.parseLine(firstLine) CSVDataSource.checkHeaderColumnNames( @@ -515,7 +516,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions.enforceSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) - }.getOrElse(filteredLines.rdd) + } else { + filteredLines.rdd + } val parsed = linesWithoutHeader.mapPartitions { iter => val rawParser = new UnivocityParser(actualSchema, parsedOptions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f70df0bcecde..5d4746cf90b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1820,4 +1820,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null)) assert(spark.read.csv(input).collect().toSet == Set(Row())) } + + test("field names of inferred schema shouldn't compare to the first row") { + val input = Seq("1,2").toDS() + val df = spark.read.option("enforceSchema", false).csv(input) + checkAnswer(df, Row("1", "2")) + } }