diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b5e5b18bcbefa..ec47618e73a6c 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -308,7 +308,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -385,6 +385,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse records, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -398,7 +400,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bd19fd4e385b4..7587875cb9849 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -562,7 +562,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, - columnNameOfCorruptRecord=None): + columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -637,6 +637,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param wholeFile: parse one record, which may span multiple lines. If None is + set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -652,7 +654,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, - columnNameOfCorruptRecord=columnNameOfCorruptRecord) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd083e4868cd6..e943f8da3db14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -437,12 +437,19 @@ def test_udf_with_order_by_and_limit(self): self.assertEqual(res.collect(), [Row(id=0, copy=0)]) def test_wholefile_json(self): - from pyspark.sql.types import StringType people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", wholeFile=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_wholefile_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", wholeFile=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType diff --git a/python/test_support/sql/ages_newlines.csv b/python/test_support/sql/ages_newlines.csv new file mode 100644 index 0000000000000..d19f6731625fa --- /dev/null +++ b/python/test_support/sql/ages_newlines.csv @@ -0,0 +1,6 @@ +Joe,20,"Hi, +I am Jeo" +Tom,30,"My name is Tom" +Hyukjin,25,"I am Hyukjin + +I love Spark!" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 59baf6e567721..63be1e5302302 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 0762d1b7daaea..54549f698aca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils +import org.apache.spark.TaskContext + object CodecStreams { private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { val compressionCodecs = new CompressionCodecFactory(config) @@ -42,6 +44,16 @@ object CodecStreams { .getOrElse(inputStream) } + /** + * Creates an input stream from the string path and add a closure for the input stream to be + * closed on task completion. + */ + def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { + val inputStream = createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala new file mode 100644 index 0000000000000..73e6abc6dad37 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.InputStream +import java.nio.charset.{Charset, StandardCharsets} + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.types.StructType + +/** + * Common functions for parsing CSV files + */ +abstract class CSVDataSource extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into [[InternalRow]] instances. + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] + + /** + * Infers the schema from `inputPaths` files. + */ + def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + protected def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } +} + +object CSVDataSource { + def apply(options: CSVOptions): CSVDataSource = { + if (options.wholeFile) { + WholeFileCSVDataSource + } else { + TextInputCSVDataSource + } + } +} + +object TextInputCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = true + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + val lines = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => + new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + } + } + + val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) + } else { + val charset = options.charset + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(",")) + .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) + } + } +} + +object WholeFileCSVDataSource extends CSVDataSource { + override val isSplitable: Boolean = false + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: UnivocityParser, + parsedOptions: CSVOptions): Iterator[InternalRow] = { + UnivocityParser.parseStream( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + parsedOptions.headerFlag, + parser) + } + + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): Option[StructType] = { + val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) + val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + false, + new CsvParser(parsedOptions.asParserSettings)) + }.take(1).headOption + + if (maybeFirstRow.isDefined) { + val firstRow = maybeFirstRow.get + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + } else { + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) + } + } + + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + options: CSVOptions): RDD[PortableDataStream] = { + val paths = inputPaths.map(_.getPath) + val name = paths.mkString(",") + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + val conf = job.getConfiguration + + val rdd = new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + + // Only returns `PortableDataStream`s without paths. + rdd.setName(s"CSVFile: $name").values + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 59f2919edfe2e..29c41455279e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.{Charset, StandardCharsets} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -43,11 +37,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" - override def toString: String = "CSV" - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvDataSource = CSVDataSource(parsedOptions) + csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, @@ -55,11 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { files: Seq[FileStatus]): Option[StructType] = { require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val paths = files.map(_.getPath.toString) - val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) + val parsedOptions = + new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + + CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) } override def prepareWrite( @@ -115,49 +112,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val lines = { - val conf = broadcastedHadoopConf.value.value - val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) - } - } - - val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, parsedOptions) - } else { - lines - } - - val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions) + val conf = broadcastedHadoopConf.value.value val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - filteredLines.flatMap(parser.parse) + CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) } } - private def createBaseDataset( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[String]): Dataset[String] = { - if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = inputPaths, - className = classOf[TextFileFormat].getName - ).resolveRelation(checkFilesExist = false)) - .select("value").as[String](Encoders.STRING) - } else { - val charset = options.charset - val rdd = sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) - .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) - sparkSession.createDataset(rdd)(Encoders.STRING) - } - } + override def toString: String = "CSV" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3fa30fe2401e1..b64d71bb4eef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -21,11 +21,9 @@ import java.math.BigDecimal import scala.util.control.Exception._ -import com.univocity.parsers.csv.CsvParser - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ private[csv] object CSVInferSchema { @@ -37,24 +35,13 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - csv: Dataset[String], - caseSensitive: Boolean, + tokenRDD: RDD[Array[String]], + header: Array[String], options: CSVOptions): StructType = { - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first() - val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) - val header = makeSafeHeader(firstRow, caseSensitive, options) - val fields = if (options.inferSchemaFlag) { - val tokenRdd = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options) - val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options) - val parser = new CsvParser(options.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { @@ -71,44 +58,6 @@ private[csv] object CSVInferSchema { StructType(fields) } - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } - private def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1caeec7c63945..50503385ad6d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -130,6 +130,8 @@ private[csv] class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index eb471651db2e3..804031a5bb5f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import java.io.InputStream import java.math.BigDecimal import java.text.NumberFormat import java.util.Locale @@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] class UnivocityParser( schema: StructType, requiredSchema: StructType, - options: CSVOptions) extends Logging { + private val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -56,12 +57,15 @@ private[csv] class UnivocityParser( private val valueConverters = dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - private val parser = new CsvParser(options.asParserSettings) + private val tokenizer = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 private val row = new GenericInternalRow(requiredSchema.length) + // This gets the raw input that is parsed lately. + private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // This parser loads an `indexArr._1`-th position value in input tokens, // then put the value in `row(indexArr._2)`. private val indexArr: Array[(Int, Int)] = { @@ -188,12 +192,13 @@ private[csv] class UnivocityParser( } /** - * Parses a single CSV record (in the form of an array of strings in which - * each element represents a column) and turns it into either one resulting row or no row (if the + * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = { - convertWithParseMode(input) { tokens => + def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): Option[InternalRow] = { + convertWithParseMode(tokens) { tokens => var i: Int = 0 while (i < indexArr.length) { val (pos, rowIdx) = indexArr(i) @@ -211,8 +216,7 @@ private[csv] class UnivocityParser( } private def convertWithParseMode( - input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = { - val tokens = parser.parseLine(input) + tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { if (options.dropMalformed && dataSchema.length != tokens.length) { if (numMalformedRecords < options.maxMalformedLogPerPartition) { logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") @@ -251,7 +255,7 @@ private[csv] class UnivocityParser( } catch { case NonFatal(e) if options.permissive => val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input)) + corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) Some(row) case NonFatal(e) if options.dropMalformed => if (numMalformedRecords < options.maxMalformedLogPerPartition) { @@ -269,3 +273,75 @@ private[csv] class UnivocityParser( } } } + +private[csv] object UnivocityParser { + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of tokens. + */ + def tokenizeStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser): Iterator[Array[String]] = { + convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) + } + + /** + * Parses a stream that contains CSV strings and turns it into an iterator of rows. + */ + def parseStream( + inputStream: InputStream, + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val tokenizer = parser.tokenizer + convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + parser.convert(tokens) + }.flatten + } + + private def convertStream[T]( + inputStream: InputStream, + shouldDropHeader: Boolean, + tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer.beginParsing(inputStream) + private var nextRecord = { + if (shouldDropHeader) { + tokenizer.parseNext() + } + tokenizer.parseNext() + } + + override def hasNext: Boolean = nextRecord != null + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + val curRecord = convert(nextRecord) + nextRecord = tokenizer.parseNext() + curRecord + } + } + + /** + * Parses an iterator that contains CSV strings and turns it into an iterator of rows. + */ + def parseIterator( + lines: Iterator[String], + shouldDropHeader: Boolean, + parser: UnivocityParser): Iterator[InternalRow] = { + val options = parser.options + + val linesWithoutHeader = if (shouldDropHeader) { + // Note that if there are only comments in the first block, the header would probably + // be not dropped. + CSVUtils.dropHeaderLine(lines, options) + } else { + lines + } + + val filteredLines: Iterator[String] = + CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + filteredLines.flatMap(line => parser.parse(line)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3e984effcb8d8..18843bfc307b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.InputStream - import scala.reflect.ClassTag import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} @@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { } } - private def createInputStream(config: Configuration, path: String): InputStream = { - val inputStream = CodecStreams.createInputStream(config, new Path(path)) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) - inputStream - } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, - createInputStream(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) } override def readFile( @@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { - Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + Utils.tryWithResource { + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } parser.parse( - createInputStream(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), CreateJacksonParser.inputStream, partitionedFileString).toIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f78e73f319de7..6a275281d8697 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -261,6 +261,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 371d4311baa3b..d94eb66201112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -24,11 +24,12 @@ import java.text.SimpleDateFormat import java.util.Locale import org.apache.commons.lang3.time.FastDateFormat -import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -243,12 +244,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - val cars = spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { wholeFile => + val cars = spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } test("test for blank column names on read and select columns") { @@ -263,14 +267,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - val exception = intercept[SparkException]{ - spark.read - .format("csv") - .options(Map("header" -> "true", "mode" -> "failfast")) - .load(testFile(carsFile)).collect() - } + Seq(false, true).foreach { wholeFile => + val exception = intercept[SparkException] { + spark.read + .format("csv") + .option("wholeFile", wholeFile) + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } } test("test for tokens more than the fields in the schema") { @@ -961,56 +968,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - val schema = new StructType().add("a", IntegerType).add("b", TimestampType) - val df1 = spark - .read - .option("mode", "PERMISSIVE") - .schema(schema) - .csv(testFile(valueMalformedFile)) - checkAnswer(df1, - Row(null, null) :: - Row(1, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records - val columnNameOfCorruptRecord = "_unparsed" - val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) - val df2 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField1) - .csv(testFile(valueMalformedFile)) - checkAnswer(df2, - Row(null, null, "0,2013-111-11 12:13:14") :: - Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: - Nil) - - // We put a `columnNameOfCorruptRecord` field in the middle of a schema - val schemaWithCorrField2 = new StructType() - .add("a", IntegerType) - .add(columnNameOfCorruptRecord, StringType) - .add("b", TimestampType) - val df3 = spark - .read - .option("mode", "PERMISSIVE") - .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schemaWithCorrField2) - .csv(testFile(valueMalformedFile)) - checkAnswer(df3, - Row(null, "0,2013-111-11 12:13:14", null) :: - Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: - Nil) - - val errMsg = intercept[AnalysisException] { - spark + Seq(false, true).foreach { wholeFile => + val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val df1 = spark + .read + .option("mode", "PERMISSIVE") + .option("wholeFile", wholeFile) + .schema(schema) + .csv(testFile(valueMalformedFile)) + checkAnswer(df1, + Row(null, null) :: + Row(1, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records + val columnNameOfCorruptRecord = "_unparsed" + val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) + val df2 = spark .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) - .collect - }.getMessage - assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + checkAnswer(df2, + Row(null, null, "0,2013-111-11 12:13:14") :: + Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: + Nil) + + // We put a `columnNameOfCorruptRecord` field in the middle of a schema + val schemaWithCorrField2 = new StructType() + .add("a", IntegerType) + .add(columnNameOfCorruptRecord, StringType) + .add("b", TimestampType) + val df3 = spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schemaWithCorrField2) + .csv(testFile(valueMalformedFile)) + checkAnswer(df3, + Row(null, "0,2013-111-11 12:13:14", null) :: + Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: + Nil) + + val errMsg = intercept[AnalysisException] { + spark + .read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) + .option("wholeFile", wholeFile) + .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) + .csv(testFile(valueMalformedFile)) + .collect + }.getMessage + assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + } + } + + test("SPARK-19610: Parse normal multi-line CSV files") { + val primitiveFieldAndType = Seq( + """" + |string","integer + | + | + |","long + | + |","bigInteger",double,boolean,null""".stripMargin, + """"this is a + |simple + |string."," + | + |10"," + |21474836470","92233720368547758070"," + | + |1.7976931348623157E308",true,""".stripMargin) + + withTempPath { path => + primitiveFieldAndType.toDF("value").coalesce(1).write.text(path.getAbsolutePath) + + val df = spark.read + .option("header", true) + .option("wholeFile", true) + .csv(path.getAbsolutePath) + + // Check if headers have new lines in the names. + val actualFields = df.schema.fieldNames.toSeq + val expectedFields = + Seq("\nstring", "integer\n\n\n", "long\n\n", "bigInteger", "double", "boolean", "null") + assert(actualFields === expectedFields) + + // Check if the rows have new lines in the values. + val expected = Row( + "this is a\nsimple\nstring.", + "\n\n10", + "\n21474836470", + "92233720368547758070", + "\n\n1.7976931348623157E308", + "true", + null) + checkAnswer(df, expected) + } + } + + test("Empty file produces empty dataframe with empty schema - wholeFile option") { + withTempPath { path => + path.createNewFile() + + val df = spark.read.format("csv") + .option("header", true) + .option("wholeFile", true) + .load(path.getAbsolutePath) + + assert(df.schema === spark.emptyDataFrame.schema) + checkAnswer(df, spark.emptyDataFrame) + } } }