diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 448a4732001b..a0e20d39c20d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None): + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f940738986..fae50b3d5d53 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a2450932e303..ea2dd7605dc5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self): .csv(rdd, samplingRatio=0.5).schema self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ac4580a0919a..de6be5f76e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ @@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + CSVDataSource.checkHeader( + firstLine, + new CsvParser(parsedOptions.asParserSettings), + actualSchema, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
  • *
  • `header` (default `false`): uses the first line as names of columns.
  • + *
  • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • @@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index dc54d182651b..82322df40752 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +122,84 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + /** + * Checks that CSV header contains the same column names as fields names in the given schema + * by taking into account case sensitivity. + */ + def checkHeader( + header: String, + parser: CsvParser, + schema: StructType, + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + checkHeaderColumnNames( + schema, + parser.parseLine(header), + fileName, + enforceSchema, + caseSensitive) + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + CSVDataSource.checkHeader( + header, + parser.tokenizer, + dataSchema, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + CSVDataSource.checkHeaderColumnNames( + dataSchema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 21279d6daf7a..b90275de9f40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 7119189a4e13..fab8d62da0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -156,6 +156,12 @@ class CSVOptions( val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 9dae41b63e81..1012e774118e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -68,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -81,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f3..5f7d5696b71a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,7 +45,7 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { + val tokenizer = { val parserSetting = options.asParserSettings if (options.columnPruning && requiredSchema.length < dataSchema.length) { val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) @@ -250,14 +250,15 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -265,11 +266,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -291,21 +295,11 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index afe10bdc4de2..d2f166c7d187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} @@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) } } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } }