Skip to content

Commit 7db2d08

Browse files
committed
Use spark.write.csv in to_csv of Series and DataFrame
1 parent 28f29da commit 7db2d08

File tree

4 files changed

+182
-161
lines changed

4 files changed

+182
-161
lines changed

databricks/koalas/generic.py

+92-32
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
import warnings
2121
from collections import Counter
2222
from collections.abc import Iterable
23+
from distutils.version import LooseVersion
2324

2425
import numpy as np
2526
import pandas as pd
2627

2728
from pyspark import sql as spark
2829
from pyspark.sql import functions as F
30+
from pyspark.sql.readwriter import OptionUtils
2931
from pyspark.sql.types import DataType, DoubleType, FloatType
3032

3133
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
@@ -456,21 +458,24 @@ def to_numpy(self):
456458
"""
457459
return self.to_pandas().values
458460

459-
def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True,
460-
index=True, encoding=None, quotechar='"', date_format=None, escapechar=None):
461+
def to_csv(self, path=None, sep=',', na_rep='', columns=None, header=True,
462+
quotechar='"', date_format=None, escapechar=None, num_files=None,
463+
**options):
461464
r"""
462465
Write object to a comma-separated values (csv) file.
463466
464-
.. note:: Spark writes files to HDFS by default.
465-
If you want to save the file locally, you need to use path like below
466-
`'files:/' + local paths` like 'files:/work/data.csv'. Otherwise,
467-
you will write the file to the HDFS path where the spark program starts.
467+
.. note:: Koalas `to_csv` writes files to a path or URI. Unlike pandas', Koalas
468+
respects HDFS's property such as 'fs.default.name'.
469+
470+
.. note:: Koalas writes CSV files into the directory, `path`, and writes
471+
multiple `part-...` files in the directory when `path` is specified.
472+
This behaviour was inherited from Apache Spark. The number of files can
473+
be controlled by `num_files`.
468474
469475
Parameters
470476
----------
471-
path_or_buf : str or file handle, default None
472-
File path or object, if None is provided the result is returned as
473-
a string.
477+
path : str, default None
478+
File path. If None is provided the result is returned as a string.
474479
sep : str, default ','
475480
String of length 1. Field delimiter for the output file.
476481
na_rep : str, default ''
@@ -480,18 +485,20 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True
480485
header : bool or list of str, default True
481486
Write out the column names. If a list of strings is given it is
482487
assumed to be aliases for the column names.
483-
index : bool, default True
484-
Write row names (index).
485-
encoding : str, optional
486-
A string representing the encoding to use in the output file,
487-
defaults to 'utf-8'.
488488
quotechar : str, default '\"'
489489
String of length 1. Character used to quote fields.
490490
date_format : str, default None
491491
Format string for datetime objects.
492492
escapechar : str, default None
493493
String of length 1. Character used to escape `sep` and `quotechar`
494494
when appropriate.
495+
num_files : the number of files to be written in `path` directory when
496+
this is a path.
497+
options: keyword arguments for additional options specific to PySpark.
498+
This kwargs are specific to PySpark's CSV options to pass. Check
499+
the options in PySpark's API documentation for spark.write.csv(...).
500+
It has higher priority and overwrites all other options.
501+
This parameter only works when `path` is specified.
495502
496503
See Also
497504
--------
@@ -500,40 +507,93 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True
500507
DataFrame.to_table
501508
DataFrame.to_parquet
502509
DataFrame.to_spark_io
510+
503511
Examples
504512
--------
505513
>>> df = ks.DataFrame(dict(
506514
... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')),
507515
... country=['KR', 'US', 'JP'],
508516
... code=[1, 2 ,3]), columns=['date', 'country', 'code'])
509-
>>> df
510-
date country code
511-
0 2012-01-31 12:00:00 KR 1
512-
1 2012-02-29 12:00:00 US 2
513-
2 2012-03-31 12:00:00 JP 3
514-
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path)
517+
>>> df.sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
518+
date country code
519+
... 2012-01-31 12:00:00 KR 1
520+
... 2012-02-29 12:00:00 US 2
521+
... 2012-03-31 12:00:00 JP 3
522+
523+
>>> print(df.to_csv()) # doctest: +NORMALIZE_WHITESPACE
524+
date,country,code
525+
2012-01-31 12:00:00,KR,1
526+
2012-02-29 12:00:00,US,2
527+
2012-03-31 12:00:00,JP,3
528+
529+
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1)
530+
>>> ks.read_csv(
531+
... path=r'%s/to_csv/foo.csv' % path
532+
... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
533+
date country code
534+
... 2012-01-31 12:00:00 KR 1
535+
... 2012-02-29 12:00:00 US 2
536+
... 2012-03-31 12:00:00 JP 3
537+
538+
In case of Series,
539+
540+
>>> print(df.date.to_csv()) # doctest: +NORMALIZE_WHITESPACE
541+
date
542+
2012-01-31 12:00:00
543+
2012-02-29 12:00:00
544+
2012-03-31 12:00:00
545+
546+
>>> df.date.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1)
547+
>>> ks.read_csv(
548+
... path=r'%s/to_csv/foo.csv' % path
549+
... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
550+
date
551+
... 2012-01-31 12:00:00
552+
... 2012-02-29 12:00:00
553+
... 2012-03-31 12:00:00
515554
"""
555+
if path is None:
556+
# If path is none, just collect and use pandas's to_csv.
557+
kdf_or_ser = self
558+
if (LooseVersion("0.24") > LooseVersion(pd.__version__)) and \
559+
isinstance(self, ks.Series):
560+
# 0.23 seems not having 'columns' parameter in Series' to_csv.
561+
return kdf_or_ser.to_pandas().to_csv(
562+
None, sep=sep, na_rep=na_rep, header=header,
563+
date_format=date_format, index=False)
564+
else:
565+
return kdf_or_ser.to_pandas().to_csv(
566+
None, sep=sep, na_rep=na_rep, columns=columns,
567+
header=header, quotechar=quotechar,
568+
date_format=date_format, escapechar=escapechar, index=False)
569+
516570
if columns is not None:
517571
data_columns = columns
518572
else:
519573
data_columns = self._internal.data_columns
520574

521-
if index:
522-
index_columns = self._internal.index_columns
523-
else:
524-
index_columns = []
575+
kdf = self
576+
if isinstance(self, ks.Series):
577+
kdf = self._kdf
525578

526579
if isinstance(header, list):
527-
sdf = self._sdf.select(index_columns +
528-
[self._internal.scol_for(old_name).alias(new_name)
529-
for (old_name, new_name) in zip(data_columns, header)])
580+
sdf = kdf._sdf.select(
581+
[self._internal.scol_for(old_name).alias(new_name)
582+
for (old_name, new_name) in zip(data_columns, header)])
530583
header = True
531584
else:
532-
sdf = self._sdf.select(index_columns + data_columns)
533-
534-
sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header,
535-
encoding=encoding, quote=quotechar, dateFormat=date_format,
536-
charToEscapeQuoteEscaping=escapechar)
585+
sdf = kdf._sdf.select(data_columns)
586+
587+
if num_files is not None:
588+
sdf = sdf.repartition(num_files)
589+
590+
builder = sdf.write.mode("overwrite")
591+
OptionUtils._set_opts(
592+
builder,
593+
path=path, sep=sep, nullValue=na_rep, header=header,
594+
quote=quotechar, dateFormat=date_format,
595+
charToEscapeQuoteEscaping=escapechar)
596+
builder.options(**options).format("csv").save(path)
537597

538598
def to_json(self, path_or_buf=None, orient=None, date_format=None,
539599
double_precision=10, force_ascii=True, date_unit='ms',

databricks/koalas/series.py

-79
Original file line numberDiff line numberDiff line change
@@ -1016,85 +1016,6 @@ def reset_index(self, level=None, drop=False, name=None, inplace=False):
10161016
else:
10171017
return kdf
10181018

1019-
def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True,
1020-
index=True, encoding=None, quotechar='"', date_format=None, escapechar=None):
1021-
r"""
1022-
Write object to a comma-separated values (csv) file.
1023-
1024-
.. note:: Spark writes files to HDFS by default.
1025-
If you want to save the file locally, you need to use path like below
1026-
`'files:/' + local paths` like 'files:/work/data.csv'. Otherwise,
1027-
you will write the file to the HDFS path where the spark program starts.
1028-
1029-
Parameters
1030-
----------
1031-
path_or_buf : str or file handle, default None
1032-
File path or object, if None is provided the result is returned as
1033-
a string.
1034-
sep : str, default ','
1035-
String of length 1. Field delimiter for the output file.
1036-
na_rep : str, default ''
1037-
Missing data representation.
1038-
columns : sequence, optional
1039-
Columns to write.
1040-
header : bool or list of str, default True
1041-
Write out the column names. If a list of strings is given it is
1042-
assumed to be aliases for the column names.
1043-
index : bool, default True
1044-
Write row names (index).
1045-
encoding : str, optional
1046-
A string representing the encoding to use in the output file,
1047-
defaults to 'utf-8'.
1048-
quotechar : str, default '\"'
1049-
String of length 1. Character used to quote fields.
1050-
date_format : str, default None
1051-
Format string for datetime objects.
1052-
escapechar : str, default None
1053-
String of length 1. Character used to escape `sep` and `quotechar`
1054-
when appropriate.
1055-
1056-
See Also
1057-
--------
1058-
read_csv
1059-
DataFrame.to_delta
1060-
DataFrame.to_table
1061-
DataFrame.to_parquet
1062-
DataFrame.to_spark_io
1063-
Examples
1064-
--------
1065-
>>> df = ks.DataFrame(dict(
1066-
... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')),
1067-
... country=['KR', 'US', 'JP'],
1068-
... code=[1, 2 ,3]), columns=['date', 'country', 'code'])
1069-
>>> df
1070-
date country code
1071-
0 2012-01-31 12:00:00 KR 1
1072-
1 2012-02-29 12:00:00 US 2
1073-
2 2012-03-31 12:00:00 JP 3
1074-
>>> df.to_csv(path=r'%s/to_csv/foo.csv' % path)
1075-
"""
1076-
if columns is not None:
1077-
data_columns = columns
1078-
else:
1079-
data_columns = self._internal.data_columns
1080-
1081-
if index:
1082-
index_columns = self._internal.index_columns
1083-
else:
1084-
index_columns = []
1085-
1086-
if isinstance(header, list):
1087-
sdf = self._sdf.select(index_columns +
1088-
[self._internal.scol_for(old_name).alias(new_name)
1089-
for (old_name, new_name) in zip(data_columns, header)])
1090-
header = True
1091-
else:
1092-
sdf = self._sdf.select(index_columns + data_columns)
1093-
1094-
sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header,
1095-
encoding=encoding, quote=quotechar, dateFormat=date_format,
1096-
charToEscapeQuoteEscaping=escapechar)
1097-
10981019
def to_frame(self, name=None) -> spark.DataFrame:
10991020
"""
11001021
Convert Series to DataFrame.

0 commit comments

Comments
 (0)