From d44952161ae57338cd431d2899abcdbcf12b1a9e Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 6 Jan 2025 10:12:29 +0100 Subject: [PATCH 01/17] feat: LazyFrame.collect kwargs --- narwhals/_dask/dataframe.py | 5 +- narwhals/_polars/dataframe.py | 4 +- narwhals/dataframe.py | 120 +++++++++++++++++++++++++++------ narwhals/stable/v1/__init__.py | 111 ++++++++++++++++++++++++------ 4 files changed, 194 insertions(+), 46 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 5e652a937e..b65f5f3928 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -29,6 +29,7 @@ from narwhals._dask.group_by import DaskLazyGroupBy from narwhals._dask.namespace import DaskNamespace from narwhals._dask.typing import IntoDaskExpr + from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals.dtypes import DType from narwhals.utils import Version @@ -78,12 +79,12 @@ def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = df.assign(**new_series) return self._from_native_frame(df) - def collect(self) -> Any: + def collect(self: Self, **kwargs: Any) -> PandasLikeDataFrame: import pandas as pd from narwhals._pandas_like.dataframe import PandasLikeDataFrame - result = self._native_frame.compute() + result = self._native_frame.compute(**kwargs) return PandasLikeDataFrame( result, implementation=Implementation.PANDAS, diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index d5e115284c..cc6f622037 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -425,11 +425,11 @@ def collect_schema(self: Self) -> dict[str, DType]: for name, dtype in self._native_frame.collect_schema().items() } - def collect(self: Self) -> PolarsDataFrame: + def collect(self: Self, **kwargs: Any) -> PolarsDataFrame: import polars as pl try: - result = self._native_frame.collect() + result = self._native_frame.collect(**kwargs) except pl.exceptions.ColumnNotFoundError as e: raise ColumnNotFoundError(str(e)) from e diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index dd786ef3df..304e985db4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3610,16 +3610,35 @@ def __getitem__(self, item: str | slice) -> NoReturn: msg = "Slicing is not supported on LazyFrame" raise TypeError(msg) - def collect(self) -> DataFrame[Any]: + def collect( + self: Self, + *, + polars_kwargs: dict[str, Any] | None = None, + dask_kwargs: dict[str, Any] | None = None, + ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. + As each underlying lazyframe has different arguments to set when materializing + the lazyframe into a dataframe, we allow to pass them separately into its own + keyword argument. + + Arguments: + polars_kwargs: [polars.LazyFrame.collect](https://docs.pola.rs/api/python/dev/reference/lazyframe/api/polars.LazyFrame.collect.html) + arguments. Used only if the `LazyFrame` is backed by a `polars.LazyFrame`. + If not provided, it uses the polars default values. + dask_kwargs: [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) + arguments. Used only if the `LazyFrame` is backed by a `dask.dataframe.DataFrame`. + If not provided, it uses the dask default values. + Returns: DataFrame Examples: - >>> import narwhals as nw >>> import polars as pl >>> import dask.dataframe as dd + >>> import narwhals as nw + >>> from narwhals.typing import IntoDataFrame, IntoFrame + >>> >>> data = { ... "a": ["a", "b", "a", "b", "b", "c"], ... "b": [1, 2, 3, 4, 5, 6], @@ -3628,28 +3647,14 @@ def collect(self) -> DataFrame[Any]: >>> lf_pl = pl.LazyFrame(data) >>> lf_dask = dd.from_dict(data, npartitions=2) - >>> lf = nw.from_native(lf_pl) - >>> lf # doctest:+ELLIPSIS + >>> nw.from_native(lf_pl) # doctest:+ELLIPSIS ┌─────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------| |>> df = lf.group_by("a").agg(nw.all().sum()).collect() - >>> df.to_native().sort("a") - shape: (3, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ a ┆ 4 ┆ 10 │ - │ b ┆ 11 ┆ 10 │ - │ c ┆ 6 ┆ 1 │ - └─────┴─────┴─────┘ - >>> lf = nw.from_native(lf_dask) - >>> lf + >>> nw.from_native(lf_dask) ┌───────────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------------| @@ -3662,15 +3667,88 @@ def collect(self) -> DataFrame[Any]: |Dask Name: frompandas, 1 expression| |Expr=df | └───────────────────────────────────┘ - >>> df = lf.group_by("a").agg(nw.col("b", "c").sum()).collect() - >>> df.to_native() + + Let's define a dataframe-agnostic that does some grouping computation and + finally collects to a DataFrame: + + >>> def agnostic_group_by_and_collect(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect() + ... .to_native() + ... ) + + We can then pass any supported library such as Polars or Dask + to `agnostic_group_by_and_collect`: + + >>> agnostic_group_by_and_collect(lf_pl) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 4 ┆ 10 │ + │ b ┆ 11 ┆ 10 │ + │ c ┆ 6 ┆ 1 │ + └─────┴─────┴─────┘ + + >>> agnostic_group_by_and_collect(lf_dask) + a b c + 0 a 4 10 + 1 b 11 10 + 2 c 6 1 + + Now for whatever reason, let's suppose that we want to run lazily, yet without + query optimization (e.g. for debugging purpose). As this is achieved + differently in polars and dask, to keep a unified workflow we can specify + the native kwargs for each backend: + + >>> def agnostic_collect_no_opt(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect( + ... polars_kwargs={"no_optimization": True}, + ... dask_kwargs={"optimize_graph": False}, + ... ) + ... .to_native() + ... ) + + >>> agnostic_collect_no_opt(lf_pl) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 4 ┆ 10 │ + │ b ┆ 11 ┆ 10 │ + │ c ┆ 6 ┆ 1 │ + └─────┴─────┴─────┘ + + >>> agnostic_collect_no_opt(lf_dask) a b c 0 a 4 10 1 b 11 10 2 c 6 1 """ + from narwhals.utils import Implementation + + if self.implementation is Implementation.POLARS and polars_kwargs is not None: + kwargs = polars_kwargs + elif self.implementation is Implementation.DASK and dask_kwargs is not None: + kwargs = dask_kwargs + else: + kwargs = {} + return self._dataframe( - self._compliant_frame.collect(), + self._compliant_frame.collect(**kwargs), level="full", ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 5ffc475e5a..d7b45954e0 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -413,16 +413,35 @@ class LazyFrame(NwLazyFrame[IntoFrameT]): def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame - def collect(self) -> DataFrame[Any]: + def collect( + self: Self, + *, + polars_kwargs: dict[str, Any] | None = None, + dask_kwargs: dict[str, Any] | None = None, + ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. + As each underlying lazyframe has different arguments to set when materializing + the lazyframe into a dataframe, we allow to pass them separately into its own + keyword argument. + + Arguments: + polars_kwargs: [polars.LazyFrame.collect](https://docs.pola.rs/api/python/dev/reference/lazyframe/api/polars.LazyFrame.collect.html) + arguments. Used only if the `LazyFrame` is backed by a `polars.LazyFrame`. + If not provided, it uses the polars default values. + dask_kwargs: [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) + arguments. Used only if the `LazyFrame` is backed by a `dask.dataframe.DataFrame`. + If not provided, it uses the dask default values. + Returns: DataFrame Examples: - >>> import narwhals as nw >>> import polars as pl >>> import dask.dataframe as dd + >>> import narwhals as nw + >>> from narwhals.typing import IntoDataFrame, IntoFrame + >>> >>> data = { ... "a": ["a", "b", "a", "b", "b", "c"], ... "b": [1, 2, 3, 4, 5, 6], @@ -431,28 +450,14 @@ def collect(self) -> DataFrame[Any]: >>> lf_pl = pl.LazyFrame(data) >>> lf_dask = dd.from_dict(data, npartitions=2) - >>> lf = nw.from_native(lf_pl) - >>> lf # doctest:+ELLIPSIS + >>> nw.from_native(lf_pl) # doctest:+ELLIPSIS ┌─────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------| |>> df = lf.group_by("a").agg(nw.all().sum()).collect() - >>> df.to_native().sort("a") - shape: (3, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ a ┆ 4 ┆ 10 │ - │ b ┆ 11 ┆ 10 │ - │ c ┆ 6 ┆ 1 │ - └─────┴─────┴─────┘ - >>> lf = nw.from_native(lf_dask) - >>> lf + >>> nw.from_native(lf_dask) ┌───────────────────────────────────┐ | Narwhals LazyFrame | |-----------------------------------| @@ -465,14 +470,78 @@ def collect(self) -> DataFrame[Any]: |Dask Name: frompandas, 1 expression| |Expr=df | └───────────────────────────────────┘ - >>> df = lf.group_by("a").agg(nw.col("b", "c").sum()).collect() - >>> df.to_native() + + Let's define a dataframe-agnostic that does some grouping computation and + finally collects to a DataFrame: + + >>> def agnostic_group_by_and_collect(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect() + ... .to_native() + ... ) + + We can then pass any supported library such as Polars or Dask + to `agnostic_group_by_and_collect`: + + >>> agnostic_group_by_and_collect(lf_pl) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 4 ┆ 10 │ + │ b ┆ 11 ┆ 10 │ + │ c ┆ 6 ┆ 1 │ + └─────┴─────┴─────┘ + + >>> agnostic_group_by_and_collect(lf_dask) + a b c + 0 a 4 10 + 1 b 11 10 + 2 c 6 1 + + Now for whatever reason, let's suppose that we want to run lazily, yet without + query optimization (e.g. for debugging purpose). As this is achieved + differently in polars and dask, to keep a unified workflow we can specify + the native kwargs for each backend: + + >>> def agnostic_collect_no_opt(lf_native: IntoFrame) -> IntoDataFrame: + ... lf = nw.from_native(lf_native) + ... return ( + ... lf.group_by("a") + ... .agg(nw.col("b", "c").sum()) + ... .sort("a") + ... .collect( + ... polars_kwargs={"no_optimization": True}, + ... dask_kwargs={"optimize_graph": False}, + ... ) + ... .to_native() + ... ) + + >>> agnostic_collect_no_opt(lf_pl) + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 4 ┆ 10 │ + │ b ┆ 11 ┆ 10 │ + │ c ┆ 6 ┆ 1 │ + └─────┴─────┴─────┘ + + >>> agnostic_collect_no_opt(lf_dask) a b c 0 a 4 10 1 b 11 10 2 c 6 1 """ - return super().collect() # type: ignore[return-value] + return super().collect(polars_kwargs=polars_kwargs, dask_kwargs=dask_kwargs) # type: ignore[return-value] def _l1_norm(self: Self) -> Self: """Private, just used to test the stable API. From 1046a9eaa418f04bc796989d744ce63653a4c764 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 6 Jan 2025 10:19:37 +0100 Subject: [PATCH 02/17] phrasing --- narwhals/stable/v1/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index d7b45954e0..5e0f15da95 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -505,7 +505,7 @@ def collect( 1 b 11 10 2 c 6 1 - Now for whatever reason, let's suppose that we want to run lazily, yet without + Now, let's suppose that we want to run lazily, yet without query optimization (e.g. for debugging purpose). As this is achieved differently in polars and dask, to keep a unified workflow we can specify the native kwargs for each backend: From 6cd4b0b084e7fb559928fac0bc0ff3fb44fa8bee Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 6 Jan 2025 10:27:47 +0100 Subject: [PATCH 03/17] forgot to save one file... --- narwhals/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 304e985db4..5abed88b89 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3702,7 +3702,7 @@ def collect( 1 b 11 10 2 c 6 1 - Now for whatever reason, let's suppose that we want to run lazily, yet without + Now, let's suppose that we want to run lazily, yet without query optimization (e.g. for debugging purpose). As this is achieved differently in polars and dask, to keep a unified workflow we can specify the native kwargs for each backend: From e9e573c57721005786ec0060d1cbd193b3737769 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 6 Jan 2025 10:59:30 +0100 Subject: [PATCH 04/17] tests --- tests/frame/collect_test.py | 22 ++++++++++++++++++++++ tests/utils.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 tests/frame/collect_test.py diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py new file mode 100644 index 0000000000..97bb2e9145 --- /dev/null +++ b/tests/frame/collect_test.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import narwhals.stable.v1 as nw +from tests.utils import Constructor +from tests.utils import assert_equal_data + + +def test_collect_kwargs(constructor: Constructor) -> None: + data = {"a": [1, 2], "b": [3, 4]} + df = nw.from_native(constructor(data)) + + result = ( + df.lazy() + .select(nw.all().sum()) + .collect( + polars_kwargs={"no_optimization": True}, + dask_kwargs={"optimize_graph": False}, + ) + ) + + expected = {"a": [3], "b": [7]} + assert_equal_data(result, expected) diff --git a/tests/utils.py b/tests/utils.py index 34f1bfa1ef..7d085aa90b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -77,7 +77,7 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: if result.implementation is Implementation.POLARS and os.environ.get( "NARWHALS_POLARS_GPU", False ): # pragma: no cover - result = result.to_native().collect(engine="gpu") + result.collect(polars_kwargs={"engine": "gpu"}) else: result = result.collect() From aedbff2c90ded4298199d038ac28342b0c20d045 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 6 Jan 2025 16:42:21 +0100 Subject: [PATCH 05/17] duckdb_kwargs --- narwhals/_duckdb/dataframe.py | 75 ++++++++++++++++++++++++++++------ narwhals/dataframe.py | 7 ++++ narwhals/stable/v1/__init__.py | 11 ++++- tests/frame/collect_test.py | 53 ++++++++++++++++++++++-- 4 files changed, 130 insertions(+), 16 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 76ff68ae0d..e0da8ece13 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -6,6 +6,7 @@ from typing import Iterable from typing import Literal from typing import Sequence +from typing import overload from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals._duckdb.utils import parse_exprs_and_named_exprs @@ -27,10 +28,13 @@ import pyarrow as pa from typing_extensions import Self + from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.group_by import DuckDBGroupBy from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._duckdb.series import DuckDBInterchangeSeries + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + from narwhals._polars.dataframe import PolarsDataFrame from narwhals.dtypes import DType @@ -76,20 +80,67 @@ def __getitem__(self, item: str) -> DuckDBInterchangeSeries: self._native_frame.select(item), version=self._version ) - def collect(self) -> Any: - try: - import pyarrow as pa # ignore-banned-import - except ModuleNotFoundError as exc: # pragma: no cover - msg = "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDcollect `LazyFrame` backed by DuckDB" - raise ModuleNotFoundError(msg) from exc + @overload + def collect(self, return_type: Literal["pyarrow"] = "pyarrow") -> ArrowDataFrame: ... - from narwhals._arrow.dataframe import ArrowDataFrame + @overload + def collect(self, return_type: Literal["pandas"]) -> PandasLikeDataFrame: ... - return ArrowDataFrame( - native_dataframe=self._native_frame.arrow(), - backend_version=parse_version(pa.__version__), - version=self._version, - ) + @overload + def collect(self, return_type: Literal["polars"]) -> PolarsDataFrame: ... + + def collect( + self, + return_type: Literal["pyarrow", "pandas", "polars"] = "pyarrow", + ) -> ArrowDataFrame | PandasLikeDataFrame | PolarsDataFrame: + if return_type == "pyarrow": + try: + import pyarrow as pa # ignore-banned-import + except ModuleNotFoundError as exc: # pragma: no cover + msg = ( + "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDB" + ) + raise ModuleNotFoundError(msg) from exc + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + native_dataframe=self._native_frame.arrow(), + backend_version=parse_version(pa.__version__), + version=self._version, + ) + + elif return_type == "pandas": + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + from narwhals.utils import Implementation + + return PandasLikeDataFrame( + native_dataframe=self._native_frame.df(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd.__version__), + version=self._version, + ) + + elif return_type == "polars": + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + from narwhals.utils import Implementation + + return PolarsDataFrame( + df=self._native_frame.pl(), + backend_version=parse_version(pl.__version__), + version=self._version, + ) + + else: + msg = ( + "Only the following `return_type`'s are supported: pyarrow, pandas and " + f"polars. Found '{return_type}'." + ) + raise ValueError(msg) def head(self, n: int) -> Self: return self._from_native_frame(self._native_frame.limit(n)) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 5abed88b89..4a26da20bc 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3615,6 +3615,7 @@ def collect( *, polars_kwargs: dict[str, Any] | None = None, dask_kwargs: dict[str, Any] | None = None, + duckdb_kwargs: dict[str, str] | None = None, ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. @@ -3629,6 +3630,10 @@ def collect( dask_kwargs: [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) arguments. Used only if the `LazyFrame` is backed by a `dask.dataframe.DataFrame`. If not provided, it uses the dask default values. + duckdb_kwargs: Allows to specify in which eager backend to materialize a + DuckDBPyRelation backed LazyFrame. It is possible to choose among + `pyarrow`, `pandas` or `polars` by declaring + `duckdb_kwargs={"return_type": ""}`. Returns: DataFrame @@ -3744,6 +3749,8 @@ def collect( kwargs = polars_kwargs elif self.implementation is Implementation.DASK and dask_kwargs is not None: kwargs = dask_kwargs + elif self.implementation is Implementation.DUCKDB and duckdb_kwargs is not None: + kwargs = duckdb_kwargs else: kwargs = {} diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index ef9a8197d7..a82ec83a28 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -418,6 +418,7 @@ def collect( *, polars_kwargs: dict[str, Any] | None = None, dask_kwargs: dict[str, Any] | None = None, + duckdb_kwargs: dict[str, str] | None = None, ) -> DataFrame[Any]: r"""Materialize this LazyFrame into a DataFrame. @@ -432,6 +433,10 @@ def collect( dask_kwargs: [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html) arguments. Used only if the `LazyFrame` is backed by a `dask.dataframe.DataFrame`. If not provided, it uses the dask default values. + duckdb_kwargs: Allows to specify in which eager backend to materialize a + DuckDBPyRelation backed LazyFrame. It is possible to choose among + `pyarrow`, `pandas` or `polars` by declaring + `duckdb_kwargs={"return_type": ""}`. Returns: DataFrame @@ -541,7 +546,11 @@ def collect( 1 b 11 10 2 c 6 1 """ - return super().collect(polars_kwargs=polars_kwargs, dask_kwargs=dask_kwargs) # type: ignore[return-value] + return super().collect( + polars_kwargs=polars_kwargs, + dask_kwargs=dask_kwargs, + duckdb_kwargs=duckdb_kwargs, + ) # type: ignore[return-value] def _l1_norm(self: Self) -> Self: """Private, just used to test the stable API. diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 97bb2e9145..84bf8dd63e 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -1,22 +1,69 @@ from __future__ import annotations -import narwhals.stable.v1 as nw +from typing import Literal + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 from tests.utils import Constructor from tests.utils import assert_equal_data def test_collect_kwargs(constructor: Constructor) -> None: data = {"a": [1, 2], "b": [3, 4]} - df = nw.from_native(constructor(data)) + df = nw_v1.from_native(constructor(data)) result = ( df.lazy() - .select(nw.all().sum()) + .select(nw_v1.all().sum()) .collect( polars_kwargs={"no_optimization": True}, dask_kwargs={"optimize_graph": False}, + duckdb_kwargs={"return_type": "pyarrow"}, ) ) expected = {"a": [3], "b": [7]} assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("return_type", "expected_cls"), + [ + ("pyarrow", pa.Table), + ("polars", pl.DataFrame), + ("pandas", pd.DataFrame), + ], +) +def test_collect_duckdb( + return_type: Literal["pyarrow", "polars", "pandas"], expected_cls: type +) -> None: + duckdb = pytest.importorskip("duckdb") + + data = {"a": [1, 2], "b": [3, 4]} + df_pl = pl.DataFrame(data) # noqa: F841 + df = nw.from_native(duckdb.sql("select * from df_pl")) + + result = df.lazy().collect(duckdb_kwargs={"return_type": return_type}).to_native() + assert isinstance(result, expected_cls) + + +def test_collect_duckdb_raise() -> None: + duckdb = pytest.importorskip("duckdb") + + data = {"a": [1, 2], "b": [3, 4]} + df_pl = pl.DataFrame(data) # noqa: F841 + df = nw.from_native(duckdb.sql("select * from df_pl")) + + with pytest.raises( + ValueError, + match=( + "Only the following `return_type`'s are supported: pyarrow, pandas and " + "polars. Found 'foo'." + ), + ): + df.lazy().collect(duckdb_kwargs={"return_type": "foo"}) From e2459e3003febb98b73d7e11de80d1f250fc7f46 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 11 Jan 2025 12:31:40 +0100 Subject: [PATCH 06/17] return_type -> eager_backend --- narwhals/_duckdb/dataframe.py | 20 +++++++++++--------- narwhals/dataframe.py | 2 +- narwhals/stable/v1/__init__.py | 2 +- tests/frame/collect_test.py | 12 ++++++------ 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 1067d38798..0c1929fa65 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -81,19 +81,21 @@ def __getitem__(self, item: str) -> DuckDBInterchangeSeries: ) @overload - def collect(self, return_type: Literal["pyarrow"] = "pyarrow") -> ArrowDataFrame: ... + def collect( + self, eager_backend: Literal["pyarrow"] = "pyarrow" + ) -> ArrowDataFrame: ... @overload - def collect(self, return_type: Literal["pandas"]) -> PandasLikeDataFrame: ... + def collect(self, eager_backend: Literal["pandas"]) -> PandasLikeDataFrame: ... @overload - def collect(self, return_type: Literal["polars"]) -> PolarsDataFrame: ... + def collect(self, eager_backend: Literal["polars"]) -> PolarsDataFrame: ... def collect( self, - return_type: Literal["pyarrow", "pandas", "polars"] = "pyarrow", + eager_backend: Literal["pyarrow", "pandas", "polars"] = "pyarrow", ) -> ArrowDataFrame | PandasLikeDataFrame | PolarsDataFrame: - if return_type == "pyarrow": + if eager_backend == "pyarrow": try: import pyarrow as pa # ignore-banned-import except ModuleNotFoundError as exc: # pragma: no cover @@ -110,7 +112,7 @@ def collect( version=self._version, ) - elif return_type == "pandas": + elif eager_backend == "pandas": import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -123,7 +125,7 @@ def collect( version=self._version, ) - elif return_type == "polars": + elif eager_backend == "polars": import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -137,8 +139,8 @@ def collect( else: msg = ( - "Only the following `return_type`'s are supported: pyarrow, pandas and " - f"polars. Found '{return_type}'." + "Only the following `eager_backend`'s are supported: pyarrow, pandas and " + f"polars. Found '{eager_backend}'." ) raise ValueError(msg) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 4a26da20bc..fce645fee3 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3633,7 +3633,7 @@ def collect( duckdb_kwargs: Allows to specify in which eager backend to materialize a DuckDBPyRelation backed LazyFrame. It is possible to choose among `pyarrow`, `pandas` or `polars` by declaring - `duckdb_kwargs={"return_type": ""}`. + `duckdb_kwargs={"eager_backend": ""}`. Returns: DataFrame diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 1ce9be958f..73d1aaa7e2 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -436,7 +436,7 @@ def collect( duckdb_kwargs: Allows to specify in which eager backend to materialize a DuckDBPyRelation backed LazyFrame. It is possible to choose among `pyarrow`, `pandas` or `polars` by declaring - `duckdb_kwargs={"return_type": ""}`. + `duckdb_kwargs={"eager_backend": ""}`. Returns: DataFrame diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 84bf8dd63e..2700499abd 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -23,7 +23,7 @@ def test_collect_kwargs(constructor: Constructor) -> None: .collect( polars_kwargs={"no_optimization": True}, dask_kwargs={"optimize_graph": False}, - duckdb_kwargs={"return_type": "pyarrow"}, + duckdb_kwargs={"eager_backend": "pyarrow"}, ) ) @@ -32,7 +32,7 @@ def test_collect_kwargs(constructor: Constructor) -> None: @pytest.mark.parametrize( - ("return_type", "expected_cls"), + ("eager_backend", "expected_cls"), [ ("pyarrow", pa.Table), ("polars", pl.DataFrame), @@ -40,7 +40,7 @@ def test_collect_kwargs(constructor: Constructor) -> None: ], ) def test_collect_duckdb( - return_type: Literal["pyarrow", "polars", "pandas"], expected_cls: type + eager_backend: Literal["pyarrow", "polars", "pandas"], expected_cls: type ) -> None: duckdb = pytest.importorskip("duckdb") @@ -48,7 +48,7 @@ def test_collect_duckdb( df_pl = pl.DataFrame(data) # noqa: F841 df = nw.from_native(duckdb.sql("select * from df_pl")) - result = df.lazy().collect(duckdb_kwargs={"return_type": return_type}).to_native() + result = df.lazy().collect(duckdb_kwargs={"eager_backend": eager_backend}).to_native() assert isinstance(result, expected_cls) @@ -62,8 +62,8 @@ def test_collect_duckdb_raise() -> None: with pytest.raises( ValueError, match=( - "Only the following `return_type`'s are supported: pyarrow, pandas and " + "Only the following `eager_backend`'s are supported: pyarrow, pandas and " "polars. Found 'foo'." ), ): - df.lazy().collect(duckdb_kwargs={"return_type": "foo"}) + df.lazy().collect(duckdb_kwargs={"eager_backend": "foo"}) From 892a91c27403f64c96308aa19fb5064d1f4ddc3a Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Thu, 30 Jan 2025 23:17:35 +0100 Subject: [PATCH 07/17] skip old pandas and simplify --- narwhals/_pandas_like/dataframe.py | 16 ++-------------- tests/frame/collect_test.py | 2 ++ 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 4b832be70b..a453e2a838 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -533,13 +533,8 @@ def collect( from narwhals._arrow.dataframe import ArrowDataFrame - if self._implementation is Implementation.CUDF: - pa_native = self._native_frame.to_arrow(preserve_index=False) - else: - pa_native = pa.Table.from_pandas(self._native_frame) - return ArrowDataFrame( - native_dataframe=pa_native, + native_dataframe=self.to_arrow(), backend_version=parse_version(pa.__version__), version=self._version, ) @@ -549,15 +544,8 @@ def collect( from narwhals._polars.dataframe import PolarsDataFrame - if self._implementation is Implementation.PANDAS: - pl_native = pl.from_pandas(self._native_frame) - elif self._implementation is Implementation.CUDF: # pragma: no cover - pl_native = pl.from_pandas(self._native_frame.to_pandas()) - elif self._implementation is Implementation.MODIN: - pl_native = pl.from_pandas(self._native_frame._to_pandas()) - return PolarsDataFrame( - df=pl_native, + df=self.to_polars(), backend_version=parse_version(pl.__version__), version=self._version, ) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 9a90a4a7b6..69b9c30f2d 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -12,6 +12,7 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin from narwhals.utils import Implementation +from tests.utils import PANDAS_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -42,6 +43,7 @@ def test_collect_to_default_backend(constructor: Constructor) -> None: assert isinstance(result, expected_cls) +@pytest.mark.skipif(PANDAS_VERSION < (1,), reason="too old for pyarrow") @pytest.mark.parametrize( ("backend", "expected_cls"), [ From 9906862a467077828af981f525c47bd117c63034 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 31 Jan 2025 10:26:09 +0100 Subject: [PATCH 08/17] fail old and filter warnings --- tests/frame/collect_test.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 69b9c30f2d..1d006eb214 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -43,7 +43,9 @@ def test_collect_to_default_backend(constructor: Constructor) -> None: assert isinstance(result, expected_cls) -@pytest.mark.skipif(PANDAS_VERSION < (1,), reason="too old for pyarrow") +@pytest.mark.filterwarnings( + "ignore:is_sparse is deprecated and will be removed in a future version." +) @pytest.mark.parametrize( ("backend", "expected_cls"), [ @@ -62,7 +64,11 @@ def test_collect_to_valid_backend( constructor: Constructor, backend: ModuleType | Implementation | str | None, expected_cls: type, + request: pytest.FixtureRequest, ) -> None: + if "pandas" in str(constructor) and PANDAS_VERSION < (1,): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.lazy().collect(backend=backend).to_native() assert isinstance(result, expected_cls) From 485209ca969a648367f208df5124faaaec4977f8 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 31 Jan 2025 12:52:13 +0100 Subject: [PATCH 09/17] skip for pandas<1 --- tests/frame/collect_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 1d006eb214..166f8c8b7c 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -19,6 +19,9 @@ if TYPE_CHECKING: from types import ModuleType +if PANDAS_VERSION < (1,): + pytest.skip(allow_module_level=True) + data = {"a": [1, 2], "b": [3, 4]} @@ -64,11 +67,7 @@ def test_collect_to_valid_backend( constructor: Constructor, backend: ModuleType | Implementation | str | None, expected_cls: type, - request: pytest.FixtureRequest, ) -> None: - if "pandas" in str(constructor) and PANDAS_VERSION < (1,): - request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) result = df.lazy().collect(backend=backend).to_native() assert isinstance(result, expected_cls) From 4de9ace83952da603edef8dc775d5f35a3fdf78f Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 31 Jan 2025 13:06:48 +0100 Subject: [PATCH 10/17] no cover skip statement --- tests/frame/collect_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 166f8c8b7c..5f9182bcba 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from types import ModuleType -if PANDAS_VERSION < (1,): +if PANDAS_VERSION < (1,): # pragma: no cover pytest.skip(allow_module_level=True) From 871c0acdda3433186949800d746d431c3ca702b2 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 31 Jan 2025 13:18:03 +0100 Subject: [PATCH 11/17] test utils fix --- tests/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 04bd8a9803..59fe42eb3b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -88,12 +88,12 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: if is_duckdb: result = from_native(result.to_native().arrow()) if hasattr(result, "collect"): - if result.implementation is Implementation.POLARS and os.environ.get( - "NARWHALS_POLARS_GPU", False - ): # pragma: no cover - result.collect(polars_kwargs={"engine": "gpu"}) - else: - result = result.collect() + kwargs = { + Implementation.POLARS: ( + {"engine": "gpu"} if os.environ.get("NARWHALS_POLARS_GPU", False) else {} + ) # pragma: no cover + } + result = result.collect(**kwargs.get(result.implementation, {})) if hasattr(result, "columns"): for idx, (col, key) in enumerate(zip(result.columns, expected.keys())): From a5044c42bb01bc2ef5e2cfb4c59887701f27df24 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:21:12 +0000 Subject: [PATCH 12/17] simplify --- narwhals/_arrow/dataframe.py | 15 ++++++--------- narwhals/_dask/dataframe.py | 15 ++++++--------- narwhals/_duckdb/dataframe.py | 14 +++++--------- narwhals/_pandas_like/dataframe.py | 16 ++++++---------- narwhals/_polars/dataframe.py | 15 ++++++--------- narwhals/dataframe.py | 23 ++++++++++++++++++++--- narwhals/utils.py | 23 +++++++++++++++++++++++ 7 files changed, 72 insertions(+), 49 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 4b081ac3f5..7b0cac0a21 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -17,8 +17,6 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import select_rows from narwhals._expression_parsing import evaluate_into_exprs -from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_polars from narwhals.dependencies import is_numpy_array from narwhals.utils import Implementation from narwhals.utils import Version @@ -564,10 +562,10 @@ def lazy(self: Self, *, backend: Implementation | None = None) -> CompliantLazyF def collect( self: Self, - backend: ModuleType | Implementation | str | None, + backend: Implementation | None, **kwargs: Any, ) -> CompliantDataFrame: - if backend in (None, "pyarrow", Implementation.PYARROW, pa): + if backend is Implementation.PYARROW or backend is None: from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( @@ -576,7 +574,7 @@ def collect( version=self._version, ) - elif backend in ("pandas", Implementation.PANDAS, get_pandas()): + if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -588,7 +586,7 @@ def collect( version=self._version, ) - elif backend in ("polars", Implementation.POLARS, get_polars()): + if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -599,9 +597,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise AssertionError(msg) # pragma: no cover def clone(self: Self) -> Self: msg = "clone is not yet supported on PyArrow tables" diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 8fe6954f19..66206e766e 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -13,8 +13,6 @@ from narwhals._dask.utils import parse_exprs_and_named_exprs from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import select_columns_by_name -from narwhals.dependencies import get_polars -from narwhals.dependencies import get_pyarrow from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation @@ -83,14 +81,14 @@ def with_columns(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: def collect( self: Self, - backend: ModuleType | Implementation | str | None, + backend: Implementation | None, **kwargs: Any, ) -> CompliantDataFrame: import pandas as pd result = self._native_frame.compute(**kwargs) - if backend in (None, "pandas", Implementation.PANDAS, pd): + if backend is None or backend is Implementation.PANDAS: from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( @@ -100,7 +98,7 @@ def collect( version=self._version, ) - elif backend in ("polars", Implementation.POLARS, get_polars()): + if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -111,7 +109,7 @@ def collect( version=self._version, ) - elif backend in ("pyarrow", Implementation.PYARROW, get_pyarrow()): + if backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame @@ -122,9 +120,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover @property def columns(self: Self) -> list[str]: diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 12500e5566..0cb898d06a 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -13,9 +13,6 @@ from narwhals._duckdb.utils import native_to_narwhals_dtype from narwhals._duckdb.utils import parse_exprs_and_named_exprs from narwhals.dependencies import get_duckdb -from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_polars -from narwhals.dependencies import get_pyarrow from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantDataFrame from narwhals.utils import Implementation @@ -88,7 +85,7 @@ def collect( backend: ModuleType | Implementation | str | None, **kwargs: Any, ) -> CompliantDataFrame: - if backend in (None, "pyarrow", Implementation.PYARROW, get_pyarrow()): + if backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame @@ -99,7 +96,7 @@ def collect( version=self._version, ) - elif backend in ("pandas", Implementation.PANDAS, get_pandas()): + if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -111,7 +108,7 @@ def collect( version=self._version, ) - elif backend in ("polars", Implementation.POLARS, get_polars()): + if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -122,9 +119,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def head(self, n: int) -> Self: return self._from_native_frame(self._native_frame.limit(n)) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 32902cc0fd..760df69f39 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -18,9 +18,6 @@ from narwhals._pandas_like.utils import pivot_table from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name -from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_polars -from narwhals.dependencies import get_pyarrow from narwhals.dependencies import is_numpy_array from narwhals.utils import Implementation from narwhals.utils import check_column_exists @@ -507,7 +504,7 @@ def sort( # --- convert --- def collect( self: Self, - backend: ModuleType | Implementation | str | None, + backend: Implementation | None, **kwargs: Any, ) -> CompliantDataFrame: if backend is None: @@ -518,7 +515,7 @@ def collect( version=self._version, ) - elif backend in ("pandas", Implementation.PANDAS, get_pandas()): + if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import return PandasLikeDataFrame( @@ -528,7 +525,7 @@ def collect( version=self._version, ) - elif backend in ("pyarrow", Implementation.PYARROW, get_pyarrow()): + if backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame @@ -539,7 +536,7 @@ def collect( version=self._version, ) - elif backend in ("polars", Implementation.POLARS, get_polars()): + if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -550,9 +547,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover # --- actions --- def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy: diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 8447d1a4f6..7a8cbfdce9 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -12,8 +12,6 @@ from narwhals._polars.utils import convert_str_slice_to_int_slice from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import native_to_narwhals_dtype -from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_pyarrow from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import is_sequence_but_not_str @@ -446,7 +444,7 @@ def collect_schema(self: Self) -> dict[str, DType]: def collect( self: Self, - backend: ModuleType | Implementation | str | None, + backend: Implementation | None, **kwargs: Any, ) -> CompliantDataFrame: import polars as pl @@ -456,7 +454,7 @@ def collect( except pl.exceptions.ColumnNotFoundError as e: raise ColumnNotFoundError(str(e)) from e - if backend in (None, "polars", Implementation.POLARS, pl): + if backend is None or backend is Implementation.POLARS: from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( @@ -465,7 +463,7 @@ def collect( version=self._version, ) - elif backend in ("pandas", Implementation.PANDAS, get_pandas()): + if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -477,7 +475,7 @@ def collect( version=self._version, ) - elif backend in ("pyarrow", Implementation.PYARROW, get_pyarrow()): + if backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame @@ -488,9 +486,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index c51045fda3..9199ea0623 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3827,7 +3827,7 @@ def collect( - `polars.LazyFrame` -> `polars.DataFrame` - `dask.DataFrame` -> `pandas.DataFrame` - `duckdb.PyRelation` -> `pyarrow.Table` - - `pyspark.DataFrame` -> `pandas.DataFrame` + - `pyspark.DataFrame` -> `pyarrow.Table` `backend` can be specified in various ways: @@ -3957,6 +3957,23 @@ def collect( b: [[4,11,6]] c: [[10,10,1]] """ + eager_backend = ( + None + if backend is None + else Implementation.from_string(backend) + if isinstance(backend, str) + else backend + if isinstance(backend, Implementation) + else Implementation.from_native_namespace(backend) + ) + supported_eager_backends = ( + Implementation.POLARS, + Implementation.PANDAS, + Implementation.PYARROW, + ) + if eager_backend is not None and eager_backend not in supported_eager_backends: + msg = f"Expected one of {supported_eager_backends} or None, got: {eager_backend}." + raise ValueError(msg) return self._dataframe( self._compliant_frame.collect(backend=backend, **kwargs), level="full", @@ -5375,9 +5392,9 @@ def clone(self: Self) -> Self: return super().clone() def lazy(self: Self) -> Self: - """Lazify the DataFrame (if possible). + """Restrict available API methods to lazy-only ones. - If a library does not support lazy execution, then this is a no-op. + This is a no-op, and exists only for compatibility with `DataFrame.lazy`. Returns: A LazyFrame. diff --git a/narwhals/utils.py b/narwhals/utils.py index bbfe3eeaf4..f9ae0620dc 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -110,6 +110,29 @@ def from_native_namespace( } return mapping.get(native_namespace, Implementation.UNKNOWN) + @classmethod + def from_string(cls: type[Self], backend: str) -> Implementation: # pragma: no cover + """Instantiate Implementation object from a native namespace module. + + Arguments: + backend: Name of backend, expressed as string. + + Returns: + Implementation. + """ + mapping = { + "pandas": Implementation.PANDAS, + "modin": Implementation.MODIN, + "cudf": Implementation.CUDF, + "pyarrow": Implementation.PYARROW, + "pyspark": Implementation.PYSPARK, + "polars": Implementation.POLARS, + "dask": Implementation.DASK, + "duckdb": Implementation.DUCKDB, + "ibis": Implementation.IBIS, + } + return mapping.get(backend, Implementation.UNKNOWN) + def to_native_namespace(self: Self) -> ModuleType: """Return the native namespace module corresponding to Implementation. From 6b2ffd178254e15200f750274f00f96202aac7c6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:24:59 +0000 Subject: [PATCH 13/17] simplify --- narwhals/dataframe.py | 10 +--------- narwhals/utils.py | 28 +++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 9199ea0623..421c41ffae 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3957,15 +3957,7 @@ def collect( b: [[4,11,6]] c: [[10,10,1]] """ - eager_backend = ( - None - if backend is None - else Implementation.from_string(backend) - if isinstance(backend, str) - else backend - if isinstance(backend, Implementation) - else Implementation.from_native_namespace(backend) - ) + eager_backend = None if backend is None else Implementation.from_backend(backend) supported_eager_backends = ( Implementation.POLARS, Implementation.PANDAS, diff --git a/narwhals/utils.py b/narwhals/utils.py index f9ae0620dc..29f831fd30 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -111,11 +111,13 @@ def from_native_namespace( return mapping.get(native_namespace, Implementation.UNKNOWN) @classmethod - def from_string(cls: type[Self], backend: str) -> Implementation: # pragma: no cover + def from_string( + cls: type[Self], backend_name: str + ) -> Implementation: # pragma: no cover """Instantiate Implementation object from a native namespace module. Arguments: - backend: Name of backend, expressed as string. + backend_name: Name of backend, expressed as string. Returns: Implementation. @@ -131,7 +133,27 @@ def from_string(cls: type[Self], backend: str) -> Implementation: # pragma: no "duckdb": Implementation.DUCKDB, "ibis": Implementation.IBIS, } - return mapping.get(backend, Implementation.UNKNOWN) + return mapping.get(backend_name, Implementation.UNKNOWN) + + @classmethod + def from_backend( + cls: type[Self], backend: str | Implementation | ModuleType + ) -> Implementation: + """Instantiate from native namespace module, string, or Implementation. + + Arguments: + backend: Backend to instantiate Implementation from. + + Returns: + Implementation. + """ + return ( + cls.from_string(backend) + if isinstance(backend, str) + else backend + if isinstance(backend, Implementation) + else cls.from_native_namespace(backend) + ) def to_native_namespace(self: Self) -> ModuleType: """Return the native namespace module corresponding to Implementation. From 766f172264051e1d909e1fc8b80aee85c7e7f7cb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:56:26 +0000 Subject: [PATCH 14/17] wip --- narwhals/_spark_like/dataframe.py | 40 +++++++++++++------ narwhals/dataframe.py | 4 +- narwhals/stable/v1/__init__.py | 2 +- tests/expr_and_series/str/to_datetime_test.py | 22 +++++++--- tests/frame/collect_test.py | 4 +- 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 15f07f9cb6..d4c834b869 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -9,9 +9,6 @@ from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs -from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_polars -from narwhals.dependencies import get_pyarrow from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame @@ -120,7 +117,7 @@ def collect( backend: ModuleType | Implementation | str | None, **kwargs: Any, ) -> CompliantDataFrame: - if backend in (None, "pandas", Implementation.PANDAS, get_pandas()): + if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -132,20 +129,40 @@ def collect( version=self._version, ) - elif backend in ("pyarrow", Implementation.PYARROW, get_pyarrow()): + elif backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame - return ArrowDataFrame( - native_dataframe=pa.Table.from_batches( + try: + native_pyarrow_frame = pa.Table.from_batches( self._native_frame._collect_as_arrow() - ), + ) + except ValueError as exc: + if "at least one RecordBatch" in str(exc): + # Empty dataframe + from narwhals._arrow.utils import narwhals_to_native_dtype + + data: dict[str, list[Any]] = {} + schema = [] + current_schema = self.collect_schema() + for key, value in current_schema.items(): + data[key] = [] + schema.append( + (key, narwhals_to_native_dtype(value, self._version)) + ) + native_pyarrow_frame = pa.Table.from_pydict( + data, schema=pa.schema(schema) + ) + else: # pragma: no cover + raise + return ArrowDataFrame( + native_pyarrow_frame, backend_version=parse_version(pa.__version__), version=self._version, ) - elif backend in ("polars", Implementation.POLARS, get_polars()): + elif backend is Implementation.POLARS: import polars as pl # ignore-banned-import import pyarrow as pa # ignore-banned-import @@ -159,9 +176,8 @@ def collect( version=self._version, ) - else: - msg = f"Unsupported `backend` value: {backend}" - raise ValueError(msg) + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover def simple_select(self: Self, *column_names: str) -> Self: return self._from_native_frame(self._native_frame.select(*column_names)) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 421c41ffae..ec6d0f780c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -3964,10 +3964,10 @@ def collect( Implementation.PYARROW, ) if eager_backend is not None and eager_backend not in supported_eager_backends: - msg = f"Expected one of {supported_eager_backends} or None, got: {eager_backend}." + msg = f"Unsupported `backend` value.\nExpected one of {supported_eager_backends} or None, got: {eager_backend}." raise ValueError(msg) return self._dataframe( - self._compliant_frame.collect(backend=backend, **kwargs), + self._compliant_frame.collect(backend=eager_backend, **kwargs), level="full", ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index a049258431..43ab578cd4 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -297,7 +297,7 @@ def collect( - `polars.LazyFrame` -> `polars.DataFrame` - `dask.DataFrame` -> `pandas.DataFrame` - `duckdb.PyRelation` -> `pyarrow.Table` - - `pyspark.DataFrame` -> `pandas.DataFrame` + - `pyspark.DataFrame` -> `pyarrow.Table` `backend` can be specified in various ways: diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 99f886a123..1558dd8ee3 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -22,6 +22,8 @@ def test_to_datetime(constructor: Constructor, request: pytest.FixtureRequest) - request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = "2020-01-01T12:34:56.000000000" + elif "pyspark" in str(constructor): + expected = "2020-01-01 12:34:56+00:00" else: expected = "2020-01-01 12:34:56" @@ -50,22 +52,25 @@ def test_to_datetime_series(constructor_eager: ConstructorEager) -> None: @pytest.mark.parametrize( - ("data", "expected", "expected_cudf"), + ("data", "expected", "expected_cudf", "expected_pyspark"), [ ( {"a": ["2020-01-01T12:34:56"]}, "2020-01-01 12:34:56", "2020-01-01T12:34:56.000000000", + "2020-01-01T12:34:56+00:00", ), ( {"a": ["2020-01-01T12:34"]}, "2020-01-01 12:34:00", "2020-01-01T12:34:00.000000000", + "2020-01-01T12:34:00+00:00", ), ( {"a": ["20240101123456"]}, "2024-01-01 12:34:56", "2024-01-01T12:34:56.000000000", + "2024-01-01T12:34:56+00:00", ), ], ) @@ -75,15 +80,20 @@ def test_to_datetime_infer_fmt( data: dict[str, list[str]], expected: str, expected_cudf: str, + expected_pyspark: str, ) -> None: - if "polars" in str(constructor) and str(data["a"][0]).isdigit(): + if ( + ("polars" in str(constructor) and str(data["a"][0]).isdigit()) + or "duckdb" in str(constructor) + or ("pyspark" in str(constructor) and data["a"][0] == "20240101123456") + ): request.applymarker(pytest.mark.xfail) + if "cudf" in str(constructor): expected = expected_cudf - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "pyspark" in str(constructor) and data["a"][0] == "20240101123456": - request.applymarker(pytest.mark.xfail) + elif "pyspark" in str(constructor): + expected = expected_pyspark + result = ( nw.from_native(constructor(data)) .lazy() diff --git a/tests/frame/collect_test.py b/tests/frame/collect_test.py index 5f9182bcba..c357d4adf0 100644 --- a/tests/frame/collect_test.py +++ b/tests/frame/collect_test.py @@ -32,7 +32,7 @@ def test_collect_to_default_backend(constructor: Constructor) -> None: if "polars" in str(constructor): expected_cls = pl.DataFrame - elif any(x in str(constructor) for x in ("pandas", "dask", "pyspark")): + elif any(x in str(constructor) for x in ("pandas", "dask")): expected_cls = pd.DataFrame elif "modin" in str(constructor): mpd = get_modin() @@ -40,7 +40,7 @@ def test_collect_to_default_backend(constructor: Constructor) -> None: elif "cudf" in str(constructor): cudf = get_cudf() expected_cls = cudf.DataFrame - else: # pyarrow and duckdb + else: # pyarrow, duckdb, and PySpark expected_cls = pa.Table assert isinstance(result, expected_cls) From 125f1489b4d644b0bc79d1a68ec98ac411f1b0ba Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:22:09 +0000 Subject: [PATCH 15/17] default to PyArrow instead of pandas for PySpark --- narwhals/_spark_like/expr_str.py | 14 ++++++++++++-- tests/expr_and_series/str/to_datetime_test.py | 18 ++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index 8bae6a0307..5d260ea2e4 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -127,14 +127,24 @@ def to_lowercase(self: Self) -> SparkLikeExpr: ) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 + is_naive = format is not None and "%s" not in format and "%z" not in format + function = ( + self._compliant_expr._F.to_timestamp_ntz + if is_naive + else self._compliant_expr._F.to_timestamp + ) + pyspark_format = strptime_to_pyspark_format(format) + format = ( + self._compliant_expr._F.lit(pyspark_format) if is_naive else pyspark_format + ) return self._compliant_expr._from_call( - lambda _input: self._compliant_expr._F.to_timestamp( + lambda _input: function( self._compliant_expr._F.replace( _input, self._compliant_expr._F.lit("T"), self._compliant_expr._F.lit(" "), ), - format=strptime_to_pyspark_format(format), + format=format, ), "to_datetime", expr_kind=self._compliant_expr._expr_kind, diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 1558dd8ee3..b88432384e 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from datetime import timezone from typing import TYPE_CHECKING import pyarrow as pa @@ -22,8 +23,6 @@ def test_to_datetime(constructor: Constructor, request: pytest.FixtureRequest) - request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = "2020-01-01T12:34:56.000000000" - elif "pyspark" in str(constructor): - expected = "2020-01-01 12:34:56+00:00" else: expected = "2020-01-01 12:34:56" @@ -58,19 +57,19 @@ def test_to_datetime_series(constructor_eager: ConstructorEager) -> None: {"a": ["2020-01-01T12:34:56"]}, "2020-01-01 12:34:56", "2020-01-01T12:34:56.000000000", - "2020-01-01T12:34:56+00:00", + "2020-01-01 12:34:56+00:00", ), ( {"a": ["2020-01-01T12:34"]}, "2020-01-01 12:34:00", "2020-01-01T12:34:00.000000000", - "2020-01-01T12:34:00+00:00", + "2020-01-01 12:34:00+00:00", ), ( {"a": ["20240101123456"]}, "2024-01-01 12:34:56", "2024-01-01T12:34:56.000000000", - "2024-01-01T12:34:56+00:00", + "2024-01-01 12:34:56+00:00", ), ], ) @@ -148,7 +147,14 @@ def test_to_datetime_infer_fmt_from_date( if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"z": ["2020-01-01", "2020-01-02", None]} - expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] + if "pyspark" in str(constructor): + expected = [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + None, + ] + else: + expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] result = ( nw.from_native(constructor(data)).lazy().select(nw.col("z").str.to_datetime()) ) From 923096500372731eb5dd40085651630f840523c2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:41:23 +0000 Subject: [PATCH 16/17] restore Self --- narwhals/_duckdb/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 0cb898d06a..59bcb757f3 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -122,7 +122,7 @@ def collect( msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover - def head(self, n: int) -> Self: + def head(self: Self, n: int) -> Self: return self._from_native_frame(self._native_frame.limit(n)) def simple_select(self, *column_names: str) -> Self: From f2bd45db75df531a2c2705f5afac9a666e2f8029 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:53:08 +0000 Subject: [PATCH 17/17] coverage --- narwhals/_spark_like/expr_str.py | 7 ++++++- pyproject.toml | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/narwhals/_spark_like/expr_str.py b/narwhals/_spark_like/expr_str.py index 5d260ea2e4..2873972c38 100644 --- a/narwhals/_spark_like/expr_str.py +++ b/narwhals/_spark_like/expr_str.py @@ -127,7 +127,12 @@ def to_lowercase(self: Self) -> SparkLikeExpr: ) def to_datetime(self: Self, format: str | None) -> SparkLikeExpr: # noqa: A002 - is_naive = format is not None and "%s" not in format and "%z" not in format + is_naive = ( + format is not None + and "%s" not in format + and "%z" not in format + and "Z" not in format + ) function = ( self._compliant_expr._F.to_timestamp_ntz if is_naive diff --git a/pyproject.toml b/pyproject.toml index 9162e39b49..a3377fef13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,7 +203,8 @@ exclude_also = [ "if .*implementation.is_cudf", 'request.applymarker\(pytest.mark.xfail', 'backend_version <', - 'if "cudf" in str\(constructor' + 'if "cudf" in str\(constructor', + 'if "pyspark" in str\(constructor' ] [tool.mypy]