-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23786][SQL] Checking column names of csv headers #20894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 44 commits
112ce2d
a85ccce
75e1534
8eb45b8
9b1a986
6442633
9440d8a
9f91ce7
0878f7a
a341dd7
98c27ea
811df6f
691cfbc
efb0105
c9f5e14
e195838
d6d370d
acd6d2e
13892fd
476b517
f8167e4
d068f6c
08cfcf4
f6a1694
adbedf3
0904daf
191b415
718f7ca
75c1ce6
ab9c514
0405863
714c66d
78d9f66
b43a7c7
a5f2916
9b2d403
1fffc16
ad6cda4
4bdabe2
2bd2713
b4bfd1d
21f8b10
aca4db9
e3b4275
d704766
04199e0
d5fde52
795a878
05fc7cd
9606711
11c7591
7dce1e7
c008328
9f7c440
e83ad60
26ae4f9
4b6495b
c5ee207
a2cbb7b
70e2b75
e7c3ace
3b37712
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non | |
| negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, | ||
| maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, | ||
| columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, | ||
| samplingRatio=None): | ||
| samplingRatio=None, enforceSchema=None): | ||
| """Loads a CSV file and returns the result as a :class:`DataFrame`. | ||
|
|
||
| This function will go through the input once to determine the input schema if | ||
|
|
@@ -373,6 +373,13 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non | |
| default value, ``false``. | ||
| :param inferSchema: infers the input schema automatically from data. It requires one extra | ||
| pass over the data. If None is set, it uses the default value, ``false``. | ||
| :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be | ||
| forcibly applied to datasource files and headers in CSV files will be | ||
| ignored. If the option is set to ``false``, the schema will be | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it ignored?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same confusion from https://github.com/apache/spark/pull/20894/files#r188553979.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not silently ignore the option. When |
||
| validated against headers in CSV files if the ``header`` option is set | ||
| to ``true``. The validation is performed in column ordering aware | ||
| manner by taking into account ``spark.sql.caseSensitive``. | ||
| If None is set, ``true`` is used by default. | ||
| :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from | ||
| values being read should be skipped. If None is set, it | ||
| uses the default value, ``false``. | ||
|
|
@@ -449,7 +456,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non | |
| maxCharsPerColumn=maxCharsPerColumn, | ||
| maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, | ||
| columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, | ||
| charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) | ||
| charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, | ||
| enforceSchema=enforceSchema) | ||
| if isinstance(path, basestring): | ||
| path = [path] | ||
| if type(path) == list: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3040,6 +3040,24 @@ def test_csv_sampling_ratio(self): | |
| .csv(rdd, samplingRatio=0.5).schema | ||
| self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) | ||
|
|
||
| def test_checking_csv_header(self): | ||
| tmpPath = tempfile.mkdtemp() | ||
|
||
| shutil.rmtree(tmpPath) | ||
| try: | ||
| self.spark.createDataFrame([[1, 1000], [2000, 2]])\ | ||
| .toDF('f1', 'f2').write.option("header", "true").csv(tmpPath) | ||
| schema = StructType([ | ||
| StructField('f2', IntegerType(), nullable=True), | ||
| StructField('f1', IntegerType(), nullable=True)]) | ||
| df = self.spark.read.option('header', 'true').schema(schema)\ | ||
| .csv(tmpPath, enforceSchema=False) | ||
| self.assertRaisesRegexp( | ||
| Exception, | ||
| "CSV file header does not contain the expected fields", | ||
|
||
| lambda: df.collect()) | ||
| finally: | ||
| shutil.rmtree(tmpPath) | ||
|
|
||
|
|
||
| class HiveSparkSubmitTests(SparkSubmitTests): | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import java.util.{Locale, Properties} | |
| import scala.collection.JavaConverters._ | ||
|
|
||
| import com.fasterxml.jackson.databind.ObjectMapper | ||
| import com.univocity.parsers.csv.CsvParser | ||
|
|
||
| import org.apache.spark.Partition | ||
| import org.apache.spark.annotation.InterfaceStability | ||
|
|
@@ -497,6 +498,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { | |
| StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) | ||
|
|
||
| val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => | ||
| if (parsedOptions.enforceSchema == false) { | ||
| CSVDataSource.checkHeader(firstLine, new CsvParser(parsedOptions.asParserSettings), | ||
|
||
| actualSchema, csvDataset.getClass.getCanonicalName, checkHeaderFlag = true, | ||
|
||
| sparkSession.sessionState.conf.caseSensitiveAnalysis) | ||
| } | ||
| filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) | ||
| }.getOrElse(filteredLines.rdd) | ||
|
|
||
|
|
@@ -537,6 +543,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { | |
| * <li>`comment` (default empty string): sets a single character used for skipping lines | ||
| * beginning with this character. By default, it is disabled.</li> | ||
| * <li>`header` (default `false`): uses the first line as names of columns.</li> | ||
| * <li>`enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema | ||
| * will be forcibly applied to datasource files and headers in CSV files will be ignored. | ||
| * If the option is set to `false`, the schema will be validated against headers in CSV files | ||
| * in the case when the `header` option is set to `true`. The validation is performed in column | ||
| * ordering aware manner by taking into account `spark.sql.caseSensitive`.</li> | ||
| * <li>`inferSchema` (default `false`): infers the input schema automatically from data. It | ||
| * requires one extra pass over the data.</li> | ||
| * <li>`samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.</li> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,10 @@ abstract class CSVDataSource extends Serializable { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] | ||
| requiredSchema: StructType, | ||
| // Actual schema of data in the csv file | ||
| dataSchema: StructType, | ||
| caseSensitive: Boolean): Iterator[InternalRow] | ||
|
|
||
| /** | ||
| * Infers the schema from `inputPaths` files. | ||
|
|
@@ -118,6 +121,64 @@ object CSVDataSource { | |
| TextInputCSVDataSource | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Checks that column names in a CSV header and field names in the schema are the same | ||
| * by taking into account case sensitivity. | ||
| */ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To help readability, how about at least adding |
||
| def checkHeaderColumnNames( | ||
| schema: StructType, | ||
| columnNames: Array[String], | ||
| fileName: String, | ||
| checkHeaderFlag: Boolean, | ||
|
||
| caseSensitive: Boolean): Unit = { | ||
| if (checkHeaderFlag && columnNames != null) { | ||
| val fieldNames = schema.map(_.name).toIndexedSeq | ||
| val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) | ||
|
|
||
| if (headerLen == schemaSize) { | ||
| var i = 0 | ||
| while (i < headerLen) { | ||
| var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) | ||
| if (!caseSensitive) { | ||
| nameInSchema = nameInSchema.toLowerCase | ||
| nameInHeader = nameInHeader.toLowerCase | ||
| } | ||
| if (nameInHeader != nameInSchema) { | ||
| throw new IllegalArgumentException( | ||
| s"""|CSV file header does not contain the expected fields. | ||
| | Header: ${columnNames.mkString(", ")} | ||
| | Schema: ${fieldNames.mkString(", ")} | ||
| |Expected: ${columnNames(i)} but found: ${fieldNames(i)} | ||
| |CSV file: $fileName""".stripMargin) | ||
| } | ||
| i += 1 | ||
| } | ||
| } else { | ||
| throw new IllegalArgumentException( | ||
| s"""|Number of column in CSV header is not equal to number of fields in the schema: | ||
| | Header length: $headerLen, schema size: $schemaSize | ||
| |CSV file: $fileName""".stripMargin) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Checks that CSV header contains the same column names as fields names in the given schema | ||
| * by taking into account case sensitivity. | ||
| */ | ||
| def checkHeader( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to define this in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| header: String, | ||
| parser: CsvParser, | ||
| schema: StructType, | ||
| fileName: String, | ||
| checkHeaderFlag: Boolean, | ||
| caseSensitive: Boolean): Unit = { | ||
| if (checkHeaderFlag) { | ||
| checkHeaderColumnNames(schema, parser.parseLine(header), fileName, checkHeaderFlag, | ||
|
||
| caseSensitive) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| object TextInputCSVDataSource extends CSVDataSource { | ||
|
|
@@ -127,7 +188,9 @@ object TextInputCSVDataSource extends CSVDataSource { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| requiredSchema: StructType, | ||
| dataSchema: StructType, | ||
| caseSensitive: Boolean): Iterator[InternalRow] = { | ||
| val lines = { | ||
| val linesReader = new HadoopFileLinesReader(file, conf) | ||
| Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) | ||
|
|
@@ -136,8 +199,19 @@ object TextInputCSVDataSource extends CSVDataSource { | |
| } | ||
| } | ||
|
|
||
| val shouldDropHeader = parser.options.headerFlag && file.start == 0 | ||
| UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) | ||
| val hasHeader = parser.options.headerFlag && file.start == 0 | ||
| if (hasHeader) { | ||
| // Checking that column names in the header are matched to field names of the schema. | ||
| // The header will be removed from lines. | ||
| // Note: if there are only comments in the first block, the header would probably | ||
| // be not extracted. | ||
| CSVUtils.extractHeader(lines, parser.options).foreach { header => | ||
| CSVDataSource.checkHeader(header, parser.tokenizer, dataSchema, file.filePath, | ||
| checkHeaderFlag = !parser.options.enforceSchema, caseSensitive) | ||
| } | ||
| } | ||
|
|
||
| UnivocityParser.parseIterator(lines, parser, requiredSchema) | ||
| } | ||
|
|
||
| override def infer( | ||
|
|
@@ -206,24 +280,33 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
| conf: Configuration, | ||
| file: PartitionedFile, | ||
| parser: UnivocityParser, | ||
| schema: StructType): Iterator[InternalRow] = { | ||
| requiredSchema: StructType, | ||
| dataSchema: StructType, | ||
| caseSensitive: Boolean): Iterator[InternalRow] = { | ||
| def checkHeader(header: Array[String]): Unit = { | ||
| CSVDataSource.checkHeaderColumnNames(dataSchema, header, file.filePath, | ||
| checkHeaderFlag = !parser.options.enforceSchema, caseSensitive) | ||
| } | ||
|
|
||
| UnivocityParser.parseStream( | ||
| CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), | ||
| parser.options.headerFlag, | ||
| parser, | ||
| schema) | ||
| parser.options.headerFlag, parser, requiredSchema, checkHeader) | ||
|
||
| } | ||
|
|
||
| override def infer( | ||
| sparkSession: SparkSession, | ||
| inputPaths: Seq[FileStatus], | ||
| parsedOptions: CSVOptions): StructType = { | ||
| val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) | ||
| // The header is not checked because there is no schema against with it could be check | ||
| def checkHeader(header: Array[String]): Unit = () | ||
|
||
|
|
||
|
||
| csv.flatMap { lines => | ||
| val path = new Path(lines.getPath()) | ||
| UnivocityParser.tokenizeStream( | ||
| CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), | ||
| shouldDropHeader = false, | ||
| dropFirstRecord = false, | ||
| checkHeader, | ||
|
||
| new CsvParser(parsedOptions.asParserSettings)) | ||
| }.take(1).headOption match { | ||
| case Some(firstRow) => | ||
|
|
@@ -235,6 +318,7 @@ object MultiLineCSVDataSource extends CSVDataSource { | |
| lines.getConfiguration, | ||
| new Path(lines.getPath())), | ||
| parsedOptions.headerFlag, | ||
| checkHeader, | ||
| new CsvParser(parsedOptions.asParserSettings)) | ||
| } | ||
| val sampled = CSVUtils.sample(tokenRDD, parsedOptions) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,14 +122,16 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { | |
| "df.filter($\"_corrupt_record\".isNotNull).count()." | ||
| ) | ||
| } | ||
| val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis | ||
|
|
||
| (file: PartitionedFile) => { | ||
| val conf = broadcastedHadoopConf.value.value | ||
| val parser = new UnivocityParser( | ||
| StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), | ||
| StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), | ||
| parsedOptions) | ||
| CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) | ||
| CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema, dataSchema, | ||
| caseSensitive) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: the same here. |
||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -153,6 +153,12 @@ class CSVOptions( | |
| val samplingRatio = | ||
| parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) | ||
|
|
||
| /** | ||
| * Forcibly apply the specified or inferred schema to datasource files. | ||
| * If the option is enabled, headers of CSV files will be ignored. | ||
| */ | ||
| val enforceSchema = getBool("enforceSchema", true) | ||
|
||
|
|
||
| def asWriterSettings: CsvWriterSettings = { | ||
| val writerSettings = new CsvWriterSettings() | ||
| val format = writerSettings.getFormat | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add this option to streaming reader and writer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added