Skip to content

Commit

Permalink
Use spark.write.csv in to_csv of Series and DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Sep 5, 2019
1 parent af8803d commit 15ba56f
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 126 deletions.
111 changes: 83 additions & 28 deletions databricks/koalas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,21 +456,24 @@ def to_numpy(self):
"""
return self.to_pandas().values

def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True,
index=True, encoding=None, quotechar='"', date_format=None, escapechar=None):
def to_csv(self, path=None, sep=',', na_rep='', columns=None, header=True,
encoding=None, quotechar='"', date_format=None, escapechar=None,
num_files=None):
r"""
Write object to a comma-separated values (csv) file.
.. note:: Spark writes files to HDFS by default.
If you want to save the file locally, you need to use path like below
`'files:/' + local paths` like 'files:/work/data.csv'. Otherwise,
you will write the file to the HDFS path where the spark program starts.
.. note:: Koalas `to_csv` writes files to a path or URI. Unlike pandas', Koalas
respects HDFS's property such as 'fs.default.name'.
.. note:: Koalas writes CSV files into the directory, `path`, and writes
multiple `part-...` files in the directoy when `path` is specified.
This behaviour was inherited from Apache Spark. The number of files can
be controlled by `num_files`.
Parameters
----------
path_or_buf : str or file handle, default None
File path or object, if None is provided the result is returned as
a string.
path : str, default None
File path. If None is provided the result is returned as a string.
sep : str, default ','
String of length 1. Field delimiter for the output file.
na_rep : str, default ''
Expand All @@ -480,8 +483,6 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True
header : bool or list of str, default True
Write out the column names. If a list of strings is given it is
assumed to be aliases for the column names.
index : bool, default True
Write row names (index).
encoding : str, optional
A string representing the encoding to use in the output file,
defaults to 'utf-8'.
Expand All @@ -492,6 +493,8 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True
escapechar : str, default None
String of length 1. Character used to escape `sep` and `quotechar`
when appropriate.
num_files : the number of files to be written in `path_or_buf` directory when
this is a path.
See Also
--------
Expand All @@ -500,40 +503,92 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True
DataFrame.to_table
DataFrame.to_parquet
DataFrame.to_spark_io
Examples
--------
>>> df = ks.DataFrame(dict(
... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')),
... country=['KR', 'US', 'JP'],
... code=[1, 2 ,3]), columns=['date', 'country', 'code'])
>>> df
date country code
0 2012-01-31 12:00:00 KR 1
1 2012-02-29 12:00:00 US 2
2 2012-03-31 12:00:00 JP 3
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path)
>>> df.sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
date country code
... 2012-01-31 12:00:00 KR 1
... 2012-02-29 12:00:00 US 2
... 2012-03-31 12:00:00 JP 3
>>> path = "/tmp/a"
>>> print(df.to_csv()) # doctest: +NORMALIZE_WHITESPACE
date,country,code
2012-01-31 12:00:00,KR,1
2012-02-29 12:00:00,US,2
2012-03-31 12:00:00,JP,3
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1)
>>> ks.read_csv(
... path=r'%s/to_csv/foo.csv' % path
... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
date country code
... 2012-01-31 12:00:00 KR 1
... 2012-02-29 12:00:00 US 2
... 2012-03-31 12:00:00 JP 3
In case of Series,
>>> print(df.date.to_csv()) # doctest: +NORMALIZE_WHITESPACE
date
2012-01-31 12:00:00
2012-02-29 12:00:00
2012-03-31 12:00:00
>>> df.date.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1)
>>> ks.read_csv(
... path=r'%s/to_csv/foo.csv' % path
... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
date
... 2012-01-31 12:00:00
... 2012-02-29 12:00:00
... 2012-03-31 12:00:00
"""
if path is None:
# If path is none, just collect and use pandas's to_csv.
kdf = self
if isinstance(self, ks.DataFrame):
f = pd.DataFrame.to_csv
elif isinstance(self, ks.Series):
f = pd.Series.to_csv
else:
raise TypeError('Constructor expects DataFrame or Series; however, '
'got [%s]' % (self,))
return kdf.to_pandas().to_csv(
path_or_buf=None, sep=sep, na_rep=na_rep, columns=columns,
header=header, encoding=encoding, quotechar=quotechar,
date_format=date_format, escapechar=escapechar, index=False)

if columns is not None:
data_columns = columns
else:
data_columns = self._internal.data_columns

if index:
index_columns = self._internal.index_columns
else:
index_columns = []
kdf = self
if isinstance(self, ks.Series):
kdf = self._kdf

if isinstance(header, list):
sdf = self._sdf.select(index_columns +
[self._internal.scol_for(old_name).alias(new_name)
for (old_name, new_name) in zip(data_columns, header)])
sdf = kdf._sdf.select(
[self._internal.scol_for(old_name).alias(new_name)
for (old_name, new_name) in zip(data_columns, header)])
header = True
else:
sdf = self._sdf.select(index_columns + data_columns)
sdf = kdf._sdf.select(data_columns)

if num_files is not None:
sdf = sdf.repartition(num_files)

sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header,
encoding=encoding, quote=quotechar, dateFormat=date_format,
charToEscapeQuoteEscaping=escapechar)
sdf.write.mode("overwrite").csv(
path=path, sep=sep, nullValue=na_rep, header=header,
encoding=encoding, quote=quotechar, dateFormat=date_format,
charToEscapeQuoteEscaping=escapechar)

def to_json(self, path_or_buf=None, orient=None, date_format=None,
double_precision=10, force_ascii=True, date_unit='ms',
Expand Down
79 changes: 0 additions & 79 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,85 +1018,6 @@ def reset_index(self, level=None, drop=False, name=None, inplace=False):
else:
return kdf

def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True,
index=True, encoding=None, quotechar='"', date_format=None, escapechar=None):
r"""
Write object to a comma-separated values (csv) file.
.. note:: Spark writes files to HDFS by default.
If you want to save the file locally, you need to use path like below
`'files:/' + local paths` like 'files:/work/data.csv'. Otherwise,
you will write the file to the HDFS path where the spark program starts.
Parameters
----------
path_or_buf : str or file handle, default None
File path or object, if None is provided the result is returned as
a string.
sep : str, default ','
String of length 1. Field delimiter for the output file.
na_rep : str, default ''
Missing data representation.
columns : sequence, optional
Columns to write.
header : bool or list of str, default True
Write out the column names. If a list of strings is given it is
assumed to be aliases for the column names.
index : bool, default True
Write row names (index).
encoding : str, optional
A string representing the encoding to use in the output file,
defaults to 'utf-8'.
quotechar : str, default '\"'
String of length 1. Character used to quote fields.
date_format : str, default None
Format string for datetime objects.
escapechar : str, default None
String of length 1. Character used to escape `sep` and `quotechar`
when appropriate.
See Also
--------
read_csv
DataFrame.to_delta
DataFrame.to_table
DataFrame.to_parquet
DataFrame.to_spark_io
Examples
--------
>>> df = ks.DataFrame(dict(
... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')),
... country=['KR', 'US', 'JP'],
... code=[1, 2 ,3]), columns=['date', 'country', 'code'])
>>> df
date country code
0 2012-01-31 12:00:00 KR 1
1 2012-02-29 12:00:00 US 2
2 2012-03-31 12:00:00 JP 3
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path)
"""
if columns is not None:
data_columns = columns
else:
data_columns = self._internal.data_columns

if index:
index_columns = self._internal.index_columns
else:
index_columns = []

if isinstance(header, list):
sdf = self._sdf.select(index_columns +
[self._internal.scol_for(old_name).alias(new_name)
for (old_name, new_name) in zip(data_columns, header)])
header = True
else:
sdf = self._sdf.select(index_columns + data_columns)

sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header,
encoding=encoding, quote=quotechar, dateFormat=date_format,
charToEscapeQuoteEscaping=escapechar)

def to_frame(self, name=None) -> spark.DataFrame:
"""
Convert Series to DataFrame.
Expand Down
70 changes: 51 additions & 19 deletions databricks/koalas/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import shutil
import tempfile
from contextlib import contextmanager
from distutils.version import LooseVersion

import pandas as pd

from databricks import koalas
from databricks import koalas as ks
from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils


Expand All @@ -29,6 +31,12 @@ def normalize_text(s):

class CsvTest(ReusedSQLTestCase, TestUtils):

def setUp(self):
self.tmp_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.tmp_dir)

@property
def csv_text(self):
return normalize_text(
Expand Down Expand Up @@ -84,7 +92,7 @@ def test_read_csv(self):

def check(header='infer', names=None, usecols=None):
expected = pd.read_csv(fn, header=header, names=names, usecols=usecols)
actual = koalas.read_csv(fn, header=header, names=names, usecols=usecols)
actual = ks.read_csv(fn, header=header, names=names, usecols=usecols)
self.assertPandasAlmostEqual(expected, actual.toPandas())

check()
Expand All @@ -107,51 +115,75 @@ def check(header='infer', names=None, usecols=None):

# check with pyspark patch.
expected = pd.read_csv(fn)
actual = koalas.read_csv(fn)
actual = ks.read_csv(fn)
self.assertPandasAlmostEqual(expected, actual.toPandas())

self.assertRaisesRegex(ValueError, 'non-unique',
lambda: koalas.read_csv(fn, names=['n', 'n']))
lambda: ks.read_csv(fn, names=['n', 'n']))
self.assertRaisesRegex(ValueError, 'does not match the number.*3',
lambda: koalas.read_csv(fn, names=['n', 'a', 'b']))
lambda: ks.read_csv(fn, names=['n', 'a', 'b']))
self.assertRaisesRegex(ValueError, 'does not match the number.*3',
lambda: koalas.read_csv(fn, header=0, names=['n', 'a', 'b']))
lambda: ks.read_csv(fn, header=0, names=['n', 'a', 'b']))
self.assertRaisesRegex(ValueError, 'Usecols do not match.*3',
lambda: koalas.read_csv(fn, usecols=[1, 3]))
lambda: ks.read_csv(fn, usecols=[1, 3]))
self.assertRaisesRegex(ValueError, 'Usecols do not match.*col',
lambda: koalas.read_csv(fn, usecols=['amount', 'col']))
lambda: ks.read_csv(fn, usecols=['amount', 'col']))
self.assertRaisesRegex(ValueError, 'Unknown header argument 1',
lambda: koalas.read_csv(fn, header='1'))
lambda: ks.read_csv(fn, header='1'))
expected_error_message = ("'usecols' must either be list-like of all strings, "
"all unicode, all integers or a callable.")
self.assertRaisesRegex(ValueError, expected_error_message,
lambda: koalas.read_csv(fn, usecols=[1, 'amount']))
lambda: ks.read_csv(fn, usecols=[1, 'amount']))

def test_read_with_spark_schema(self):
with self.csv_file(self.csv_text_2) as fn:
actual = koalas.read_csv(fn, names="A string, B string, C long, D long, E long")
actual = ks.read_csv(fn, names="A string, B string, C long, D long, E long")
expected = pd.read_csv(fn, names=['A', 'B', 'C', 'D', 'E'])
self.assertEqual(repr(expected), repr(actual))

def test_read_csv_with_comment(self):
with self.csv_file(self.csv_text_with_comments) as fn:
expected = pd.read_csv(fn, comment='#')
actual = koalas.read_csv(fn, comment='#')
actual = ks.read_csv(fn, comment='#')
self.assertPandasAlmostEqual(expected, actual.toPandas())

self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported',
lambda: koalas.read_csv(fn, comment='').show())
lambda: ks.read_csv(fn, comment='').show())
self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported',
lambda: koalas.read_csv(fn, comment='##').show())
lambda: ks.read_csv(fn, comment='##').show())
self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported',
lambda: koalas.read_csv(fn, comment=1))
lambda: ks.read_csv(fn, comment=1))
self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported',
lambda: koalas.read_csv(fn, comment=[1]))
lambda: ks.read_csv(fn, comment=[1]))

def test_read_csv_with_mangle_dupe_cols(self):
self.assertRaisesRegex(ValueError, 'mangle_dupe_cols',
lambda: koalas.read_csv('path', mangle_dupe_cols=False))
lambda: ks.read_csv('path', mangle_dupe_cols=False))

def test_read_csv_with_parse_dates(self):
self.assertRaisesRegex(ValueError, 'parse_dates',
lambda: koalas.read_csv('path', parse_dates=True))
lambda: ks.read_csv('path', parse_dates=True))

def test_to_csv(self):
pdf = pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']})
kdf = ks.DataFrame(pdf)

kdf.to_csv(self.tmp_dir, num_files=1)
expected = pdf.to_csv(index=False)

output_paths = [path for path in os.listdir(self.tmp_dir) if path.startswith("part-")]
assert len(output_paths) > 0
output_path = "%s/%s" % (self.tmp_dir, output_paths[0])
self.assertEqual(open(output_path).read(), expected)

def test_to_csv_with_basic_options(self):
pdf = pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']})
kdf = ks.DataFrame(pdf)

kdf.to_csv(self.tmp_dir, num_files=1, sep='|', columns=['a'], header=False)
expected = pdf.to_csv(index=False, sep='|', columns=['a'], header=False)

output_paths = [path for path in os.listdir(self.tmp_dir) if path.startswith("part-")]
assert len(output_paths) > 0
output_path = "%s/%s" % (self.tmp_dir, output_paths[0])
self.assertEqual(open(output_path).read(), expected)

0 comments on commit 15ba56f

Please sign in to comment.