Skip to content

Commit

Permalink
feat(polars): accept list of CSVs to read_csv (#9232)
Browse files Browse the repository at this point in the history
Polars now accepts a list of CSVs to `scan_csv` so we can expose this to
end-users.

A few caveats:
- it doesn't accept lists of globs, or lists of compressed CSVs, so we
  flatten the list of paths if it only has one element in case it is a
  glob or csv.gz

Fixes #9230
  • Loading branch information
gforsyth authored May 22, 2024
1 parent c06285e commit 7a272e3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
15 changes: 4 additions & 11 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
from fsspec import AbstractFileSystem


def normalize_filenames(source_list):
# Promote to list
source_list = util.promote_list(source_list)

return list(map(util.normalize_filename, source_list))


_UDF_INPUT_TYPE_MAPPING = {
InputType.PYARROW: duckdb.functional.ARROW,
InputType.PYTHON: duckdb.functional.NATIVE,
Expand Down Expand Up @@ -649,7 +642,7 @@ def read_json(
table_name,
sg.select(STAR).from_(
self.compiler.f.read_json_auto(
normalize_filenames(source_list), *options
util.normalize_filenames(source_list), *options
)
),
)
Expand Down Expand Up @@ -682,7 +675,7 @@ def read_csv(
The just-registered table
"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

if not table_name:
table_name = util.gen_name("read_csv")
Expand Down Expand Up @@ -807,7 +800,7 @@ def read_parquet(
The just-registered table
"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

table_name = table_name or util.gen_name("read_parquet")

Expand Down Expand Up @@ -910,7 +903,7 @@ def read_delta(
The just-registered table.
"""
source_table = normalize_filenames(source_table)[0]
source_table = util.normalize_filenames(source_table)[0]

table_name = table_name or util.gen_name("read_delta")

Expand Down
17 changes: 12 additions & 5 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ibis.backends.sql.dialects import Polars
from ibis.expr.rewrites import lower_stringslice
from ibis.formats.polars import PolarsSchema
from ibis.util import gen_name, normalize_filename
from ibis.util import gen_name, normalize_filename, normalize_filenames

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -158,7 +158,10 @@ def sql(
return self.table(name)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
self,
path: str | Path | list[str | Path] | tuple[str | Path],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV file as a table.
Expand All @@ -180,16 +183,20 @@ def read_csv(
The just-registered table
"""
path = normalize_filename(path)
source_list = normalize_filenames(path)
# Flatten the list if there's only one element because Polars
# can't handle glob strings, or compressed CSVs in a single-element list
if len(source_list) == 1:
source_list = source_list[0]
table_name = table_name or gen_name("read_csv")
try:
table = pl.scan_csv(path, **kwargs)
table = pl.scan_csv(source_list, **kwargs)
# triggers a schema computation to handle compressed csv inference
# and raise a compute error
table.schema # noqa: B018
except pl.exceptions.ComputeError:
# handles compressed csvs
table = pl.read_csv(path, **kwargs)
table = pl.read_csv(source_list, **kwargs)

self._add_table(table_name, table)
return self.table(table_name)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def gzip_csv(data_dir, tmp_path):
"fancy_stones2",
id="multi_csv",
marks=pytest.mark.notyet(
["polars", "datafusion"],
["datafusion"],
reason="doesn't accept multiple files to scan or read",
),
),
Expand Down
7 changes: 7 additions & 0 deletions ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ def _absolufy_paths(name):
return source


def normalize_filenames(source_list):
# Promote to list
source_list = promote_list(source_list)

return list(map(normalize_filename, source_list))


def gen_name(namespace: str) -> str:
"""Create a unique identifier."""
uid = base64.b32encode(uuid.uuid4().bytes).decode().rstrip("=").lower()
Expand Down

0 comments on commit 7a272e3

Please sign in to comment.