Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(polars): accept list of CSVs to read_csv #9232

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
from fsspec import AbstractFileSystem


def normalize_filenames(source_list):
# Promote to list
source_list = util.promote_list(source_list)

return list(map(util.normalize_filename, source_list))


_UDF_INPUT_TYPE_MAPPING = {
InputType.PYARROW: duckdb.functional.ARROW,
InputType.PYTHON: duckdb.functional.NATIVE,
Expand Down Expand Up @@ -649,7 +642,7 @@ def read_json(
table_name,
sg.select(STAR).from_(
self.compiler.f.read_json_auto(
normalize_filenames(source_list), *options
util.normalize_filenames(source_list), *options
)
),
)
Expand Down Expand Up @@ -682,7 +675,7 @@ def read_csv(
The just-registered table

"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

if not table_name:
table_name = util.gen_name("read_csv")
Expand Down Expand Up @@ -807,7 +800,7 @@ def read_parquet(
The just-registered table

"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

table_name = table_name or util.gen_name("read_parquet")

Expand Down Expand Up @@ -910,7 +903,7 @@ def read_delta(
The just-registered table.

"""
source_table = normalize_filenames(source_table)[0]
source_table = util.normalize_filenames(source_table)[0]

table_name = table_name or util.gen_name("read_delta")

Expand Down
17 changes: 12 additions & 5 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ibis.backends.sql.dialects import Polars
from ibis.expr.rewrites import lower_stringslice
from ibis.formats.polars import PolarsSchema
from ibis.util import gen_name, normalize_filename
from ibis.util import gen_name, normalize_filename, normalize_filenames

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -158,7 +158,10 @@ def sql(
return self.table(name)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
self,
path: str | Path | list[str | Path] | tuple[str | Path],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV file as a table.

Expand All @@ -180,16 +183,20 @@ def read_csv(
The just-registered table

"""
path = normalize_filename(path)
source_list = normalize_filenames(path)
# Flatten the list if there's only one element because Polars
# can't handle glob strings, or compressed CSVs in a single-element list
if len(source_list) == 1:
source_list = source_list[0]
table_name = table_name or gen_name("read_csv")
try:
table = pl.scan_csv(path, **kwargs)
table = pl.scan_csv(source_list, **kwargs)
# triggers a schema computation to handle compressed csv inference
# and raise a compute error
table.schema # noqa: B018
except pl.exceptions.ComputeError:
# handles compressed csvs
table = pl.read_csv(path, **kwargs)
table = pl.read_csv(source_list, **kwargs)

self._add_table(table_name, table)
return self.table(table_name)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def gzip_csv(data_dir, tmp_path):
"fancy_stones2",
id="multi_csv",
marks=pytest.mark.notyet(
["polars", "datafusion"],
["datafusion"],
reason="doesn't accept multiple files to scan or read",
),
),
Expand Down
7 changes: 7 additions & 0 deletions ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def _absolufy_paths(name):
return source


def normalize_filenames(source_list):
# Promote to list
source_list = promote_list(source_list)

return list(map(normalize_filename, source_list))


def gen_name(namespace: str) -> str:
"""Create a unique identifier."""
uid = base64.b32encode(uuid.uuid4().bytes).decode().rstrip("=").lower()
Expand Down