diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 695bf310562..5581d04f57b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5282,6 +5282,7 @@ def to_sql( name: str, con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], batch_size: Optional[int] = None, + num_proc: Optional[int] = None, **sql_writer_kwargs, ) -> int: """Exports the dataset to a SQL database. @@ -5294,6 +5295,11 @@ def to_sql( batch_size (`int`, *optional*): Size of the batch to load in memory and write at once. Defaults to `datasets.config.DEFAULT_MAX_BATCH_SIZE`. + num_proc (`int`, *optional*): + Number of processes for multiprocessing. By default, it doesn't + use multiprocessing. `batch_size` in this case defaults to + `datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default + value if you have sufficient compute power. **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's [`pandas.DataFrame.to_sql`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html). @@ -5324,7 +5330,7 @@ def to_sql( # Dynamic import to avoid circular dependency from .io.sql import SqlDatasetWriter - return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write() + return SqlDatasetWriter(self, name, con, batch_size=batch_size, num_proc=num_proc, **sql_writer_kwargs).write() def _estimate_nbytes(self) -> int: dataset_nbytes = self.data.nbytes diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4661e8c6dd7..ec2a988bb7f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2712,6 +2712,15 @@ def test_to_sql(self, in_memory): self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + # Test writing with multiprocessors + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, num_proc=3, if_exists="replace") + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + def test_train_test_split(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir) as dset: