Skip to content

Commit

Permalink
feat(export): allow passing keyword arguments to PyArrow `ParquetWrit…
Browse files Browse the repository at this point in the history
…er` and `CSVWriter`
  • Loading branch information
jayceslesar authored and cpcloud committed Nov 18, 2023
1 parent e3b9611 commit 40558fd
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def to_parquet(
import pyarrow.parquet as pq

with expr.to_pyarrow_batches(params=params) as batch_reader:
with pq.ParquetWriter(path, batch_reader.schema) as writer:
with pq.ParquetWriter(path, batch_reader.schema, **kwargs) as writer:
for batch in batch_reader:
writer.write_batch(batch)

Expand Down Expand Up @@ -582,7 +582,7 @@ def to_csv(
import pyarrow.csv as pcsv

with expr.to_pyarrow_batches(params=params) as batch_reader:
with pcsv.CSVWriter(path, batch_reader.schema) as writer:
with pcsv.CSVWriter(path, batch_reader.schema, **kwargs) as writer:
for batch in batch_reader:
writer.write_batch(batch)

Expand Down
22 changes: 20 additions & 2 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ class Backend(AlchemyCrossSchemaBackend, CanCreateSchema):
name = "duckdb"
compiler = DuckDBSQLCompiler
supports_create_or_replace = True
reserved_csv_copy_args = [
"COMPRESSION",
"FORCE_QUOTE",
"DATEFORMAT",
"DELIM",
"SEP",
"ESCAPE",
"HEADER",
"NULLSTR",
"QUOTE",
"TIMESTAMP_FORMAT"
]
reserved_parquet_copy_args = [
"COMPRESSION",
"ROW_GROUP_SIZE",
"ROW_GROUP_SIZE_BYTES",
"FIELD_IDS",
]

@property
def settings(self) -> _Settings:
Expand Down Expand Up @@ -1089,7 +1107,7 @@ def to_parquet(
"""
self._run_pre_execute_hooks(expr)
query = self._to_sql(expr, params=params)
args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items())]
args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_parquet_copy_args)]
copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})"
with self.begin() as con:
con.exec_driver_sql(copy_cmd)
Expand Down Expand Up @@ -1127,7 +1145,7 @@ def to_csv(
args = [
"FORMAT 'csv'",
f"HEADER {int(header)}",
*(f"{k.upper()} {v!r}" for k, v in kwargs.items()),
*(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_csv_copy_args),
]
copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})"
with self.begin() as con:
Expand Down
27 changes: 27 additions & 0 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pyarrow.csv as pcsv
import pytest
import sqlalchemy as sa
from pytest import param
Expand Down Expand Up @@ -220,6 +221,21 @@ def test_table_to_parquet(tmp_path, backend, awards_players):
backend.assert_frame_equal(awards_players.to_pandas(), df)


@pytest.mark.notimpl(["flink"])
@pytest.mark.parametrize(("kwargs"), [({"version": "1.0"}), ({"version": "2.6"})])
def test_table_to_parquet_writer_kwargs(kwargs, tmp_path, backend, awards_players):
outparquet = tmp_path / "out.parquet"
awards_players.to_parquet(outparquet, **kwargs)

df = pd.read_parquet(outparquet)

backend.assert_frame_equal(awards_players.to_pandas(), df)

file = pa.parquet.ParquetFile(outparquet)

assert file.metadata.format_version == kwargs["version"]


@pytest.mark.notimpl(
[
"bigquery",
Expand Down Expand Up @@ -299,6 +315,17 @@ def test_table_to_csv(tmp_path, backend, awards_players):
backend.assert_frame_equal(awards_players.to_pandas(), df)


@pytest.mark.notimpl(["flink"])
@pytest.mark.parametrize(("kwargs", "delimiter"), [({"write_options": pcsv.WriteOptions(delimiter=";")}, ";"), ({"write_options": pcsv.WriteOptions(delimiter="\t")}, "\t")])
def test_table_to_csv_writer_kwargs(kwargs, delimiter, tmp_path, backend, awards_players):
outcsv = tmp_path / "out.csv"
# avoid pandas NaNonense
awards_players = awards_players.select("playerID", "awardID", "yearID", "lgID")

awards_players.to_csv(outcsv, **kwargs)
pd.read_csv(outcsv, delimiter=delimiter)


@pytest.mark.parametrize(
("dtype", "pyarrow_dtype"),
[
Expand Down

0 comments on commit 40558fd

Please sign in to comment.