diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 88bfe9e014c..d60454c018d 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3187,6 +3187,7 @@ def to_csv( self, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_csv_kwargs, ) -> int: """Exports the dataset to csv @@ -3195,6 +3196,10 @@ def to_csv( path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO. batch_size (Optional ``int``): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + num_proc (:obj:`int`, optional): Number of processes for multiprocessing. By default it doesn't + use multiprocessing. ``batch_size`` in this case defaults to + :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default + value if you have sufficient compute power. to_csv_kwargs: Parameters to pass to pandas's :func:`pandas.DataFrame.to_csv` Returns: @@ -3203,7 +3208,7 @@ def to_csv( # Dynamic import to avoid circular dependency from .io.csv import CsvDatasetWriter - return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, **to_csv_kwargs).write() + return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_csv_kwargs).write() def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]: """Returns the dataset as a Python dict. Can also return a generator for large datasets. diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index a63e5bd61cf..97402e0538b 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -1,9 +1,11 @@ +import multiprocessing import os from typing import BinaryIO, Optional, Union -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, Features, NamedSplit, config, utils from ..formatting import query_table from ..packaged_modules.csv.csv import Csv +from ..utils import logging from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -58,41 +60,69 @@ def __init__( dataset: Dataset, path_or_buf: Union[PathLike, BinaryIO], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **to_csv_kwargs, ): + assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0." self.dataset = dataset self.path_or_buf = path_or_buf - self.batch_size = batch_size + self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + self.num_proc = num_proc + self.encoding = "utf-8" self.to_csv_kwargs = to_csv_kwargs def write(self) -> int: - batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE + _ = self.to_csv_kwargs.pop("path_or_buf", None) if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): with open(self.path_or_buf, "wb+") as buffer: - written = self._write(file_obj=buffer, batch_size=batch_size, **self.to_csv_kwargs) + written = self._write(file_obj=buffer, **self.to_csv_kwargs) else: - written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.to_csv_kwargs) + written = self._write(file_obj=self.path_or_buf, **self.to_csv_kwargs) return written - def _write( - self, file_obj: BinaryIO, batch_size: int, header: bool = True, encoding: str = "utf-8", **to_csv_kwargs - ) -> int: + def _batch_csv(self, args): + offset, header, to_csv_kwargs = args + + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + self.batch_size), + indices=self.dataset._indices, + ) + csv_str = batch.to_pandas().to_csv( + path_or_buf=None, header=header if (offset == 0) else False, **to_csv_kwargs + ) + return csv_str.encode(self.encoding) + + def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> int: """Writes the pyarrow table as CSV to a binary file handle. Caller is responsible for opening and closing the handle. """ written = 0 - _ = to_csv_kwargs.pop("path_or_buf", None) - - for offset in range(0, len(self.dataset), batch_size): - batch = query_table( - table=self.dataset._data, - key=slice(offset, offset + batch_size), - indices=self.dataset._indices if self.dataset._indices is not None else None, - ) - csv_str = batch.to_pandas().to_csv( - path_or_buf=None, header=header if (offset == 0) else False, encoding=encoding, **to_csv_kwargs - ) - written += file_obj.write(csv_str.encode(encoding)) + + if self.num_proc is None or self.num_proc == 1: + for offset in utils.tqdm( + range(0, len(self.dataset), self.batch_size), + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating CSV from Arrow format", + ): + csv_str = self._batch_csv((offset, header, to_csv_kwargs)) + written += file_obj.write(csv_str) + + else: + with multiprocessing.Pool(self.num_proc) as pool: + for csv_str in utils.tqdm( + pool.imap( + self._batch_csv, + [(offset, header, to_csv_kwargs) for offset in range(0, len(self.dataset), self.batch_size)], + ), + total=(len(self.dataset) // self.batch_size) + 1, + unit="ba", + disable=bool(logging.get_verbosity() == logging.NOTSET), + desc="Creating CSV from Arrow format", + ): + written += file_obj.write(csv_str) + return written diff --git a/tests/conftest.py b/tests/conftest.py index f929e00bb4e..de238deda2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -224,7 +224,7 @@ def arrow_path(tmp_path_factory): @pytest.fixture(scope="session") def csv_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.csv") - with open(path, "w") as f: + with open(path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) writer.writeheader() for item in DATA: @@ -235,7 +235,7 @@ def csv_path(tmp_path_factory): @pytest.fixture(scope="session") def csv2_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset2.csv") - with open(path, "w") as f: + with open(path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) writer.writeheader() for item in DATA: diff --git a/tests/io/test_csv.py b/tests/io/test_csv.py index 7053ae09910..69c1e1d5092 100644 --- a/tests/io/test_csv.py +++ b/tests/io/test_csv.py @@ -1,7 +1,10 @@ +import csv +import os + import pytest from datasets import Dataset, DatasetDict, Features, NamedSplit, Value -from datasets.io.csv import CsvDatasetReader +from datasets.io.csv import CsvDatasetReader, CsvDatasetWriter from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases @@ -121,3 +124,34 @@ def test_csv_datasetdict_reader_split(split, csv_path, tmp_path): dataset = CsvDatasetReader(path, cache_dir=cache_dir).read() _check_csv_datasetdict(dataset, expected_features, splits=list(path.keys())) assert all(dataset[split].split == split for split in path.keys()) + + +def iter_csv_file(csv_path): + with open(csv_path, "r", encoding="utf-8") as csvfile: + yield from csv.reader(csvfile) + + +def test_dataset_to_csv(csv_path, tmp_path): + cache_dir = tmp_path / "cache" + output_csv = os.path.join(cache_dir, "tmp.csv") + dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() + CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=1).write() + + original_csv = iter_csv_file(csv_path) + expected_csv = iter_csv_file(output_csv) + + for row1, row2 in zip(original_csv, expected_csv): + assert row1 == row2 + + +def test_dataset_to_csv_multiproc(csv_path, tmp_path): + cache_dir = tmp_path / "cache" + output_csv = os.path.join(cache_dir, "tmp.csv") + dataset = CsvDatasetReader({"train": csv_path}, cache_dir=cache_dir).read() + CsvDatasetWriter(dataset["train"], output_csv, index=False, num_proc=2).write() + + original_csv = iter_csv_file(csv_path) + expected_csv = iter_csv_file(output_csv) + + for row1, row2 in zip(original_csv, expected_csv): + assert row1 == row2