diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2b86054c0ffcb..e840ff1682502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -240,23 +240,25 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, csv: Dataset[String], maybeFirstLine: Option[String], - parsedOptions: CSVOptions): StructType = maybeFirstLine match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) - val tokenRDD = sampled.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) - case None => - // If the first line could not be read, just return the empty schema. - StructType(Nil) + parsedOptions: CSVOptions): StructType = { + val csvParser = new CsvParser(parsedOptions.asParserSettings) + maybeFirstLine.map(csvParser.parseLine(_)) match { + case Some(firstRow) if firstRow != null => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case _ => + // If the first line could not be read, just return the empty schema. + StructType(Nil) + } } private def createBaseDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e15af425b2649..9088d43905e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -216,7 +216,12 @@ class UnivocityParser( } private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != parsedSchema.length) { + if (tokens == null) { + throw BadRecordException( + () => getCurrentInput, + () => None, + new RuntimeException("Malformed CSV record")) + } else if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5a1d6679ebbdb..ba9215381c6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -34,7 +34,7 @@ import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -1700,4 +1700,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkCount(2) countForMalformedCSV(0, Seq("")) } + + test("SPARK-25387: bad input should not cause NPE") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val input = spark.createDataset(Seq("\u0000\u0000\u0001234")) + + checkAnswer(spark.read.schema(schema).csv(input), Row(null)) + checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null)) + assert(spark.read.csv(input).collect().toSet == Set(Row())) + } }