From a88659a0158affb914c87d394223212628a888cd Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 22 May 2024 10:05:54 -0400 Subject: [PATCH] feat(polars): accept list of CSVs to read_csv Polars now accepts a list of CSVs to `scan_csv` so we can expose this to end-users. A few caveats: - it doesn't accept lists of globs, or lists of compressed CSVs, so we flatten the list of paths if it only has one element in case it is a glob or csv.gz --- ibis/backends/duckdb/__init__.py | 15 ++++----------- ibis/backends/polars/__init__.py | 17 ++++++++++++----- ibis/backends/tests/test_register.py | 2 +- ibis/util.py | 7 +++++++ 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 04542d690202..712f6fcbedad 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -38,13 +38,6 @@ from fsspec import AbstractFileSystem -def normalize_filenames(source_list): - # Promote to list - source_list = util.promote_list(source_list) - - return list(map(util.normalize_filename, source_list)) - - _UDF_INPUT_TYPE_MAPPING = { InputType.PYARROW: duckdb.functional.ARROW, InputType.PYTHON: duckdb.functional.NATIVE, @@ -649,7 +642,7 @@ def read_json( table_name, sg.select(STAR).from_( self.compiler.f.read_json_auto( - normalize_filenames(source_list), *options + util.normalize_filenames(source_list), *options ) ), ) @@ -682,7 +675,7 @@ def read_csv( The just-registered table """ - source_list = normalize_filenames(source_list) + source_list = util.normalize_filenames(source_list) if not table_name: table_name = util.gen_name("read_csv") @@ -807,7 +800,7 @@ def read_parquet( The just-registered table """ - source_list = normalize_filenames(source_list) + source_list = util.normalize_filenames(source_list) table_name = table_name or util.gen_name("read_parquet") @@ -910,7 +903,7 @@ def read_delta( The just-registered table. """ - source_table = normalize_filenames(source_table)[0] + source_table = util.normalize_filenames(source_table)[0] table_name = table_name or util.gen_name("read_delta") diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index d16a0d0b4f2f..d2620b9d9e0d 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -22,7 +22,7 @@ from ibis.backends.sql.dialects import Polars from ibis.expr.rewrites import lower_stringslice from ibis.formats.polars import PolarsSchema -from ibis.util import gen_name, normalize_filename +from ibis.util import gen_name, normalize_filename, normalize_filenames if TYPE_CHECKING: from collections.abc import Iterable @@ -158,7 +158,10 @@ def sql( return self.table(name) def read_csv( - self, path: str | Path, table_name: str | None = None, **kwargs: Any + self, + path: str | Path | list[str | Path] | tuple[str | Path], + table_name: str | None = None, + **kwargs: Any, ) -> ir.Table: """Register a CSV file as a table. @@ -180,16 +183,20 @@ def read_csv( The just-registered table """ - path = normalize_filename(path) + source_list = normalize_filenames(path) + # Flatten the list if there's only one element because Polars + # can't handle glob strings, or compressed CSVs in a single-element list + if len(source_list) == 1: + source_list = source_list[0] table_name = table_name or gen_name("read_csv") try: - table = pl.scan_csv(path, **kwargs) + table = pl.scan_csv(source_list, **kwargs) # triggers a schema computation to handle compressed csv inference # and raise a compute error table.schema # noqa: B018 except pl.exceptions.ComputeError: # handles compressed csvs - table = pl.read_csv(path, **kwargs) + table = pl.read_csv(source_list, **kwargs) self._add_table(table_name, table) return self.table(table_name) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index dec45a4571e6..7ce1dd85ac4e 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -74,7 +74,7 @@ def gzip_csv(data_dir, tmp_path): "fancy_stones2", id="multi_csv", marks=pytest.mark.notyet( - ["polars", "datafusion"], + ["datafusion"], reason="doesn't accept multiple files to scan or read", ), ), diff --git a/ibis/util.py b/ibis/util.py index 10454b9e879a..59051ac20669 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -512,6 +512,13 @@ def _absolufy_paths(name): return source +def normalize_filenames(source_list): + # Promote to list + source_list = promote_list(source_list) + + return list(map(normalize_filename, source_list)) + + def gen_name(namespace: str) -> str: """Create a unique identifier.""" uid = base64.b32encode(uuid.uuid4().bytes).decode().rstrip("=").lower()