diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 49f4e6b2ede1..3ca5d548ae7d 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -349,7 +349,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
- samplingRatio=None, enforceSchema=None):
+ samplingRatio=None, enforceSchema=None, emptyValue=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -444,6 +444,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
different, ``\0`` otherwise.
:param samplingRatio: defines fraction of rows used for schema inferring.
If None is set, it uses the default value, ``1.0``.
+ :param emptyValue: sets the string representation of an empty value. If None is set, it uses
+ the default value, empty string.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -463,7 +465,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
- enforceSchema=enforceSchema)
+ enforceSchema=enforceSchema, emptyValue=emptyValue)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -859,7 +861,7 @@ def text(self, path, compression=None, lineSep=None):
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
- charToEscapeQuoteEscaping=None, encoding=None):
+ charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -911,6 +913,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
different, ``\0`` otherwise..
:param encoding: sets the encoding (charset) of saved csv files. If None is set,
the default UTF-8 charset will be used.
+ :param emptyValue: sets the string representation of an empty value. If None is set, it uses
+ the default value, ``""``.
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
@@ -921,7 +925,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
- encoding=encoding)
+ encoding=encoding, emptyValue=emptyValue)
self._jwrite.csv(path)
@since(1.5)
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index ee13778a7dcd..522900bf6684 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -564,7 +564,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
- enforceSchema=None):
+ enforceSchema=None, emptyValue=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -658,6 +658,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
the quote character. If None is set, the default value is
escape character when escape and quote characters are
different, ``\0`` otherwise..
+ :param emptyValue: sets the string representation of an empty value. If None is set, it uses
+ the default value, empty string.
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
@@ -674,7 +676,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
+ emptyValue=emptyValue)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 0cfcc45fb3d3..e6c2cba79841 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -571,6 +571,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* whitespaces from values being read should be skipped.
*
`nullValue` (default empty string): sets the string representation of a null value. Since
* 2.0.1, this applies to all supported types including the string type.
+ * `emptyValue` (default empty string): sets the string representation of an empty value.
* `nanValue` (default `NaN`): sets the string representation of a non-number" value.
* `positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index eca2d5b97190..dfb8c4718550 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -635,6 +635,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* enclosed in quotes. Default is to only escape values containing a quote character.
* `header` (default `false`): writes the names of columns as the first line.
* `nullValue` (default empty string): sets the string representation of a null value.
+ * `emptyValue` (default `""`): sets the string representation of an empty value.
* `encoding` (by default it is not set): specifies encoding (charset) of saved csv
* files. If it is not set, the UTF-8 charset will be used.
* `compression` (default `null`): compression codec to use when saving to file. This can be
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 2b86054c0ffc..5f427239ced1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -91,9 +91,10 @@ abstract class CSVDataSource extends Serializable {
}
row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`, put the
- // index as the suffix.
+ if (value == null || value.isEmpty || value == options.nullValue ||
+ value == options.emptyValueInRead) {
+ // When there are empty strings or the values set in `nullValue` or in `emptyValue`,
+ // put the index as the suffix.
s"_c$index"
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// When there are case-insensitive duplicates, put the index as the suffix.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index a585cbed2551..e7743b07f866 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -79,7 +79,8 @@ private[csv] object CSVInferSchema {
* point checking if it is an Int, as the final type must be Double or higher.
*/
def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
- if (field == null || field.isEmpty || field == options.nullValue) {
+ if (field == null || field.isEmpty || field == options.nullValue ||
+ field == options.emptyValueInRead) {
typeSoFar
} else {
typeSoFar match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index fab8d62da0c1..f84f783604e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -117,6 +117,9 @@ class CSVOptions(
val nullValue = parameters.getOrElse("nullValue", "")
+ val emptyValueInRead = parameters.getOrElse("emptyValue", "")
+ val emptyValueInWrite = parameters.getOrElse("emptyValue", "\"\"")
+
val nanValue = parameters.getOrElse("nanValue", "NaN")
val positiveInf = parameters.getOrElse("positiveInf", "Inf")
@@ -173,7 +176,7 @@ class CSVOptions(
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
- writerSettings.setEmptyValue("\"\"")
+ writerSettings.setEmptyValue(emptyValueInWrite)
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
@@ -194,7 +197,7 @@ class CSVOptions(
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
settings.setNullValue(nullValue)
- settings.setEmptyValue("")
+ settings.setEmptyValue(emptyValueInRead)
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 39e9e1ad426b..2a4db4afbe00 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -327,6 +327,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* whitespaces from values being read should be skipped.
* `nullValue` (default empty string): sets the string representation of a null value. Since
* 2.0.1, this applies to all supported types including the string type.
+ * `emptyValue` (default empty string): sets the string representation of an empty value.
* `nanValue` (default `NaN`): sets the string representation of a non-number" value.
* `positiveInf` (default `Inf`): sets the string representation of a positive infinity
* value.
diff --git a/sql/core/src/test/resources/test-data/cars-empty-value.csv b/sql/core/src/test/resources/test-data/cars-empty-value.csv
new file mode 100644
index 000000000000..0f20a2f23ac0
--- /dev/null
+++ b/sql/core/src/test/resources/test-data/cars-empty-value.csv
@@ -0,0 +1,4 @@
+year,make,model,comment,blank
+"2012","Tesla","S","",""
+1997,Ford,E350,"Go get one now they are going fast",
+2015,Chevy,Volt,,""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 57e36e082653..40273251c378 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -105,6 +105,20 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
}
+ test("Empty fields are handled properly when an emptyValue is specified") {
+ var options = new CSVOptions(Map("emptyValue" -> "empty"), false, "GMT")
+ assert(CSVInferSchema.inferField(NullType, "empty", options) == NullType)
+ assert(CSVInferSchema.inferField(StringType, "empty", options) == StringType)
+ assert(CSVInferSchema.inferField(LongType, "empty", options) == LongType)
+
+ options = new CSVOptions(Map("emptyValue" -> "\\N"), false, "GMT")
+ assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType)
+ assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
+ assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
+ assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
+ assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
+ }
+
test("Merging Nulltypes should yield Nulltype.") {
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
assert(mergedNullTypes.deep == Array(NullType).deep)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 5a1d6679ebbd..2b39a0b1f52e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -50,6 +50,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
private val carsAltFile = "test-data/cars-alternative.csv"
private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv"
private val carsNullFile = "test-data/cars-null.csv"
+ private val carsEmptyValueFile = "test-data/cars-empty-value.csv"
private val carsBlankColName = "test-data/cars-blank-column-name.csv"
private val emptyFile = "test-data/empty.csv"
private val commentsFile = "test-data/comments.csv"
@@ -668,6 +669,70 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
}
+ test("empty fields with user defined empty values") {
+
+ // year,make,model,comment,blank
+ val dataSchema = StructType(List(
+ StructField("year", IntegerType, nullable = true),
+ StructField("make", StringType, nullable = false),
+ StructField("model", StringType, nullable = false),
+ StructField("comment", StringType, nullable = true),
+ StructField("blank", StringType, nullable = true)))
+ val cars = spark.read
+ .format("csv")
+ .schema(dataSchema)
+ .option("header", "true")
+ .option("emptyValue", "empty")
+ .load(testFile(carsEmptyValueFile))
+
+ verifyCars(cars, withHeader = true, checkValues = false)
+ val results = cars.collect()
+ assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty"))
+ assert(results(1).toSeq ===
+ Array(1997, "Ford", "E350", "Go get one now they are going fast", null))
+ assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty"))
+ }
+
+ test("save csv with empty fields with user defined empty values") {
+ withTempDir { dir =>
+ val csvDir = new File(dir, "csv").getCanonicalPath
+
+ // year,make,model,comment,blank
+ val dataSchema = StructType(List(
+ StructField("year", IntegerType, nullable = true),
+ StructField("make", StringType, nullable = false),
+ StructField("model", StringType, nullable = false),
+ StructField("comment", StringType, nullable = true),
+ StructField("blank", StringType, nullable = true)))
+ val cars = spark.read
+ .format("csv")
+ .schema(dataSchema)
+ .option("header", "true")
+ .option("nullValue", "NULL")
+ .load(testFile(carsEmptyValueFile))
+
+ cars.coalesce(1).write
+ .format("csv")
+ .option("header", "true")
+ .option("emptyValue", "empty")
+ .option("nullValue", null)
+ .save(csvDir)
+
+ val carsCopy = spark.read
+ .format("csv")
+ .schema(dataSchema)
+ .option("header", "true")
+ .load(csvDir)
+
+ verifyCars(carsCopy, withHeader = true, checkValues = false)
+ val results = carsCopy.collect()
+ assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty"))
+ assert(results(1).toSeq ===
+ Array(1997, "Ford", "E350", "Go get one now they are going fast", null))
+ assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty"))
+ }
+ }
+
test("save csv with compression codec option") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
@@ -1375,6 +1440,52 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
}
}
+ test("SPARK-25241: An empty string should not be coerced to null when emptyValue is passed.") {
+ val litNull: String = null
+ val df = Seq(
+ (1, "John Doe"),
+ (2, ""),
+ (3, "-"),
+ (4, litNull)
+ ).toDF("id", "name")
+
+ // Checks for new behavior where a null is not coerced to an empty string when `emptyValue` is
+ // set to anything but an empty string literal.
+ withTempPath { path =>
+ df.write
+ .option("emptyValue", "-")
+ .csv(path.getAbsolutePath)
+ val computed = spark.read
+ .option("emptyValue", "-")
+ .schema(df.schema)
+ .csv(path.getAbsolutePath)
+ val expected = Seq(
+ (1, "John Doe"),
+ (2, "-"),
+ (3, "-"),
+ (4, "-")
+ ).toDF("id", "name")
+
+ checkAnswer(computed, expected)
+ }
+ // Keeps the old behavior where empty string us coerced to emptyValue is not passed.
+ withTempPath { path =>
+ df.write
+ .csv(path.getAbsolutePath)
+ val computed = spark.read
+ .schema(df.schema)
+ .csv(path.getAbsolutePath)
+ val expected = Seq(
+ (1, "John Doe"),
+ (2, litNull),
+ (3, "-"),
+ (4, litNull)
+ ).toDF("id", "name")
+
+ checkAnswer(computed, expected)
+ }
+ }
+
test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") {
val schema = new StructType().add("colA", StringType)
val ds = spark