diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 448a4732001b..a0e20d39c20d 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
- samplingRatio=None):
+ samplingRatio=None, enforceSchema=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
+ :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
+ forcibly applied to datasource files, and headers in CSV files will be
+ ignored. If the option is set to ``false``, the schema will be
+ validated against all headers in CSV files or the first header in RDD
+ if the ``header`` option is set to ``true``. Field names in the schema
+ and column names in CSV headers are checked by their positions
+ taking into account ``spark.sql.caseSensitive``. If None is set,
+ ``true`` is used by default. Though the default value is ``true``,
+ it is recommended to disable the ``enforceSchema`` option
+ to avoid incorrect results.
:param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
@@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
+ enforceSchema=enforceSchema)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 15f940738986..fae50b3d5d53 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
- columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
+ columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
+ enforceSchema=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
+ :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
+ forcibly applied to datasource files, and headers in CSV files will be
+ ignored. If the option is set to ``false``, the schema will be
+ validated against all headers in CSV files or the first header in RDD
+ if the ``header`` option is set to ``true``. Field names in the schema
+ and column names in CSV headers are checked by their positions
+ taking into account ``spark.sql.caseSensitive``. If None is set,
+ ``true`` is used by default. Though the default value is ``true``,
+ it is recommended to disable the ``enforceSchema`` option
+ to avoid incorrect results.
:param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
@@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a2450932e303..ea2dd7605dc5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self):
.csv(rdd, samplingRatio=0.5).schema
self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
+ def test_checking_csv_header(self):
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+ try:
+ self.spark.createDataFrame([[1, 1000], [2000, 2]])\
+ .toDF('f1', 'f2').write.option("header", "true").csv(path)
+ schema = StructType([
+ StructField('f2', IntegerType(), nullable=True),
+ StructField('f1', IntegerType(), nullable=True)])
+ df = self.spark.read.option('header', 'true').schema(schema)\
+ .csv(path, enforceSchema=False)
+ self.assertRaisesRegexp(
+ Exception,
+ "CSV header does not conform to the schema",
+ lambda: df.collect())
+ finally:
+ shutil.rmtree(path)
+
class HiveSparkSubmitTests(SparkSubmitTests):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ac4580a0919a..de6be5f76e15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -22,6 +22,7 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._
import com.fasterxml.jackson.databind.ObjectMapper
+import com.univocity.parsers.csv.CsvParser
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
@@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* it determines the columns as string types and it reads only the first line to determine the
* names and the number of fields.
*
+ * If the enforceSchema is set to `false`, only the CSV header in the first line is checked
+ * to conform specified or inferred schema.
+ *
* @param csvDataset input Dataset with one CSV row per record
* @since 2.2.0
*/
@@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+ CSVDataSource.checkHeader(
+ firstLine,
+ new CsvParser(parsedOptions.asParserSettings),
+ actualSchema,
+ csvDataset.getClass.getCanonicalName,
+ parsedOptions.enforceSchema,
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
@@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`comment` (default empty string): sets a single character used for skipping lines
* beginning with this character. By default, it is disabled.
* `header` (default `false`): uses the first line as names of columns.
+ * `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema
+ * will be forcibly applied to datasource files, and headers in CSV files will be ignored.
+ * If the option is set to `false`, the schema will be validated against all headers in CSV files
+ * in the case when the `header` option is set to `true`. Field names in the schema
+ * and column names in CSV headers are checked by their positions taking into account
+ * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable
+ * the `enforceSchema` option to avoid incorrect results.
* `inferSchema` (default `false`): infers the input schema automatically from data. It
* requires one extra pass over the data.
* `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
@@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
* `multiLine` (default `false`): parse one record, which may span multiple lines.
*
+ *
* @since 2.0.0
*/
@scala.annotation.varargs
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index dc54d182651b..82322df40752 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
@@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow]
+ requiredSchema: StructType,
+ // Actual schema of data in the csv file
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow]
/**
* Infers the schema from `inputPaths` files.
@@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable {
}
}
-object CSVDataSource {
+object CSVDataSource extends Logging {
def apply(options: CSVOptions): CSVDataSource = {
if (options.multiLine) {
MultiLineCSVDataSource
@@ -118,6 +122,84 @@ object CSVDataSource {
TextInputCSVDataSource
}
}
+
+ /**
+ * Checks that column names in a CSV header and field names in the schema are the same
+ * by taking into account case sensitivity.
+ *
+ * @param schema - provided (or inferred) schema to which CSV must conform.
+ * @param columnNames - names of CSV columns that must be checked against to the schema.
+ * @param fileName - name of CSV file that are currently checked. It is used in error messages.
+ * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
+ * names are checked for conformance to the schema. In the case if
+ * the column name don't conform to the schema, an exception is thrown.
+ * @param caseSensitive - if it is set to `false`, comparison of column names and schema field
+ * names is not case sensitive.
+ */
+ def checkHeaderColumnNames(
+ schema: StructType,
+ columnNames: Array[String],
+ fileName: String,
+ enforceSchema: Boolean,
+ caseSensitive: Boolean): Unit = {
+ if (columnNames != null) {
+ val fieldNames = schema.map(_.name).toIndexedSeq
+ val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
+ var errorMessage: Option[String] = None
+
+ if (headerLen == schemaSize) {
+ var i = 0
+ while (errorMessage.isEmpty && i < headerLen) {
+ var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
+ if (!caseSensitive) {
+ nameInSchema = nameInSchema.toLowerCase
+ nameInHeader = nameInHeader.toLowerCase
+ }
+ if (nameInHeader != nameInSchema) {
+ errorMessage = Some(
+ s"""|CSV header does not conform to the schema.
+ | Header: ${columnNames.mkString(", ")}
+ | Schema: ${fieldNames.mkString(", ")}
+ |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
+ |CSV file: $fileName""".stripMargin)
+ }
+ i += 1
+ }
+ } else {
+ errorMessage = Some(
+ s"""|Number of column in CSV header is not equal to number of fields in the schema:
+ | Header length: $headerLen, schema size: $schemaSize
+ |CSV file: $fileName""".stripMargin)
+ }
+
+ errorMessage.foreach { msg =>
+ if (enforceSchema) {
+ logWarning(msg)
+ } else {
+ throw new IllegalArgumentException(msg)
+ }
+ }
+ }
+ }
+
+ /**
+ * Checks that CSV header contains the same column names as fields names in the given schema
+ * by taking into account case sensitivity.
+ */
+ def checkHeader(
+ header: String,
+ parser: CsvParser,
+ schema: StructType,
+ fileName: String,
+ enforceSchema: Boolean,
+ caseSensitive: Boolean): Unit = {
+ checkHeaderColumnNames(
+ schema,
+ parser.parseLine(header),
+ fileName,
+ enforceSchema,
+ caseSensitive)
+ }
}
object TextInputCSVDataSource extends CSVDataSource {
@@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ requiredSchema: StructType,
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}
- val shouldDropHeader = parser.options.headerFlag && file.start == 0
- UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
+ val hasHeader = parser.options.headerFlag && file.start == 0
+ if (hasHeader) {
+ // Checking that column names in the header are matched to field names of the schema.
+ // The header will be removed from lines.
+ // Note: if there are only comments in the first block, the header would probably
+ // be not extracted.
+ CSVUtils.extractHeader(lines, parser.options).foreach { header =>
+ CSVDataSource.checkHeader(
+ header,
+ parser.tokenizer,
+ dataSchema,
+ file.filePath,
+ parser.options.enforceSchema,
+ caseSensitive)
+ }
+ }
+
+ UnivocityParser.parseIterator(lines, parser, requiredSchema)
}
override def infer(
@@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ requiredSchema: StructType,
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow] = {
+ def checkHeader(header: Array[String]): Unit = {
+ CSVDataSource.checkHeaderColumnNames(
+ dataSchema,
+ header,
+ file.filePath,
+ parser.options.enforceSchema,
+ caseSensitive)
+ }
+
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
parser.options.headerFlag,
parser,
- schema)
+ requiredSchema,
+ checkHeader)
}
override def infer(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 21279d6daf7a..b90275de9f40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
(file: PartitionedFile) => {
val conf = broadcastedHadoopConf.value.value
@@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
parsedOptions)
- CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
+ CSVDataSource(parsedOptions).readFile(
+ conf,
+ file,
+ parser,
+ requiredSchema,
+ dataSchema,
+ caseSensitive)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 7119189a4e13..fab8d62da0c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -156,6 +156,12 @@ class CSVOptions(
val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+ /**
+ * Forcibly apply the specified or inferred schema to datasource files.
+ * If the option is enabled, headers of CSV files will be ignored.
+ */
+ val enforceSchema = getBool("enforceSchema", default = true)
+
def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 9dae41b63e81..1012e774118e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -68,12 +68,8 @@ object CSVUtils {
}
}
- /**
- * Drop header line so that only data can remain.
- * This is similar with `filterHeaderLine` above and currently being used in CSV reading path.
- */
- def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
- val nonEmptyLines = if (options.isCommentSet) {
+ def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ if (options.isCommentSet) {
val commentPrefix = options.comment.toString
iter.dropWhile { line =>
line.trim.isEmpty || line.trim.startsWith(commentPrefix)
@@ -81,11 +77,19 @@ object CSVUtils {
} else {
iter.dropWhile(_.trim.isEmpty)
}
-
- if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
- iter
}
+ /**
+ * Extracts header and moves iterator forward so that only data remains in it
+ */
+ def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = {
+ val nonEmptyLines = skipComments(iter, options)
+ if (nonEmptyLines.hasNext) {
+ Some(nonEmptyLines.next())
+ } else {
+ None
+ }
+ }
/**
* Helper method that converts string representation of a character to actual character.
* It handles some Java escaped strings and throws exception if given string is longer than one
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 4f00cc5eb3f3..5f7d5696b71a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -45,7 +45,7 @@ class UnivocityParser(
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any
- private val tokenizer = {
+ val tokenizer = {
val parserSetting = options.asParserSettings
if (options.columnPruning && requiredSchema.length < dataSchema.length) {
val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f)))
@@ -250,14 +250,15 @@ private[csv] object UnivocityParser {
inputStream: InputStream,
shouldDropHeader: Boolean,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ schema: StructType,
+ checkHeader: Array[String] => Unit): Iterator[InternalRow] = {
val tokenizer = parser.tokenizer
val safeParser = new FailureSafeParser[Array[String]](
input => Seq(parser.convert(input)),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
- convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
+ convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens =>
safeParser.parse(tokens)
}.flatten
}
@@ -265,11 +266,14 @@ private[csv] object UnivocityParser {
private def convertStream[T](
inputStream: InputStream,
shouldDropHeader: Boolean,
- tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] {
+ tokenizer: CsvParser,
+ checkHeader: Array[String] => Unit = _ => ())(
+ convert: Array[String] => T) = new Iterator[T] {
tokenizer.beginParsing(inputStream)
private var nextRecord = {
if (shouldDropHeader) {
- tokenizer.parseNext()
+ val firstRecord = tokenizer.parseNext()
+ checkHeader(firstRecord)
}
tokenizer.parseNext()
}
@@ -291,21 +295,11 @@ private[csv] object UnivocityParser {
*/
def parseIterator(
lines: Iterator[String],
- shouldDropHeader: Boolean,
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow] = {
val options = parser.options
- val linesWithoutHeader = if (shouldDropHeader) {
- // Note that if there are only comments in the first block, the header would probably
- // be not dropped.
- CSVUtils.dropHeaderLine(lines, options)
- } else {
- lines
- }
-
- val filteredLines: Iterator[String] =
- CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
+ val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
val safeParser = new FailureSafeParser[String](
input => Seq(parser.parse(input)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index afe10bdc4de2..d2f166c7d187 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale
+import scala.collection.JavaConverters._
+
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.log4j.{AppenderSkeleton, LogManager}
+import org.apache.log4j.spi.LoggingEvent
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
@@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5)))
}
}
+
+ def checkHeader(multiLine: Boolean): Unit = {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ withTempPath { path =>
+ val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+ val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+ odf.write.option("header", true).csv(path.getCanonicalPath)
+ val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+ val exception = intercept[SparkException] {
+ spark.read
+ .schema(ischema)
+ .option("multiLine", multiLine)
+ .option("header", true)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+ .collect()
+ }
+ assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+
+ val shortSchema = new StructType().add("f1", DoubleType)
+ val exceptionForShortSchema = intercept[SparkException] {
+ spark.read
+ .schema(shortSchema)
+ .option("multiLine", multiLine)
+ .option("header", true)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+ .collect()
+ }
+ assert(exceptionForShortSchema.getMessage.contains(
+ "Number of column in CSV header is not equal to number of fields in the schema"))
+
+ val longSchema = new StructType()
+ .add("f1", DoubleType)
+ .add("f2", DoubleType)
+ .add("f3", DoubleType)
+
+ val exceptionForLongSchema = intercept[SparkException] {
+ spark.read
+ .schema(longSchema)
+ .option("multiLine", multiLine)
+ .option("header", true)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+ .collect()
+ }
+ assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3"))
+
+ val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType)
+ val caseSensitiveException = intercept[SparkException] {
+ spark.read
+ .schema(caseSensitiveSchema)
+ .option("multiLine", multiLine)
+ .option("header", true)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+ .collect()
+ }
+ assert(caseSensitiveException.getMessage.contains(
+ "CSV header does not conform to the schema"))
+ }
+ }
+ }
+
+ test(s"SPARK-23786: Checking column names against schema in the multiline mode") {
+ checkHeader(multiLine = true)
+ }
+
+ test(s"SPARK-23786: Checking column names against schema in the per-line mode") {
+ checkHeader(multiLine = false)
+ }
+
+ test("SPARK-23786: CSV header must not be checked if it doesn't exist") {
+ withTempPath { path =>
+ val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+ val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+ odf.write.option("header", false).csv(path.getCanonicalPath)
+ val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+ val idf = spark.read
+ .schema(ischema)
+ .option("header", false)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+
+ checkAnswer(idf, odf)
+ }
+ }
+
+ test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ withTempPath { path =>
+ val oschema = new StructType().add("A", StringType)
+ val odf = spark.createDataFrame(List(Row("0")).asJava, oschema)
+ odf.write.option("header", true).csv(path.getCanonicalPath)
+ val ischema = new StructType().add("a", StringType)
+ val idf = spark.read.schema(ischema)
+ .option("header", true)
+ .option("enforceSchema", false)
+ .csv(path.getCanonicalPath)
+ checkAnswer(idf, odf)
+ }
+ }
+ }
+
+ test("SPARK-23786: check header on parsing of dataset of strings") {
+ val ds = Seq("columnA,columnB", "1.0,1000.0").toDS()
+ val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType)
+ val exception = intercept[IllegalArgumentException] {
+ spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds)
+ }
+
+ assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+ }
+
+ test("SPARK-23786: enforce inferred schema") {
+ val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType)
+ val withHeader = spark.read
+ .option("inferSchema", true)
+ .option("enforceSchema", false)
+ .option("header", true)
+ .csv(Seq("_c0,_c1", "1.0,a").toDS())
+ assert(withHeader.schema == expectedSchema)
+ checkAnswer(withHeader, Seq(Row(1.0, "a")))
+
+ // Ignore the inferSchema flag if an user sets a schema
+ val schema = new StructType().add("colA", DoubleType).add("colB", StringType)
+ val ds = spark.read
+ .option("inferSchema", true)
+ .option("enforceSchema", false)
+ .option("header", true)
+ .schema(schema)
+ .csv(Seq("colA,colB", "1.0,a").toDS())
+ assert(ds.schema == schema)
+ checkAnswer(ds, Seq(Row(1.0, "a")))
+
+ val exception = intercept[IllegalArgumentException] {
+ spark.read
+ .option("inferSchema", true)
+ .option("enforceSchema", false)
+ .option("header", true)
+ .schema(schema)
+ .csv(Seq("col1,col2", "1.0,a").toDS())
+ }
+ assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+ }
+
+ test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") {
+ class TestAppender extends AppenderSkeleton {
+ var events = new java.util.ArrayList[LoggingEvent]
+ override def close(): Unit = {}
+ override def requiresLayout: Boolean = false
+ protected def append(event: LoggingEvent): Unit = events.add(event)
+ }
+
+ val testAppender1 = new TestAppender
+ LogManager.getRootLogger.addAppender(testAppender1)
+ try {
+ val ds = Seq("columnA,columnB", "1.0,1000.0").toDS()
+ val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType)
+
+ spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds)
+ } finally {
+ LogManager.getRootLogger.removeAppender(testAppender1)
+ }
+ assert(testAppender1.events.asScala
+ .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema")))
+
+ val testAppender2 = new TestAppender
+ LogManager.getRootLogger.addAppender(testAppender2)
+ try {
+ withTempPath { path =>
+ val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+ val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+ odf.write.option("header", true).csv(path.getCanonicalPath)
+ val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+ spark.read
+ .schema(ischema)
+ .option("header", true)
+ .option("enforceSchema", true)
+ .csv(path.getCanonicalPath)
+ .collect()
+ }
+ } finally {
+ LogManager.getRootLogger.removeAppender(testAppender2)
+ }
+ assert(testAppender2.events.asScala
+ .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema")))
+ }
}