diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b0c51b1e9992..874eed436c9e 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -661,7 +661,7 @@ def text(self, path, compression=None):
@since(2.0)
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
- timestampFormat=None):
+ timestampFormat=None, encoding=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -701,13 +701,15 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
+ :param encoding: encodes the CSV files by the given encoding type. If None is set,
+ it uses the default value, ``UTF-8``.
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll,
- dateFormat=dateFormat, timestampFormat=timestampFormat)
+ dateFormat=dateFormat, timestampFormat=timestampFormat, encoding=encoding)
self._jwrite.csv(path)
@since(1.5)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 9c5660a3780a..35d75dd72d5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -572,6 +572,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*
`timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
+ * `encoding` (default `UTF-8`): encodes the CSV files by the given encoding
+ * type.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index 6239508ec942..025de5b29276 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.io.{CharArrayWriter, OutputStream, StringReader}
-import java.nio.charset.StandardCharsets
+import java.io.OutputStream
+import java.nio.charset.Charset
import com.univocity.parsers.csv._
@@ -71,6 +71,7 @@ private[csv] class LineCsvWriter(
output: OutputStream) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
+ private val writerCharset = Charset.forName(params.charset)
format.setDelimiter(params.delimiter)
format.setQuote(params.quote)
@@ -84,7 +85,7 @@ private[csv] class LineCsvWriter(
writerSettings.setHeaders(headers: _*)
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
- private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
+ private val writer = new CsvWriter(output, writerCharset, writerSettings)
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
if (includeHeader) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 491ff72337a8..fe962cc90e79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -905,4 +905,23 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(df, Row(1, null))
}
}
+
+ test("save data with gb18030") {
+ withTempPath { path =>
+ // scalastyle:off
+ val df = Seq(("1", "中文")).toDF("num", "lanaguage")
+ // scalastyle:on
+ df.write
+ .option("header", "true")
+ .option("encoding", "GB18030")
+ .csv(path.getAbsolutePath)
+
+ val readBack = spark.read
+ .option("header", "true")
+ .option("encoding", "GB18030")
+ .csv(path.getAbsolutePath)
+
+ checkAnswer(df, readBack)
+ }
+ }
}