Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ def mode(self, saveMode):

>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self._jwrite = self._jwrite.mode(saveMode)
# At the JVM side, the default value of mode is already set to "error".
# So, if the given saveMode is None, we will not call JVM-side's mode method.
if saveMode is not None:
self._jwrite = self._jwrite.mode(saveMode)
return self

@since(1.4)
Expand Down Expand Up @@ -253,11 +256,12 @@ def partitionBy(self, *cols):
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
if len(cols) > 0:
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
return self

@since(1.4)
def save(self, path=None, format=None, mode="error", **options):
def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
"""Saves the contents of the :class:`DataFrame` to a data source.

The data source is specified by the ``format`` and a set of ``options``.
Expand All @@ -272,11 +276,12 @@ def save(self, path=None, format=None, mode="error", **options):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns
:param options: all other string options

>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode).options(**options)
self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
if path is None:
Expand All @@ -296,7 +301,7 @@ def insertInto(self, tableName, overwrite=False):
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)

@since(1.4)
def saveAsTable(self, name, format=None, mode="error", **options):
def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
"""Saves the content of the :class:`DataFrame` as the specified table.

In the case the table already exists, behavior of this function depends on the
Expand All @@ -312,15 +317,16 @@ def saveAsTable(self, name, format=None, mode="error", **options):
:param name: the table name
:param format: the format used to save
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param partitionBy: names of partitioning columns
:param options: all other string options
"""
self.mode(mode).options(**options)
self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
self._jwrite.saveAsTable(name)

@since(1.4)
def json(self, path, mode="error"):
def json(self, path, mode=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.

:param path: the path in any Hadoop supported file system
Expand All @@ -333,10 +339,10 @@ def json(self, path, mode="error"):

>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self._jwrite.mode(mode).json(path)
self.mode(mode)._jwrite.json(path)

@since(1.4)
def parquet(self, path, mode="error"):
def parquet(self, path, mode=None, partitionBy=()):
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.

:param path: the path in any Hadoop supported file system
Expand All @@ -346,13 +352,15 @@ def parquet(self, path, mode="error"):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns

>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new line here

"""
self._jwrite.mode(mode).parquet(path)
self.partitionBy(partitionBy).mode(mode)
self._jwrite.parquet(path)

@since(1.4)
def jdbc(self, url, table, mode="error", properties={}):
def jdbc(self, url, table, mode=None, properties={}):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.

.. note:: Don't create too many partitions in parallel on a large cluster;\
Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,38 @@ def test_save_and_load(self):

shutil.rmtree(tmpPath)

def test_save_and_load_builder(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.json(tmpPath)
actual = self.sqlCtx.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))

df.write.mode("overwrite").json(tmpPath)
actual = self.sqlCtx.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
.format("json").save(path=tmpPath)
actual =\
self.sqlCtx.read.format("json")\
.load(path=tmpPath, noUse="this options will not be used in load.")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))

defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.sqlCtx.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

shutil.rmtree(tmpPath)

def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
Expand Down