diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 288179fc480da..32ff2c90bfa28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -230,64 +230,55 @@ class UnivocityParser( () => getCurrentInput, () => None, new RuntimeException("Malformed CSV record")) - } else if (tokens.length != parsedSchema.length) { + } + + var checkedTokens = tokens + var badRecordException: Option[Throwable] = None + + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (parsedSchema.length > tokens.length) { + checkedTokens = if (parsedSchema.length > tokens.length) { tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { tokens.take(parsedSchema.length) } - def getPartialResult(): Option[InternalRow] = { - try { - convert(checkedTokens).headOption - } catch { - case _: BadRecordException => None - } - } - // For records with less or more tokens than the schema, tries to return partial results - // if possible. - throw BadRecordException( - () => getCurrentInput, - () => getPartialResult(), - new RuntimeException("Malformed CSV record")) - } else { - // When the length of the returned tokens is identical to the length of the parsed schema, - // we just need to: - // 1. Convert the tokens that correspond to the required schema. - // 2. Apply the pushdown filters to `requiredRow`. - var i = 0 - val row = requiredRow.head - var skipRow = false - var badRecordException: Option[Throwable] = None - while (i < requiredSchema.length) { - try { - if (!skipRow) { - row(i) = valueConverters(i).apply(getToken(tokens, i)) - if (csvFilters.skipRow(row, i)) { - skipRow = true - } - } - if (skipRow) { - row.setNullAt(i) + badRecordException = Some(new RuntimeException("Malformed CSV record")) + } + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to: + // 1. Convert the tokens that correspond to the required schema. + // 2. Apply the pushdown filters to `requiredRow`. + var i = 0 + val row = requiredRow.head + var skipRow = false + while (i < requiredSchema.length) { + try { + if (!skipRow) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) + if (csvFilters.skipRow(row, i)) { + skipRow = true } - } catch { - case NonFatal(e) => - badRecordException = badRecordException.orElse(Some(e)) - row.setNullAt(i) } - i += 1 + if (skipRow) { + row.setNullAt(i) + } + } catch { + case NonFatal(e) => + badRecordException = badRecordException.orElse(Some(e)) + row.setNullAt(i) } - if (skipRow) { - noRows + i += 1 + } + if (skipRow) { + noRows + } else { + if (badRecordException.isDefined) { + throw BadRecordException( + () => getCurrentInput, () => requiredRow.headOption, badRecordException.get) } else { - if (badRecordException.isDefined) { - throw BadRecordException( - () => getCurrentInput, () => requiredRow.headOption, badRecordException.get) - } else { - requiredRow - } + requiredRow } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 846b5c594d42e..d88ec62822b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2270,4 +2270,28 @@ class CSVSuite extends QueryTest with SharedSparkSession with TestCsvData { } } } + + test("SPARK-30530: apply filters to malformed rows") { + withSQLConf(SQLConf.CSV_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { path => + Seq( + "100.0,1.0,", + "200.0,,", + "300.0,3.0,", + "1.0,4.0,", + ",4.0,", + "500.0,,", + ",6.0,", + "-500.0,50.5").toDF("data") + .repartition(1) + .write.text(path.getAbsolutePath) + val schema = new StructType().add("floats", FloatType).add("more_floats", FloatType) + val readback = spark.read + .schema(schema) + .csv(path.getAbsolutePath) + .filter("floats is null") + checkAnswer(readback, Seq(Row(null, 4.0), Row(null, 6.0))) + } + } + } }