diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f..b1d7702079ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -254,6 +254,17 @@ class UnivocityParser( } } + /** + * This function deals with the cases it fails to parse in PERMISSIVE mode. The failure reasons + * of this mode are 1) the longer lengths of tokens than expected or 2) format exceptions + * (e.g., NumberFormatException). + */ + private def failedRecordWithPermissiveMode(): Option[InternalRow] = { + val row = new GenericInternalRow(requiredSchema.length) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) + Some(row) + } + private def convertWithParseMode( tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { if (options.dropMalformed && dataSchema.length != tokens.length) { @@ -271,43 +282,37 @@ class UnivocityParser( throw new RuntimeException(s"Malformed line in FAILFAST mode: " + s"${tokens.mkString(options.delimiter.toString)}") } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { + // If a length of parsed tokens is longer than expected, it treats them as malformed. + if (options.permissive && dataSchema.length < tokens.length) { + failedRecordWithPermissiveMode() + } else { + // If the length is shorter than expected, it adds extra tokens in the tail. + val checkedTokens = if (options.permissive && dataSchema.length > tokens.length) { tokens ++ new Array[String](dataSchema.length - tokens.length) } else { - tokens.take(dataSchema.length) + tokens } - } else { - tokens - } - try { - Some(convert(checkedTokens)) - } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + try { + Some(convert(checkedTokens)) + } catch { + case NonFatal(e) if options.permissive => + val row = new GenericInternalRow(requiredSchema.length) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) + Some(row) + case NonFatal(e) if options.dropMalformed => + if (numMalformedRecords < options.maxMalformedLogPerPartition) { + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") + } + if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } + numMalformedRecords += 1 + None + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e..bbbec937504a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -297,17 +297,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("test for tokens more than the fields in the schema") { - val cars = spark - .read - .format("csv") - .option("header", "false") - .option("comment", "~") - .load(testFile(carsMalformedFile)) - - verifyCars(cars, withHeader = false, checkTypes = false) - } - test("test with null quote character") { val cars = spark.read .format("csv") @@ -1116,4 +1105,21 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.schema === schema) } + test("SPARK-19783 test for tokens more than the fields in the schema") { + val columnNameOfCorruptRecord = "_unparsed" + withTempPath { path => + Seq("1,2", "1,2,3,4").toDF().write.text(path.getAbsolutePath) + val schema = StructType( + StructField("a", IntegerType, true) :: + StructField("b", IntegerType, true) :: + StructField(columnNameOfCorruptRecord, StringType, true) :: Nil) + val df = spark.read + .schema(schema) + .option("header", "false") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .csv(path.getAbsolutePath) + + checkAnswer(df, Row(1, 2, null) :: Row(null, null, "1,2,3,4") :: Nil) + } + } }