diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b5e5b18bcbefa..ec47618e73a6c 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -308,7 +308,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
- columnNameOfCorruptRecord=None):
+ columnNameOfCorruptRecord=None, wholeFile=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -385,6 +385,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ :param wholeFile: parse records, which may span multiple lines. If None is
+ set, it uses the default value, ``false``.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -398,7 +400,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
if isinstance(path, basestring):
path = [path]
return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index bd19fd4e385b4..7587875cb9849 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -562,7 +562,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None,
- columnNameOfCorruptRecord=None):
+ columnNameOfCorruptRecord=None, wholeFile=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -637,6 +637,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the value specified in
``spark.sql.columnNameOfCorruptRecord``.
+ :param wholeFile: parse one record, which may span multiple lines. If None is
+ set, it uses the default value, ``false``.
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
@@ -652,7 +654,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns,
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone,
- columnNameOfCorruptRecord=columnNameOfCorruptRecord)
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fd083e4868cd6..e943f8da3db14 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -437,12 +437,19 @@ def test_udf_with_order_by_and_limit(self):
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
def test_wholefile_json(self):
- from pyspark.sql.types import StringType
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
wholeFile=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_wholefile_csv(self):
+ ages_newlines = self.spark.read.csv(
+ "python/test_support/sql/ages_newlines.csv", wholeFile=True)
+ expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
+ Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
+ Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
+ self.assertEqual(ages_newlines.collect(), expected)
+
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
diff --git a/python/test_support/sql/ages_newlines.csv b/python/test_support/sql/ages_newlines.csv
new file mode 100644
index 0000000000000..d19f6731625fa
--- /dev/null
+++ b/python/test_support/sql/ages_newlines.csv
@@ -0,0 +1,6 @@
+Joe,20,"Hi,
+I am Jeo"
+Tom,30,"My name is Tom"
+Hyukjin,25,"I am Hyukjin
+
+I love Spark!"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 59baf6e567721..63be1e5302302 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -463,6 +463,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
+ * `wholeFile` (default `false`): parse one record, which may span multiple lines.
*
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
index 0762d1b7daaea..54549f698aca5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
@@ -27,6 +27,8 @@ import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.ReflectionUtils
+import org.apache.spark.TaskContext
+
object CodecStreams {
private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = {
val compressionCodecs = new CompressionCodecFactory(config)
@@ -42,6 +44,16 @@ object CodecStreams {
.getOrElse(inputStream)
}
+ /**
+ * Creates an input stream from the string path and add a closure for the input stream to be
+ * closed on task completion.
+ */
+ def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = {
+ val inputStream = createInputStream(config, new Path(path))
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
+ inputStream
+ }
+
private def getCompressionCodec(
context: JobContext,
file: Option[Path] = None): Option[CompressionCodec] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
new file mode 100644
index 0000000000000..73e6abc6dad37
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.InputStream
+import java.nio.charset.{Charset, StandardCharsets}
+
+import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+
+import org.apache.spark.TaskContext
+import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.rdd.{BinaryFileRDD, RDD}
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Common functions for parsing CSV files
+ */
+abstract class CSVDataSource extends Serializable {
+ def isSplitable: Boolean
+
+ /**
+ * Parse a [[PartitionedFile]] into [[InternalRow]] instances.
+ */
+ def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow]
+
+ /**
+ * Infers the schema from `inputPaths` files.
+ */
+ def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType]
+
+ /**
+ * Generates a header from the given row which is null-safe and duplicate-safe.
+ */
+ protected def makeSafeHeader(
+ row: Array[String],
+ caseSensitive: Boolean,
+ options: CSVOptions): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`, put the
+ // index as the suffix.
+ s"_c$index"
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // When there are case-insensitive duplicates, put the index as the suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
+ }
+}
+
+object CSVDataSource {
+ def apply(options: CSVOptions): CSVDataSource = {
+ if (options.wholeFile) {
+ WholeFileCSVDataSource
+ } else {
+ TextInputCSVDataSource
+ }
+ }
+}
+
+object TextInputCSVDataSource extends CSVDataSource {
+ override val isSplitable: Boolean = true
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow] = {
+ val lines = {
+ val linesReader = new HadoopFileLinesReader(file, conf)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ linesReader.map { line =>
+ new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
+ }
+ }
+
+ val shouldDropHeader = parsedOptions.headerFlag && file.start == 0
+ UnivocityParser.parseIterator(lines, shouldDropHeader, parser)
+ }
+
+ override def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType] = {
+ val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+ val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first()
+ val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+ val linesWithoutHeader =
+ CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+ val parser = new CsvParser(parsedOptions.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+
+ Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ }
+
+ private def createBaseDataset(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ options: CSVOptions): Dataset[String] = {
+ val paths = inputPaths.map(_.getPath.toString)
+ if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
+ sparkSession.baseRelationToDataFrame(
+ DataSource.apply(
+ sparkSession,
+ paths = paths,
+ className = classOf[TextFileFormat].getName
+ ).resolveRelation(checkFilesExist = false))
+ .select("value").as[String](Encoders.STRING)
+ } else {
+ val charset = options.charset
+ val rdd = sparkSession.sparkContext
+ .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(","))
+ .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
+ sparkSession.createDataset(rdd)(Encoders.STRING)
+ }
+ }
+}
+
+object WholeFileCSVDataSource extends CSVDataSource {
+ override val isSplitable: Boolean = false
+
+ override def readFile(
+ conf: Configuration,
+ file: PartitionedFile,
+ parser: UnivocityParser,
+ parsedOptions: CSVOptions): Iterator[InternalRow] = {
+ UnivocityParser.parseStream(
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
+ parsedOptions.headerFlag,
+ parser)
+ }
+
+ override def infer(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ parsedOptions: CSVOptions): Option[StructType] = {
+ val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
+ val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines =>
+ UnivocityParser.tokenizeStream(
+ CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
+ false,
+ new CsvParser(parsedOptions.asParserSettings))
+ }.take(1).headOption
+
+ if (maybeFirstRow.isDefined) {
+ val firstRow = maybeFirstRow.get
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.flatMap { lines =>
+ UnivocityParser.tokenizeStream(
+ CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
+ parsedOptions.headerFlag,
+ new CsvParser(parsedOptions.asParserSettings))
+ }
+ Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ } else {
+ // If the first row could not be read, just return the empty schema.
+ Some(StructType(Nil))
+ }
+ }
+
+ private def createBaseRdd(
+ sparkSession: SparkSession,
+ inputPaths: Seq[FileStatus],
+ options: CSVOptions): RDD[PortableDataStream] = {
+ val paths = inputPaths.map(_.getPath)
+ val name = paths.mkString(",")
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ FileInputFormat.setInputPaths(job, paths: _*)
+ val conf = job.getConfiguration
+
+ val rdd = new BinaryFileRDD(
+ sparkSession.sparkContext,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ conf,
+ sparkSession.sparkContext.defaultMinPartitions)
+
+ // Only returns `PortableDataStream`s without paths.
+ rdd.setName(s"CSVFile: $name").values
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 59f2919edfe2e..29c41455279e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -17,21 +17,15 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.nio.charset.{Charset, StandardCharsets}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, Text}
-import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._
-import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -43,11 +37,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def shortName(): String = "csv"
- override def toString: String = "CSV"
-
- override def hashCode(): Int = getClass.hashCode()
-
- override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ val parsedOptions =
+ new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val csvDataSource = CSVDataSource(parsedOptions)
+ csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
+ }
override def inferSchema(
sparkSession: SparkSession,
@@ -55,11 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
files: Seq[FileStatus]): Option[StructType] = {
require(files.nonEmpty, "Cannot infer schema from an empty set of files")
- val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
- val paths = files.map(_.getPath.toString)
- val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
+ val parsedOptions =
+ new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+
+ CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions)
}
override def prepareWrite(
@@ -115,49 +112,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
(file: PartitionedFile) => {
- val lines = {
- val conf = broadcastedHadoopConf.value.value
- val linesReader = new HadoopFileLinesReader(file, conf)
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
- linesReader.map { line =>
- new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
- }
- }
-
- val linesWithoutHeader = if (parsedOptions.headerFlag && file.start == 0) {
- // Note that if there are only comments in the first block, the header would probably
- // be not dropped.
- CSVUtils.dropHeaderLine(lines, parsedOptions)
- } else {
- lines
- }
-
- val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, parsedOptions)
+ val conf = broadcastedHadoopConf.value.value
val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions)
- filteredLines.flatMap(parser.parse)
+ CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions)
}
}
- private def createBaseDataset(
- sparkSession: SparkSession,
- options: CSVOptions,
- inputPaths: Seq[String]): Dataset[String] = {
- if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
- sparkSession.baseRelationToDataFrame(
- DataSource.apply(
- sparkSession,
- paths = inputPaths,
- className = classOf[TextFileFormat].getName
- ).resolveRelation(checkFilesExist = false))
- .select("value").as[String](Encoders.STRING)
- } else {
- val charset = options.charset
- val rdd = sparkSession.sparkContext
- .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(","))
- .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
- sparkSession.createDataset(rdd)(Encoders.STRING)
- }
- }
+ override def toString: String = "CSV"
+
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
}
private[csv] class CsvOutputWriter(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 3fa30fe2401e1..b64d71bb4eef2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -21,11 +21,9 @@ import java.math.BigDecimal
import scala.util.control.Exception._
-import com.univocity.parsers.csv.CsvParser
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
private[csv] object CSVInferSchema {
@@ -37,24 +35,13 @@ private[csv] object CSVInferSchema {
* 3. Replace any null types with string type
*/
def infer(
- csv: Dataset[String],
- caseSensitive: Boolean,
+ tokenRDD: RDD[Array[String]],
+ header: Array[String],
options: CSVOptions): StructType = {
- val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
- val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
- val header = makeSafeHeader(firstRow, caseSensitive, options)
-
val fields = if (options.inferSchemaFlag) {
- val tokenRdd = csv.rdd.mapPartitions { iter =>
- val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
- val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
- val parser = new CsvParser(options.asParserSettings)
- linesWithoutHeader.map(parser.parseLine)
- }
-
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
- tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
+ tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
@@ -71,44 +58,6 @@ private[csv] object CSVInferSchema {
StructType(fields)
}
- /**
- * Generates a header from the given row which is null-safe and duplicate-safe.
- */
- private def makeSafeHeader(
- row: Array[String],
- caseSensitive: Boolean,
- options: CSVOptions): Array[String] = {
- if (options.headerFlag) {
- val duplicates = {
- val headerNames = row.filter(_ != null)
- .map(name => if (caseSensitive) name else name.toLowerCase)
- headerNames.diff(headerNames.distinct).distinct
- }
-
- row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`, put the
- // index as the suffix.
- s"_c$index"
- } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
- // When there are case-insensitive duplicates, put the index as the suffix.
- s"$value$index"
- } else if (duplicates.contains(value)) {
- // When there are duplicates, put the index as the suffix.
- s"$value$index"
- } else {
- value
- }
- }
- } else {
- row.zipWithIndex.map { case (_, index) =>
- // Uses default column names, "_c#" where # is its position of fields
- // when header option is disabled.
- s"_c$index"
- }
- }
- }
-
private def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 1caeec7c63945..50503385ad6d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -130,6 +130,8 @@ private[csv] class CSVOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)
+ val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)
+
val maxColumns = getInt("maxColumns", 20480)
val maxCharsPerColumn = getInt("maxCharsPerColumn", -1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index eb471651db2e3..804031a5bb5f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.csv
+import java.io.InputStream
import java.math.BigDecimal
import java.text.NumberFormat
import java.util.Locale
@@ -36,7 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[csv] class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
- options: CSVOptions) extends Logging {
+ private val options: CSVOptions) extends Logging {
require(requiredSchema.toSet.subsetOf(schema.toSet),
"requiredSchema should be the subset of schema.")
@@ -56,12 +57,15 @@ private[csv] class UnivocityParser(
private val valueConverters =
dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
- private val parser = new CsvParser(options.asParserSettings)
+ private val tokenizer = new CsvParser(options.asParserSettings)
private var numMalformedRecords = 0
private val row = new GenericInternalRow(requiredSchema.length)
+ // This gets the raw input that is parsed lately.
+ private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd
+
// This parser loads an `indexArr._1`-th position value in input tokens,
// then put the value in `row(indexArr._2)`.
private val indexArr: Array[(Int, Int)] = {
@@ -188,12 +192,13 @@ private[csv] class UnivocityParser(
}
/**
- * Parses a single CSV record (in the form of an array of strings in which
- * each element represents a column) and turns it into either one resulting row or no row (if the
+ * Parses a single CSV string and turns it into either one resulting row or no row (if the
* the record is malformed).
*/
- def parse(input: String): Option[InternalRow] = {
- convertWithParseMode(input) { tokens =>
+ def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input))
+
+ private def convert(tokens: Array[String]): Option[InternalRow] = {
+ convertWithParseMode(tokens) { tokens =>
var i: Int = 0
while (i < indexArr.length) {
val (pos, rowIdx) = indexArr(i)
@@ -211,8 +216,7 @@ private[csv] class UnivocityParser(
}
private def convertWithParseMode(
- input: String)(convert: Array[String] => InternalRow): Option[InternalRow] = {
- val tokens = parser.parseLine(input)
+ tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = {
if (options.dropMalformed && dataSchema.length != tokens.length) {
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
@@ -251,7 +255,7 @@ private[csv] class UnivocityParser(
} catch {
case NonFatal(e) if options.permissive =>
val row = new GenericInternalRow(requiredSchema.length)
- corruptFieldIndex.foreach(row(_) = UTF8String.fromString(input))
+ corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput()))
Some(row)
case NonFatal(e) if options.dropMalformed =>
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
@@ -269,3 +273,75 @@ private[csv] class UnivocityParser(
}
}
}
+
+private[csv] object UnivocityParser {
+
+ /**
+ * Parses a stream that contains CSV strings and turns it into an iterator of tokens.
+ */
+ def tokenizeStream(
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ tokenizer: CsvParser): Iterator[Array[String]] = {
+ convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens)
+ }
+
+ /**
+ * Parses a stream that contains CSV strings and turns it into an iterator of rows.
+ */
+ def parseStream(
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ parser: UnivocityParser): Iterator[InternalRow] = {
+ val tokenizer = parser.tokenizer
+ convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
+ parser.convert(tokens)
+ }.flatten
+ }
+
+ private def convertStream[T](
+ inputStream: InputStream,
+ shouldDropHeader: Boolean,
+ tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] {
+ tokenizer.beginParsing(inputStream)
+ private var nextRecord = {
+ if (shouldDropHeader) {
+ tokenizer.parseNext()
+ }
+ tokenizer.parseNext()
+ }
+
+ override def hasNext: Boolean = nextRecord != null
+
+ override def next(): T = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ val curRecord = convert(nextRecord)
+ nextRecord = tokenizer.parseNext()
+ curRecord
+ }
+ }
+
+ /**
+ * Parses an iterator that contains CSV strings and turns it into an iterator of rows.
+ */
+ def parseIterator(
+ lines: Iterator[String],
+ shouldDropHeader: Boolean,
+ parser: UnivocityParser): Iterator[InternalRow] = {
+ val options = parser.options
+
+ val linesWithoutHeader = if (shouldDropHeader) {
+ // Note that if there are only comments in the first block, the header would probably
+ // be not dropped.
+ CSVUtils.dropHeaderLine(lines, options)
+ } else {
+ lines
+ }
+
+ val filteredLines: Iterator[String] =
+ CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
+ filteredLines.flatMap(line => parser.parse(line))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 3e984effcb8d8..18843bfc307b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.execution.datasources.json
-import java.io.InputStream
-
import scala.reflect.ClassTag
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import com.google.common.io.ByteStreams
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
@@ -186,16 +184,10 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
}
}
- private def createInputStream(config: Configuration, path: String): InputStream = {
- val inputStream = CodecStreams.createInputStream(config, new Path(path))
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close()))
- inputStream
- }
-
override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
CreateJacksonParser.inputStream(
jsonFactory,
- createInputStream(record.getConfiguration, record.getPath()))
+ CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))
}
override def readFile(
@@ -203,13 +195,15 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
file: PartitionedFile,
parser: JacksonParser): Iterator[InternalRow] = {
def partitionedFileString(ignored: Any): UTF8String = {
- Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream =>
+ Utils.tryWithResource {
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)
+ } { inputStream =>
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
}
parser.parse(
- createInputStream(conf, file.filePath),
+ CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
CreateJacksonParser.inputStream,
partitionedFileString).toIterator
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index f78e73f319de7..6a275281d8697 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -261,6 +261,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
+ * `wholeFile` (default `false`): parse one record, which may span multiple lines.
*
*
* @since 2.0.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 371d4311baa3b..d94eb66201112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -24,11 +24,12 @@ import java.text.SimpleDateFormat
import java.util.Locale
import org.apache.commons.lang3.time.FastDateFormat
-import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
+import org.apache.spark.sql.functions.{col, regexp_replace}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
@@ -243,12 +244,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test for DROPMALFORMED parsing mode") {
- val cars = spark.read
- .format("csv")
- .options(Map("header" -> "true", "mode" -> "dropmalformed"))
- .load(testFile(carsFile))
+ Seq(false, true).foreach { wholeFile =>
+ val cars = spark.read
+ .format("csv")
+ .option("wholeFile", wholeFile)
+ .options(Map("header" -> "true", "mode" -> "dropmalformed"))
+ .load(testFile(carsFile))
- assert(cars.select("year").collect().size === 2)
+ assert(cars.select("year").collect().size === 2)
+ }
}
test("test for blank column names on read and select columns") {
@@ -263,14 +267,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("test for FAILFAST parsing mode") {
- val exception = intercept[SparkException]{
- spark.read
- .format("csv")
- .options(Map("header" -> "true", "mode" -> "failfast"))
- .load(testFile(carsFile)).collect()
- }
+ Seq(false, true).foreach { wholeFile =>
+ val exception = intercept[SparkException] {
+ spark.read
+ .format("csv")
+ .option("wholeFile", wholeFile)
+ .options(Map("header" -> "true", "mode" -> "failfast"))
+ .load(testFile(carsFile)).collect()
+ }
- assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
+ assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
+ }
}
test("test for tokens more than the fields in the schema") {
@@ -961,56 +968,121 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") {
- val schema = new StructType().add("a", IntegerType).add("b", TimestampType)
- val df1 = spark
- .read
- .option("mode", "PERMISSIVE")
- .schema(schema)
- .csv(testFile(valueMalformedFile))
- checkAnswer(df1,
- Row(null, null) ::
- Row(1, java.sql.Date.valueOf("1983-08-04")) ::
- Nil)
-
- // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records
- val columnNameOfCorruptRecord = "_unparsed"
- val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType)
- val df2 = spark
- .read
- .option("mode", "PERMISSIVE")
- .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .schema(schemaWithCorrField1)
- .csv(testFile(valueMalformedFile))
- checkAnswer(df2,
- Row(null, null, "0,2013-111-11 12:13:14") ::
- Row(1, java.sql.Date.valueOf("1983-08-04"), null) ::
- Nil)
-
- // We put a `columnNameOfCorruptRecord` field in the middle of a schema
- val schemaWithCorrField2 = new StructType()
- .add("a", IntegerType)
- .add(columnNameOfCorruptRecord, StringType)
- .add("b", TimestampType)
- val df3 = spark
- .read
- .option("mode", "PERMISSIVE")
- .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .schema(schemaWithCorrField2)
- .csv(testFile(valueMalformedFile))
- checkAnswer(df3,
- Row(null, "0,2013-111-11 12:13:14", null) ::
- Row(1, null, java.sql.Date.valueOf("1983-08-04")) ::
- Nil)
-
- val errMsg = intercept[AnalysisException] {
- spark
+ Seq(false, true).foreach { wholeFile =>
+ val schema = new StructType().add("a", IntegerType).add("b", TimestampType)
+ val df1 = spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("wholeFile", wholeFile)
+ .schema(schema)
+ .csv(testFile(valueMalformedFile))
+ checkAnswer(df1,
+ Row(null, null) ::
+ Row(1, java.sql.Date.valueOf("1983-08-04")) ::
+ Nil)
+
+ // If `schema` has `columnNameOfCorruptRecord`, it should handle corrupt records
+ val columnNameOfCorruptRecord = "_unparsed"
+ val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType)
+ val df2 = spark
.read
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
- .schema(schema.add(columnNameOfCorruptRecord, IntegerType))
+ .option("wholeFile", wholeFile)
+ .schema(schemaWithCorrField1)
.csv(testFile(valueMalformedFile))
- .collect
- }.getMessage
- assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
+ checkAnswer(df2,
+ Row(null, null, "0,2013-111-11 12:13:14") ::
+ Row(1, java.sql.Date.valueOf("1983-08-04"), null) ::
+ Nil)
+
+ // We put a `columnNameOfCorruptRecord` field in the middle of a schema
+ val schemaWithCorrField2 = new StructType()
+ .add("a", IntegerType)
+ .add(columnNameOfCorruptRecord, StringType)
+ .add("b", TimestampType)
+ val df3 = spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .option("wholeFile", wholeFile)
+ .schema(schemaWithCorrField2)
+ .csv(testFile(valueMalformedFile))
+ checkAnswer(df3,
+ Row(null, "0,2013-111-11 12:13:14", null) ::
+ Row(1, null, java.sql.Date.valueOf("1983-08-04")) ::
+ Nil)
+
+ val errMsg = intercept[AnalysisException] {
+ spark
+ .read
+ .option("mode", "PERMISSIVE")
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .option("wholeFile", wholeFile)
+ .schema(schema.add(columnNameOfCorruptRecord, IntegerType))
+ .csv(testFile(valueMalformedFile))
+ .collect
+ }.getMessage
+ assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
+ }
+ }
+
+ test("SPARK-19610: Parse normal multi-line CSV files") {
+ val primitiveFieldAndType = Seq(
+ """"
+ |string","integer
+ |
+ |
+ |","long
+ |
+ |","bigInteger",double,boolean,null""".stripMargin,
+ """"this is a
+ |simple
+ |string.","
+ |
+ |10","
+ |21474836470","92233720368547758070","
+ |
+ |1.7976931348623157E308",true,""".stripMargin)
+
+ withTempPath { path =>
+ primitiveFieldAndType.toDF("value").coalesce(1).write.text(path.getAbsolutePath)
+
+ val df = spark.read
+ .option("header", true)
+ .option("wholeFile", true)
+ .csv(path.getAbsolutePath)
+
+ // Check if headers have new lines in the names.
+ val actualFields = df.schema.fieldNames.toSeq
+ val expectedFields =
+ Seq("\nstring", "integer\n\n\n", "long\n\n", "bigInteger", "double", "boolean", "null")
+ assert(actualFields === expectedFields)
+
+ // Check if the rows have new lines in the values.
+ val expected = Row(
+ "this is a\nsimple\nstring.",
+ "\n\n10",
+ "\n21474836470",
+ "92233720368547758070",
+ "\n\n1.7976931348623157E308",
+ "true",
+ null)
+ checkAnswer(df, expected)
+ }
+ }
+
+ test("Empty file produces empty dataframe with empty schema - wholeFile option") {
+ withTempPath { path =>
+ path.createNewFile()
+
+ val df = spark.read.format("csv")
+ .option("header", true)
+ .option("wholeFile", true)
+ .load(path.getAbsolutePath)
+
+ assert(df.schema === spark.emptyDataFrame.schema)
+ checkAnswer(df, spark.emptyDataFrame)
+ }
}
}