diff --git a/databricks/koalas/generic.py b/databricks/koalas/generic.py index 6a24860cb6..af9293d521 100644 --- a/databricks/koalas/generic.py +++ b/databricks/koalas/generic.py @@ -456,21 +456,24 @@ def to_numpy(self): """ return self.to_pandas().values - def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True, - index=True, encoding=None, quotechar='"', date_format=None, escapechar=None): + def to_csv(self, path=None, sep=',', na_rep='', columns=None, header=True, + encoding=None, quotechar='"', date_format=None, escapechar=None, + num_files=None): r""" Write object to a comma-separated values (csv) file. - .. note:: Spark writes files to HDFS by default. - If you want to save the file locally, you need to use path like below - `'files:/' + local paths` like 'files:/work/data.csv'. Otherwise, - you will write the file to the HDFS path where the spark program starts. + .. note:: Koalas `to_csv` writes files to a path or URI. Unlike pandas', Koalas + respects HDFS's property such as 'fs.default.name'. + + .. note:: Koalas writes CSV files into the directory, `path`, and writes + multiple `part-...` files in the directoy when `path` is specified. + This behaviour was inherited from Apache Spark. The number of files can + be controlled by `num_files`. Parameters ---------- - path_or_buf : str or file handle, default None - File path or object, if None is provided the result is returned as - a string. + path : str, default None + File path. If None is provided the result is returned as a string. sep : str, default ',' String of length 1. Field delimiter for the output file. na_rep : str, default '' @@ -480,8 +483,6 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True header : bool or list of str, default True Write out the column names. If a list of strings is given it is assumed to be aliases for the column names. - index : bool, default True - Write row names (index). encoding : str, optional A string representing the encoding to use in the output file, defaults to 'utf-8'. @@ -492,6 +493,8 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True escapechar : str, default None String of length 1. Character used to escape `sep` and `quotechar` when appropriate. + num_files : the number of files to be written in `path_or_buf` directory when + this is a path. See Also -------- @@ -500,40 +503,92 @@ def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True DataFrame.to_table DataFrame.to_parquet DataFrame.to_spark_io + Examples -------- >>> df = ks.DataFrame(dict( ... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')), ... country=['KR', 'US', 'JP'], ... code=[1, 2 ,3]), columns=['date', 'country', 'code']) - >>> df - date country code - 0 2012-01-31 12:00:00 KR 1 - 1 2012-02-29 12:00:00 US 2 - 2 2012-03-31 12:00:00 JP 3 - >>> df.to_csv(path=r'%s/to_csv/foo.csv' % path) + >>> df.sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + date country code + ... 2012-01-31 12:00:00 KR 1 + ... 2012-02-29 12:00:00 US 2 + ... 2012-03-31 12:00:00 JP 3 + + >>> path = "/tmp/a" + + >>> print(df.to_csv()) # doctest: +NORMALIZE_WHITESPACE + date,country,code + 2012-01-31 12:00:00,KR,1 + 2012-02-29 12:00:00,US,2 + 2012-03-31 12:00:00,JP,3 + + >>> df.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1) + >>> ks.read_csv( + ... path=r'%s/to_csv/foo.csv' % path + ... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + date country code + ... 2012-01-31 12:00:00 KR 1 + ... 2012-02-29 12:00:00 US 2 + ... 2012-03-31 12:00:00 JP 3 + + In case of Series, + + >>> print(df.date.to_csv()) # doctest: +NORMALIZE_WHITESPACE + date + 2012-01-31 12:00:00 + 2012-02-29 12:00:00 + 2012-03-31 12:00:00 + + >>> df.date.to_csv(path=r'%s/to_csv/foo.csv' % path, num_files=1) + >>> ks.read_csv( + ... path=r'%s/to_csv/foo.csv' % path + ... ).sort_values(by="date") # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + date + ... 2012-01-31 12:00:00 + ... 2012-02-29 12:00:00 + ... 2012-03-31 12:00:00 """ + if path is None: + # If path is none, just collect and use pandas's to_csv. + kdf = self + if isinstance(self, ks.DataFrame): + f = pd.DataFrame.to_csv + elif isinstance(self, ks.Series): + f = pd.Series.to_csv + else: + raise TypeError('Constructor expects DataFrame or Series; however, ' + 'got [%s]' % (self,)) + return kdf.to_pandas().to_csv( + path_or_buf=None, sep=sep, na_rep=na_rep, columns=columns, + header=header, encoding=encoding, quotechar=quotechar, + date_format=date_format, escapechar=escapechar, index=False) + if columns is not None: data_columns = columns else: data_columns = self._internal.data_columns - if index: - index_columns = self._internal.index_columns - else: - index_columns = [] + kdf = self + if isinstance(self, ks.Series): + kdf = self._kdf if isinstance(header, list): - sdf = self._sdf.select(index_columns + - [self._internal.scol_for(old_name).alias(new_name) - for (old_name, new_name) in zip(data_columns, header)]) + sdf = kdf._sdf.select( + [self._internal.scol_for(old_name).alias(new_name) + for (old_name, new_name) in zip(data_columns, header)]) header = True else: - sdf = self._sdf.select(index_columns + data_columns) + sdf = kdf._sdf.select(data_columns) + + if num_files is not None: + sdf = sdf.repartition(num_files) - sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header, - encoding=encoding, quote=quotechar, dateFormat=date_format, - charToEscapeQuoteEscaping=escapechar) + sdf.write.mode("overwrite").csv( + path=path, sep=sep, nullValue=na_rep, header=header, + encoding=encoding, quote=quotechar, dateFormat=date_format, + charToEscapeQuoteEscaping=escapechar) def to_json(self, path_or_buf=None, orient=None, date_format=None, double_precision=10, force_ascii=True, date_unit='ms', diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index 1a09b5c572..94e66c4106 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -1018,85 +1018,6 @@ def reset_index(self, level=None, drop=False, name=None, inplace=False): else: return kdf - def to_csv(self, path_or_buf=None, sep=',', na_rep='', columns=None, header=True, - index=True, encoding=None, quotechar='"', date_format=None, escapechar=None): - r""" - Write object to a comma-separated values (csv) file. - - .. note:: Spark writes files to HDFS by default. - If you want to save the file locally, you need to use path like below - `'files:/' + local paths` like 'files:/work/data.csv'. Otherwise, - you will write the file to the HDFS path where the spark program starts. - - Parameters - ---------- - path_or_buf : str or file handle, default None - File path or object, if None is provided the result is returned as - a string. - sep : str, default ',' - String of length 1. Field delimiter for the output file. - na_rep : str, default '' - Missing data representation. - columns : sequence, optional - Columns to write. - header : bool or list of str, default True - Write out the column names. If a list of strings is given it is - assumed to be aliases for the column names. - index : bool, default True - Write row names (index). - encoding : str, optional - A string representing the encoding to use in the output file, - defaults to 'utf-8'. - quotechar : str, default '\"' - String of length 1. Character used to quote fields. - date_format : str, default None - Format string for datetime objects. - escapechar : str, default None - String of length 1. Character used to escape `sep` and `quotechar` - when appropriate. - - See Also - -------- - read_csv - DataFrame.to_delta - DataFrame.to_table - DataFrame.to_parquet - DataFrame.to_spark_io - Examples - -------- - >>> df = ks.DataFrame(dict( - ... date=list(pd.date_range('2012-1-1 12:00:00', periods=3, freq='M')), - ... country=['KR', 'US', 'JP'], - ... code=[1, 2 ,3]), columns=['date', 'country', 'code']) - >>> df - date country code - 0 2012-01-31 12:00:00 KR 1 - 1 2012-02-29 12:00:00 US 2 - 2 2012-03-31 12:00:00 JP 3 - >>> df.to_csv(path=r'%s/to_csv/foo.csv' % path) - """ - if columns is not None: - data_columns = columns - else: - data_columns = self._internal.data_columns - - if index: - index_columns = self._internal.index_columns - else: - index_columns = [] - - if isinstance(header, list): - sdf = self._sdf.select(index_columns + - [self._internal.scol_for(old_name).alias(new_name) - for (old_name, new_name) in zip(data_columns, header)]) - header = True - else: - sdf = self._sdf.select(index_columns + data_columns) - - sdf.write.csv(path=path_or_buf, sep=sep, nullValue=na_rep, header=header, - encoding=encoding, quote=quotechar, dateFormat=date_format, - charToEscapeQuoteEscaping=escapechar) - def to_frame(self, name=None) -> spark.DataFrame: """ Convert Series to DataFrame. diff --git a/databricks/koalas/tests/test_csv.py b/databricks/koalas/tests/test_csv.py index 3bdc0e6bcd..4efe091333 100644 --- a/databricks/koalas/tests/test_csv.py +++ b/databricks/koalas/tests/test_csv.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import shutil +import tempfile from contextlib import contextmanager from distutils.version import LooseVersion import pandas as pd -from databricks import koalas +from databricks import koalas as ks from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils @@ -29,6 +31,12 @@ def normalize_text(s): class CsvTest(ReusedSQLTestCase, TestUtils): + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + @property def csv_text(self): return normalize_text( @@ -84,7 +92,7 @@ def test_read_csv(self): def check(header='infer', names=None, usecols=None): expected = pd.read_csv(fn, header=header, names=names, usecols=usecols) - actual = koalas.read_csv(fn, header=header, names=names, usecols=usecols) + actual = ks.read_csv(fn, header=header, names=names, usecols=usecols) self.assertPandasAlmostEqual(expected, actual.toPandas()) check() @@ -107,51 +115,75 @@ def check(header='infer', names=None, usecols=None): # check with pyspark patch. expected = pd.read_csv(fn) - actual = koalas.read_csv(fn) + actual = ks.read_csv(fn) self.assertPandasAlmostEqual(expected, actual.toPandas()) self.assertRaisesRegex(ValueError, 'non-unique', - lambda: koalas.read_csv(fn, names=['n', 'n'])) + lambda: ks.read_csv(fn, names=['n', 'n'])) self.assertRaisesRegex(ValueError, 'does not match the number.*3', - lambda: koalas.read_csv(fn, names=['n', 'a', 'b'])) + lambda: ks.read_csv(fn, names=['n', 'a', 'b'])) self.assertRaisesRegex(ValueError, 'does not match the number.*3', - lambda: koalas.read_csv(fn, header=0, names=['n', 'a', 'b'])) + lambda: ks.read_csv(fn, header=0, names=['n', 'a', 'b'])) self.assertRaisesRegex(ValueError, 'Usecols do not match.*3', - lambda: koalas.read_csv(fn, usecols=[1, 3])) + lambda: ks.read_csv(fn, usecols=[1, 3])) self.assertRaisesRegex(ValueError, 'Usecols do not match.*col', - lambda: koalas.read_csv(fn, usecols=['amount', 'col'])) + lambda: ks.read_csv(fn, usecols=['amount', 'col'])) self.assertRaisesRegex(ValueError, 'Unknown header argument 1', - lambda: koalas.read_csv(fn, header='1')) + lambda: ks.read_csv(fn, header='1')) expected_error_message = ("'usecols' must either be list-like of all strings, " "all unicode, all integers or a callable.") self.assertRaisesRegex(ValueError, expected_error_message, - lambda: koalas.read_csv(fn, usecols=[1, 'amount'])) + lambda: ks.read_csv(fn, usecols=[1, 'amount'])) def test_read_with_spark_schema(self): with self.csv_file(self.csv_text_2) as fn: - actual = koalas.read_csv(fn, names="A string, B string, C long, D long, E long") + actual = ks.read_csv(fn, names="A string, B string, C long, D long, E long") expected = pd.read_csv(fn, names=['A', 'B', 'C', 'D', 'E']) self.assertEqual(repr(expected), repr(actual)) def test_read_csv_with_comment(self): with self.csv_file(self.csv_text_with_comments) as fn: expected = pd.read_csv(fn, comment='#') - actual = koalas.read_csv(fn, comment='#') + actual = ks.read_csv(fn, comment='#') self.assertPandasAlmostEqual(expected, actual.toPandas()) self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported', - lambda: koalas.read_csv(fn, comment='').show()) + lambda: ks.read_csv(fn, comment='').show()) self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported', - lambda: koalas.read_csv(fn, comment='##').show()) + lambda: ks.read_csv(fn, comment='##').show()) self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported', - lambda: koalas.read_csv(fn, comment=1)) + lambda: ks.read_csv(fn, comment=1)) self.assertRaisesRegex(ValueError, 'Only length-1 comment characters supported', - lambda: koalas.read_csv(fn, comment=[1])) + lambda: ks.read_csv(fn, comment=[1])) def test_read_csv_with_mangle_dupe_cols(self): self.assertRaisesRegex(ValueError, 'mangle_dupe_cols', - lambda: koalas.read_csv('path', mangle_dupe_cols=False)) + lambda: ks.read_csv('path', mangle_dupe_cols=False)) def test_read_csv_with_parse_dates(self): self.assertRaisesRegex(ValueError, 'parse_dates', - lambda: koalas.read_csv('path', parse_dates=True)) + lambda: ks.read_csv('path', parse_dates=True)) + + def test_to_csv(self): + pdf = pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}) + kdf = ks.DataFrame(pdf) + + kdf.to_csv(self.tmp_dir, num_files=1) + expected = pdf.to_csv(index=False) + + output_paths = [path for path in os.listdir(self.tmp_dir) if path.startswith("part-")] + assert len(output_paths) > 0 + output_path = "%s/%s" % (self.tmp_dir, output_paths[0]) + self.assertEqual(open(output_path).read(), expected) + + def test_to_csv_with_basic_options(self): + pdf = pd.DataFrame({'a': [1, 2, 3], 'b': ['a', 'b', 'c']}) + kdf = ks.DataFrame(pdf) + + kdf.to_csv(self.tmp_dir, num_files=1, sep='|', columns=['a'], header=False) + expected = pdf.to_csv(index=False, sep='|', columns=['a'], header=False) + + output_paths = [path for path in os.listdir(self.tmp_dir) if path.startswith("part-")] + assert len(output_paths) > 0 + output_path = "%s/%s" % (self.tmp_dir, output_paths[0]) + self.assertEqual(open(output_path).read(), expected)