diff --git a/dataframe_api_compat/pandas_standard/__init__.py b/dataframe_api_compat/pandas_standard/__init__.py index 1334b202..339820c7 100644 --- a/dataframe_api_compat/pandas_standard/__init__.py +++ b/dataframe_api_compat/pandas_standard/__init__.py @@ -168,7 +168,12 @@ def convert_to_standard_compliant_column( raise ValueError(msg) if ser.name is None: ser = ser.rename("") - return Column(ser, api_version=api_version or "2023.11-beta", df=None) + return Column( + ser, + api_version=api_version or "2023.11-beta", + df=None, + is_persisted=True, + ) def convert_to_standard_compliant_dataframe( @@ -253,15 +258,14 @@ def dataframe_from_columns( api_versions.add(col._api_version) # type: ignore[attr-defined] return DataFrame(pd.DataFrame(data), api_version=list(api_versions)[0]) - def column_from_1d_array( + def column_from_1d_array( # type: ignore[override] self, data: Any, *, - dtype: DType, name: str | None = None, ) -> Column: - ser = pd.Series(data, dtype=map_standard_dtype_to_pandas_dtype(dtype), name=name) - return Column(ser, api_version=self._api_version, df=None) + ser = pd.Series(data, name=name) + return Column(ser, api_version=self._api_version, df=None, is_persisted=True) def column_from_sequence( self, @@ -272,10 +276,11 @@ def column_from_sequence( ) -> Column: ser = pd.Series( sequence, + # todo make optional? dtype=map_standard_dtype_to_pandas_dtype(dtype), name=name, ) - return Column(ser, api_version=self._api_version, df=None) + return Column(ser, api_version=self._api_version, df=None, is_persisted=True) def concat( self, diff --git a/dataframe_api_compat/pandas_standard/column_object.py b/dataframe_api_compat/pandas_standard/column_object.py index e276ecd4..fb8ab3d9 100644 --- a/dataframe_api_compat/pandas_standard/column_object.py +++ b/dataframe_api_compat/pandas_standard/column_object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import datetime from typing import TYPE_CHECKING from typing import Any @@ -51,15 +52,23 @@ def __init__( df DataFrame this column originates from. """ - from dataframe_api_compat.pandas_standard.scalar_object import Scalar self._name = series.name or "" self._series = series self._api_version = api_version self._df = df - self._scalar = Scalar self._is_persisted = is_persisted + def _to_scalar(self, value: Any) -> Scalar: + from dataframe_api_compat.pandas_standard.scalar_object import Scalar + + return Scalar( + value, + api_version=self._api_version, + df=self._df, + is_persisted=self._is_persisted, + ) + def __repr__(self) -> str: # pragma: no cover header = f" Standard Column (api_version={self._api_version}) " length = len(header) @@ -83,6 +92,7 @@ def _from_series(self, series: pd.Series) -> Column: series.reset_index(drop=True), api_version=self._api_version, df=self._df, + is_persisted=self._is_persisted, ) def _materialise(self) -> pd.Series: @@ -102,6 +112,12 @@ def __column_namespace__( ) def persist(self) -> Column: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on Column that was already persisted", + UserWarning, + stacklevel=2, + ) return Column( self.column, df=None, @@ -136,11 +152,8 @@ def filter(self, mask: Column) -> Column: def get_value(self, row_number: int) -> Any: ser = self.column - return self._scalar( + return self._to_scalar( ser.iloc[row_number], - api_version=self._api_version, - df=self._df, - is_persisted=self._is_persisted, ) def slice_rows( @@ -270,35 +283,35 @@ def __invert__(self: Column) -> Column: def any(self, *, skip_nulls: bool | Scalar = True) -> Scalar: ser = self.column - return self._scalar(ser.any(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.any()) def all(self, *, skip_nulls: bool | Scalar = True) -> Scalar: ser = self.column - return self._scalar(ser.all(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.all()) def min(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.min(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.min()) def max(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.max(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.max()) def sum(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.sum(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.sum()) def prod(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.prod(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.prod()) def median(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.median(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.median()) def mean(self, *, skip_nulls: bool | Scalar = True) -> Any: ser = self.column - return self._scalar(ser.mean(), api_version=self._api_version, df=self._df) + return self._to_scalar(ser.mean()) def std( self, @@ -307,10 +320,8 @@ def std( skip_nulls: bool | Scalar = True, ) -> Any: ser = self.column - return self._scalar( + return self._to_scalar( ser.std(ddof=correction), - api_version=self._api_version, - df=self._df, ) def var( @@ -320,16 +331,20 @@ def var( skip_nulls: bool | Scalar = True, ) -> Any: ser = self.column - return self._scalar( + return self._to_scalar( ser.var(ddof=correction), - api_version=self._api_version, - df=self._df, ) def __len__(self) -> int: ser = self._materialise() return len(ser) + def n_unique(self) -> Scalar: # pragma: no cover (todo, still needs adding upstream) + ser = self.column + return self._to_scalar( + ser.nunique(), + ) + # Transformations def is_null(self) -> Column: diff --git a/dataframe_api_compat/pandas_standard/dataframe_object.py b/dataframe_api_compat/pandas_standard/dataframe_object.py index d23cb8be..5cb34e78 100644 --- a/dataframe_api_compat/pandas_standard/dataframe_object.py +++ b/dataframe_api_compat/pandas_standard/dataframe_object.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import warnings from typing import TYPE_CHECKING from typing import Any from typing import Iterator @@ -91,6 +92,7 @@ def _from_dataframe(self, df: pd.DataFrame) -> DataFrame: return DataFrame( df, api_version=self._api_version, + is_persisted=self._is_persisted, ) # Properties @@ -515,6 +517,12 @@ def join( ) def persist(self) -> DataFrame: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on DataFrame that was already persisted", + UserWarning, + stacklevel=2, + ) return DataFrame( self.dataframe, api_version=self._api_version, diff --git a/dataframe_api_compat/pandas_standard/scalar_object.py b/dataframe_api_compat/pandas_standard/scalar_object.py index d609efb8..7a30752c 100644 --- a/dataframe_api_compat/pandas_standard/scalar_object.py +++ b/dataframe_api_compat/pandas_standard/scalar_object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from typing import Any @@ -35,7 +36,12 @@ def __scalar_namespace__(self) -> Namespace: return Namespace(api_version=self._api_version) def _from_scalar(self, scalar: Scalar) -> Scalar: - return Scalar(scalar, df=self._df, api_version=self._api_version) + return Scalar( + scalar, + df=self._df, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) @property def dtype(self) -> DType: # pragma: no cover # todo @@ -57,6 +63,12 @@ def _materialise(self) -> Any: return self._value def persist(self) -> Scalar: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on Scalar that was already persisted", + UserWarning, + stacklevel=2, + ) return Scalar( self._value, df=None, diff --git a/dataframe_api_compat/polars_standard/__init__.py b/dataframe_api_compat/polars_standard/__init__.py index bce1c11e..336355ce 100644 --- a/dataframe_api_compat/polars_standard/__init__.py +++ b/dataframe_api_compat/polars_standard/__init__.py @@ -148,19 +148,14 @@ def dataframe_from_columns( api_version=list(api_version)[0], ) - def column_from_1d_array( + def column_from_1d_array( # type: ignore[override] self, array: Any, *, - dtype: DType, name: str = "", ) -> Column: - ser = pl.Series( - values=array, - dtype=_map_standard_to_polars_dtypes(dtype), - name=name, - ) - return Column(pl.lit(ser), api_version=self.api_version, df=None) + ser = pl.Series(values=array, name=name) + return Column(ser, api_version=self.api_version, df=None, is_persisted=True) def column_from_sequence( self, @@ -174,7 +169,7 @@ def column_from_sequence( dtype=_map_standard_to_polars_dtypes(dtype), name=name, ) - return Column(pl.lit(ser), api_version=self.api_version, df=None) + return Column(ser, api_version=self.api_version, df=None, is_persisted=True) def dataframe_from_2d_array( self, @@ -303,7 +298,7 @@ def concat( dataframes: Sequence[DataFrameT], ) -> DataFrame: dataframes = cast("Sequence[DataFrame]", dataframes) - dfs: list[pl.LazyFrame] = [] + dfs: list[pl.LazyFrame | pl.DataFrame] = [] api_versions: set[str] = set() for df in dataframes: dfs.append(df.dataframe) @@ -311,8 +306,9 @@ def concat( if len(api_versions) > 1: # pragma: no cover msg = f"Multiple api versions found: {api_versions}" raise ValueError(msg) + # todo raise if not all share persistedness return DataFrame( - pl.concat(dfs), + pl.concat(dfs), # type: ignore[type-var] api_version=api_versions.pop(), ) @@ -427,7 +423,12 @@ def convert_to_standard_compliant_column( ser: pl.Series, api_version: str | None = None, ) -> Column: - return Column(pl.lit(ser), api_version=api_version or "2023.11-beta", df=None) + return Column( + ser, + api_version=api_version or "2023.11-beta", + df=None, + is_persisted=True, + ) def convert_to_standard_compliant_dataframe( diff --git a/dataframe_api_compat/polars_standard/column_object.py b/dataframe_api_compat/polars_standard/column_object.py index e7eaa4e5..033a70f4 100644 --- a/dataframe_api_compat/polars_standard/column_object.py +++ b/dataframe_api_compat/polars_standard/column_object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from typing import Any from typing import Literal @@ -24,26 +25,32 @@ ColumnT = object +def _extract_name(expr: pl.Expr | pl.Series, df: DataFrame | None) -> str: + if isinstance(expr, pl.Expr): + try: + return expr.meta.output_name() + except pl.ComputeError: # pragma: no cover + # can remove if/when requiring polars >= 0.19.13 + if df is not None: + # Unexpected error. Just let it raise. + raise + return "" + return expr.name + + class Column(ColumnT): def __init__( self, - expr: pl.Expr, + expr: pl.Expr | pl.Series, *, df: DataFrame | None, api_version: str, is_persisted: bool = False, ) -> None: self._expr = expr + self._name = _extract_name(expr, df) self._df = df self._api_version = api_version - try: - self._name = expr.meta.output_name() - except pl.ComputeError: # pragma: no cover - # can remove if/when requiring polars >= 0.19.13 - if df is not None: - # Unexpected error. Just let it raise. - raise - self._name = "" self._is_persisted = is_persisted def __repr__(self) -> str: # pragma: no cover @@ -63,8 +70,13 @@ def __repr__(self) -> str: # pragma: no cover def __iter__(self) -> NoReturn: raise NotImplementedError - def _from_expr(self, expr: pl.Expr) -> Self: - return self.__class__(expr, df=self._df, api_version=self._api_version) + def _from_expr(self, expr: pl.Expr | pl.Series) -> Self: + return self.__class__( + expr, + df=self._df, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) def _materialise(self) -> pl.Series: if not self._is_persisted: @@ -72,11 +84,7 @@ def _materialise(self) -> pl.Series: raise RuntimeError( msg, ) - if self._df is not None: - df = self._df.dataframe.collect().select(self._expr) - else: - df = pl.select(self._expr) - return df.get_column(df.columns[0]) + return self._expr # type: ignore[return-value] # In the standard def __column_namespace__(self) -> Namespace: # pragma: no cover @@ -86,24 +94,30 @@ def __column_namespace__(self) -> Namespace: # pragma: no cover api_version=self._api_version, ) - def _to_scalar(self, value: pl.Expr, *, is_persisted: bool = False) -> Scalar: + def _to_scalar(self, value: Any) -> Scalar: from dataframe_api_compat.polars_standard.scalar_object import Scalar return Scalar( value, api_version=self._api_version, df=self._df, - is_persisted=is_persisted, + is_persisted=self._is_persisted, ) def persist(self) -> Column: if self._df is not None: - df = self._df.dataframe.collect().select(self._expr) + assert isinstance(self._df.dataframe, pl.LazyFrame) # help mypy + df = self._df.dataframe.select(self._expr).collect() else: + warnings.warn( + "Calling `.persist` on Column that was already persisted", + UserWarning, + stacklevel=2, + ) df = pl.select(self._expr) column = df.get_column(df.columns[0]) return Column( - pl.lit(column), + column, df=None, api_version=self._api_version, is_persisted=True, @@ -114,7 +128,7 @@ def name(self) -> str: return self._name @property - def column(self) -> pl.Expr: + def column(self) -> pl.Expr | pl.Series: return self._expr @property @@ -139,18 +153,16 @@ def get_rows(self, indices: Column) -> Column: return self._from_expr(self._expr.gather(indices._expr)) def filter(self, mask: Column) -> Column: - return self._from_expr(self._expr.filter(mask._expr)) + return self._from_expr(self._expr.filter(mask._expr)) # type: ignore[arg-type] def get_value(self, row_number: int) -> Any: if POLARS_VERSION < (0, 19, 14): - return self._to_scalar( - self._expr.take(row_number), - is_persisted=self._is_persisted, - ) - return self._to_scalar( - self._expr.gather(row_number), - is_persisted=self._is_persisted, - ) + result = self._expr.take(row_number) + else: + result = self._expr.gather(row_number) + if isinstance(result, pl.Series): + return self._to_scalar(result.item()) + return self._to_scalar(result) def slice_rows( self, @@ -370,7 +382,7 @@ def sort( return self._from_expr(expr) def is_in(self, values: Self) -> Self: - return self._from_expr(self._expr.is_in(values._expr)) + return self._from_expr(self._expr.is_in(values._expr)) # type: ignore[arg-type] def sorted_indices( self, diff --git a/dataframe_api_compat/polars_standard/dataframe_object.py b/dataframe_api_compat/polars_standard/dataframe_object.py index 49cc3d6c..5ed9c573 100644 --- a/dataframe_api_compat/polars_standard/dataframe_object.py +++ b/dataframe_api_compat/polars_standard/dataframe_object.py @@ -2,6 +2,7 @@ import collections import secrets +import warnings from typing import TYPE_CHECKING from typing import Any from typing import Iterator @@ -48,7 +49,7 @@ def generate_random_token(column_names: list[str]) -> str: class DataFrame(DataFrameT): def __init__( self, - df: pl.LazyFrame, + df: pl.LazyFrame | pl.DataFrame, *, api_version: str, is_persisted: bool = False, @@ -65,7 +66,7 @@ def _validate_is_persisted(self) -> pl.DataFrame: raise ValueError( msg, ) - return self.dataframe.collect() + return self.dataframe # type: ignore[return-value] def __repr__(self) -> str: # pragma: no cover header = f" Standard DataFrame (api_version={self._api_version}) " @@ -88,10 +89,11 @@ def _validate_booleanness(self) -> None: msg, ) - def _from_dataframe(self, df: pl.LazyFrame) -> DataFrame: + def _from_dataframe(self, df: pl.LazyFrame | pl.DataFrame) -> DataFrame: return DataFrame( df, api_version=self._api_version, + is_persisted=self._is_persisted, ) # Properties @@ -109,7 +111,7 @@ def column_names(self) -> list[str]: return self.dataframe.columns @property - def dataframe(self) -> pl.LazyFrame: + def dataframe(self) -> pl.LazyFrame | pl.DataFrame: return self._df # In the Standard @@ -125,11 +127,18 @@ def columns_iter(self) -> Iterator[Column]: def col(self, value: str) -> Column: from dataframe_api_compat.polars_standard.column_object import Column + if isinstance(self.dataframe, pl.DataFrame): + return Column( + self.dataframe.get_column(value), + df=None, + api_version=self._api_version, + is_persisted=True, + ) return Column( pl.col(value), df=self, api_version=self._api_version, - is_persisted=self._is_persisted, + is_persisted=False, ) def shape(self) -> tuple[int, int]: @@ -522,7 +531,7 @@ def join( ) result = self.dataframe.join( - other_df, + other_df, # type: ignore[arg-type] left_on=left_on, right_on=right_on, how=how, @@ -532,8 +541,17 @@ def join( return self._from_dataframe(result) def persist(self) -> DataFrame: + if isinstance(self.dataframe, pl.DataFrame): + warnings.warn( + "Calling `.persist` on DataFrame that was already persisted", + UserWarning, + stacklevel=2, + ) + df = self.dataframe + else: + df = self.dataframe.collect() return DataFrame( - self.dataframe.collect().lazy(), + df, api_version=self._api_version, is_persisted=True, ) diff --git a/dataframe_api_compat/polars_standard/group_by_object.py b/dataframe_api_compat/polars_standard/group_by_object.py index d54c5f38..b02b0f48 100644 --- a/dataframe_api_compat/polars_standard/group_by_object.py +++ b/dataframe_api_compat/polars_standard/group_by_object.py @@ -21,7 +21,12 @@ class GroupBy(GroupByT): - def __init__(self, df: pl.LazyFrame, keys: Sequence[str], api_version: str) -> None: + def __init__( + self, + df: pl.LazyFrame | pl.DataFrame, + keys: Sequence[str], + api_version: str, + ) -> None: for key in keys: if key not in df.columns: msg = f"key {key} not present in DataFrame's columns" diff --git a/dataframe_api_compat/polars_standard/scalar_object.py b/dataframe_api_compat/polars_standard/scalar_object.py index a48ac62e..719afb1a 100644 --- a/dataframe_api_compat/polars_standard/scalar_object.py +++ b/dataframe_api_compat/polars_standard/scalar_object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from typing import Any @@ -37,7 +38,12 @@ def __scalar_namespace__(self) -> Namespace: return Namespace(api_version=self._api_version) def _from_scalar(self, scalar: Scalar) -> Scalar: - return Scalar(scalar, df=self._df, api_version=self._api_version) + return Scalar( + scalar, + df=self._df, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) @property def dtype(self) -> DType: # pragma: no cover # todo @@ -56,18 +62,19 @@ def _materialise(self) -> Any: msg = "Can't call __bool__ on Scalar. Please use .persist() first." raise RuntimeError(msg) - if self._df is None: - value = pl.select(self._value).item() - else: - df = self._df.dataframe.collect().select(self._value) - value = df.get_column(df.columns[0]).item() - return value + return pl.select(self._value).item() def persist(self) -> Scalar: if self._df is None: - value = pl.select(self._value).item() + warnings.warn( + "Calling `.persist` on Scalar that was already persisted", + UserWarning, + stacklevel=2, + ) + value = self._value else: - df = self._df.dataframe.collect().select(self._value) + assert isinstance(self._df.dataframe, pl.LazyFrame) # help mypy + df = self._df.dataframe.select(self._value).collect() value = df.get_column(df.columns[0]).item() return Scalar( value, diff --git a/docs/basics/persist.md b/docs/basics/persist.md index adf65338..10c3ee94 100644 --- a/docs/basics/persist.md +++ b/docs/basics/persist.md @@ -19,9 +19,18 @@ then you'll likely be fine. The `dataframe-api-compat` package is written with lazy computation in mind. For the Polars implementation, all objects are backed by lazy constructs: -- `DataFrame`: backed by `polars.LazyFrame`. -- `Column`: backed by `polars.Expr`. -- `Scalar`: backed by `polars.Expr`. +- `DataFrame`: + - by default, backed by `polars.LazyFrame` + - if you call `persist`, backed by `polars.DataFrame` +- `Column`: + - by default, backed by `polars.Expr` + - if you call `persist`, or if you called `persist` on + the dataframe it was derived from, backed by `polars.Series` +- `Scalar`: + - by default, backed by `polars.Expr` + - if you call `persist`, or if you called `persist` on + the dataframe or column it was derived from, backed by + a Python scalar. All operations can be done lazily, except for: - `DataFrame.to_array()`, diff --git a/tests/column/any_all_test.py b/tests/column/any_all_test.py index b8082a6d..a775dfff 100644 --- a/tests/column/any_all_test.py +++ b/tests/column/any_all_test.py @@ -11,10 +11,11 @@ def test_expr_any(library: str) -> None: bool(df.col("a").any()) df = df.persist() result = df.col("a").any() - assert bool(result.persist()) + with pytest.warns(UserWarning): + assert bool(result.persist()) def test_expr_all(library: str) -> None: df = bool_dataframe_1(library).persist() result = df.col("a").all() - assert not bool(result.persist()) + assert not bool(result) diff --git a/tests/column/fill_null_test.py b/tests/column/fill_null_test.py index d52478cc..611c9efb 100644 --- a/tests/column/fill_null_test.py +++ b/tests/column/fill_null_test.py @@ -19,11 +19,11 @@ def test_fill_null_noop_column(library: str) -> None: result = df.assign(ser.fill_null(0).rename("result")).persist().col("result") if library != "pandas-numpy": # nan should not have changed! - assert float(result.get_value(2).persist()) != float( # type: ignore[arg-type] - result.get_value(2).persist(), # type: ignore[arg-type] + assert float(result.get_value(2)) != float( # type: ignore[arg-type] + result.get_value(2), # type: ignore[arg-type] ) else: # nan was filled with 0 - assert float(result.get_value(2).persist()) == 0 # type: ignore[arg-type] - assert float(result.get_value(1).persist()) != 0.0 # type: ignore[arg-type] - assert float(result.get_value(0).persist()) != 0.0 # type: ignore[arg-type] + assert float(result.get_value(2)) == 0 # type: ignore[arg-type] + assert float(result.get_value(1)) != 0.0 # type: ignore[arg-type] + assert float(result.get_value(0)) != 0.0 # type: ignore[arg-type] diff --git a/tests/column/get_value_test.py b/tests/column/get_value_test.py index 98e7b37f..49a30c89 100644 --- a/tests/column/get_value_test.py +++ b/tests/column/get_value_test.py @@ -5,9 +5,9 @@ def test_get_value(library: str) -> None: result = integer_dataframe_1(library).persist().col("a").get_value(0) - assert int(result.persist()) == 1 # type: ignore[call-overload] + assert int(result) == 1 # type: ignore[call-overload] def test_mean_scalar(library: str) -> None: result = integer_dataframe_1(library).persist().col("a").max() - assert int(result.persist()) == 3 # type: ignore[call-overload] + assert int(result) == 3 # type: ignore[call-overload] diff --git a/tests/integration/free_vs_w_parent_test.py b/tests/integration/free_vs_w_parent_test.py index b75a9fba..dab61de4 100644 --- a/tests/integration/free_vs_w_parent_test.py +++ b/tests/integration/free_vs_w_parent_test.py @@ -8,14 +8,12 @@ def test_free_vs_w_parent(library: str) -> None: df1 = integer_dataframe_1(library) namespace = df1.__dataframe_namespace__() - free_ser1 = namespace.column_from_1d_array( - np.array([1, 2, 3]), - dtype=namespace.Int64(), + free_ser1 = namespace.column_from_1d_array( # type: ignore[call-arg] + np.array([1, 2, 3], dtype="int64"), name="preds", ) - free_ser2 = namespace.column_from_1d_array( - np.array([4, 5, 6]), - dtype=namespace.Int64(), + free_ser2 = namespace.column_from_1d_array( # type: ignore[call-arg] + np.array([4, 5, 6], dtype="int64"), name="preds", ) diff --git a/tests/integration/persistedness_test.py b/tests/integration/persistedness_test.py index 2c512dde..bd8f8d08 100644 --- a/tests/integration/persistedness_test.py +++ b/tests/integration/persistedness_test.py @@ -18,8 +18,7 @@ def test_within_df_propagation(library: str) -> None: df1 = df1.persist() df1 = df1 + 1 # the call below would recompute `df1 + 1` multiple times - with pytest.raises(RuntimeError): - _ = int(df1.col("a").get_value(0)) # type: ignore[call-overload] + assert int(df1.col("a").get_value(0)) == 2 # type: ignore[call-overload] # this is the correct way df1 = integer_dataframe_1(library) @@ -36,8 +35,7 @@ def test_within_df_propagation(library: str) -> None: df1 = integer_dataframe_1(library) df1 = df1 + 1 col = df1.col("a").persist() - with pytest.raises(RuntimeError): - assert int((col + 1).get_value(0)) == 2 # type: ignore[call-overload] + assert int((col + 1).get_value(0)) == 3 # type: ignore[call-overload] # persisting the scalar works too df1 = integer_dataframe_1(library) @@ -48,21 +46,20 @@ def test_within_df_propagation(library: str) -> None: df1 = integer_dataframe_1(library) df1 = df1 + 1 scalar = df1.col("a").get_value(0).persist() - with pytest.raises(RuntimeError): - assert int(scalar + 1) == 2 # type: ignore[call-overload] + assert int(scalar + 1) == 3 # type: ignore[call-overload] def test_within_df_within_col_propagation(library: str) -> None: df1 = integer_dataframe_1(library) df1 = df1 + 1 df1 = df1.persist() - assert int((df1.col("a") + 1).mean().persist()) == 4 # type: ignore[call-overload] + assert int((df1.col("a") + 1).mean()) == 4 # type: ignore[call-overload] def test_cross_df_propagation(library: str) -> None: df1 = integer_dataframe_1(library) df2 = integer_dataframe_2(library) - df1 = df1 + 1 + df1 = (df1 + 1).persist() df2 = df2.rename_columns({"b": "c"}).persist() result = df1.join(df2, how="left", left_on="a", right_on="a") result_pd = convert_dataframe_to_pandas_numpy(interchange_to_pandas(result)) @@ -81,10 +78,11 @@ def test_multiple_propagations(library: str) -> None: # multiple times to do things optimally df = integer_dataframe_1(library) df = df.persist() - df1 = df.filter(df.col("a") > 1).persist() - df2 = df.filter(df.col("a") <= 1).persist() - assert int(df1.col("a").mean().persist()) == 2 # type: ignore[call-overload] - assert int(df2.col("a").mean().persist()) == 1 # type: ignore[call-overload] + with pytest.warns(UserWarning): + df1 = df.filter(df.col("a") > 1).persist() + df2 = df.filter(df.col("a") <= 1).persist() + assert int(df1.col("a").mean()) == 2 # type: ignore[call-overload] + assert int(df2.col("a").mean()) == 1 # type: ignore[call-overload] # But what if I want to do this df = integer_dataframe_1(library) @@ -94,9 +92,8 @@ def test_multiple_propagations(library: str) -> None: df1 = df1 + 1 # without this persist, `df1 + 1` will be computed twice - df1 = df1.persist() - int(df1.col("a").mean().persist()) # type: ignore[call-overload] - int(df1.col("a").mean().persist()) # type: ignore[call-overload] + int(df1.col("a").mean()) # type: ignore[call-overload] + int(df1.col("a").mean()) # type: ignore[call-overload] def test_parent_propagations(library: str) -> None: diff --git a/tests/namespace/column_from_1d_array_test.py b/tests/namespace/column_from_1d_array_test.py index 10bff425..4b8c158b 100644 --- a/tests/namespace/column_from_1d_array_test.py +++ b/tests/namespace/column_from_1d_array_test.py @@ -15,89 +15,69 @@ @pytest.mark.parametrize( - ("namespace_dtype", "pandas_dtype"), + "pandas_dtype", [ - ("Float64", "float64"), - ("Float32", "float32"), - ("Int64", "int64"), - ("Int32", "int32"), - ("Int16", "int16"), - ("Int8", "int8"), - ("UInt64", "uint64"), - ("UInt32", "uint32"), - ("UInt16", "uint16"), - ("UInt8", "uint8"), + "float64", + "float32", + "int64", + "int32", + "int16", + "int8", + "uint64", + "uint32", + "uint16", + "uint8", ], ) def test_column_from_1d_array( library: str, - namespace_dtype: str, pandas_dtype: str, ) -> None: ser = integer_dataframe_1(library).col("a").persist() namespace = ser.__column_namespace__() - arr = np.array([1, 2, 3]) + arr = np.array([1, 2, 3], dtype=pandas_dtype) result = namespace.dataframe_from_columns( - namespace.column_from_1d_array( + namespace.column_from_1d_array( # type: ignore[call-arg] arr, name="result", - dtype=getattr(namespace, namespace_dtype)(), - ).persist(), + ), ) result_pd = interchange_to_pandas(result)["result"] expected = pd.Series([1, 2, 3], name="result", dtype=pandas_dtype) pd.testing.assert_series_equal(result_pd, expected) -@pytest.mark.parametrize( - ("namespace_dtype", "pandas_dtype"), - [ - ("String", "object"), - ], -) def test_column_from_1d_array_string( library: str, - namespace_dtype: str, - pandas_dtype: str, ) -> None: ser = integer_dataframe_1(library).persist().col("a") namespace = ser.__column_namespace__() arr = np.array(["a", "b", "c"]) result = namespace.dataframe_from_columns( - namespace.column_from_1d_array( + namespace.column_from_1d_array( # type: ignore[call-arg] arr, name="result", - dtype=getattr(namespace, namespace_dtype)(), - ).persist(), + ), ) result_pd = interchange_to_pandas(result)["result"] - expected = pd.Series(["a", "b", "c"], name="result", dtype=pandas_dtype) + expected = pd.Series(["a", "b", "c"], name="result", dtype="object") pd.testing.assert_series_equal(result_pd, expected) -@pytest.mark.parametrize( - ("namespace_dtype", "pandas_dtype"), - [ - ("Bool", "bool"), - ], -) def test_column_from_1d_array_bool( library: str, - namespace_dtype: str, - pandas_dtype: str, ) -> None: ser = integer_dataframe_1(library).persist().col("a") namespace = ser.__column_namespace__() arr = np.array([True, False, True]) result = namespace.dataframe_from_columns( - namespace.column_from_1d_array( + namespace.column_from_1d_array( # type: ignore[call-arg] arr, name="result", - dtype=getattr(namespace, namespace_dtype)(), - ).persist(), + ), ) result_pd = interchange_to_pandas(result)["result"] - expected = pd.Series([True, False, True], name="result", dtype=pandas_dtype) + expected = pd.Series([True, False, True], name="result") pd.testing.assert_series_equal(result_pd, expected) @@ -106,11 +86,10 @@ def test_datetime_from_1d_array(library: str) -> None: namespace = ser.__column_namespace__() arr = np.array([date(2020, 1, 1), date(2020, 1, 2)], dtype="datetime64[ms]") result = namespace.dataframe_from_columns( - namespace.column_from_1d_array( + namespace.column_from_1d_array( # type: ignore[call-arg] arr, name="result", - dtype=namespace.Datetime("ms"), - ).persist(), + ), ) result_pd = interchange_to_pandas(result)["result"] expected = pd.Series( @@ -134,11 +113,10 @@ def test_duration_from_1d_array(library: str) -> None: namespace = ser.__column_namespace__() arr = np.array([timedelta(1), timedelta(2)], dtype="timedelta64[ms]") result = namespace.dataframe_from_columns( - namespace.column_from_1d_array( + namespace.column_from_1d_array( # type: ignore[call-arg] arr, name="result", - dtype=namespace.Duration("ms"), - ).persist(), + ), ) if library == "polars-lazy": # https://github.com/data-apis/dataframe-api/issues/329 diff --git a/tests/namespace/column_from_sequence_test.py b/tests/namespace/column_from_sequence_test.py index 4f5eec68..7ccabc2d 100644 --- a/tests/namespace/column_from_sequence_test.py +++ b/tests/namespace/column_from_sequence_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import datetime +from datetime import timedelta from typing import Any import pandas as pd @@ -10,31 +12,67 @@ @pytest.mark.parametrize( - ("values", "dtype", "expected"), + ("values", "dtype", "kwargs", "expected"), [ - ([1, 2, 3], "Int64", pd.Series([1, 2, 3], dtype="int64", name="result")), - ([1, 2, 3], "Int32", pd.Series([1, 2, 3], dtype="int32", name="result")), + ([1, 2, 3], "Int64", {}, pd.Series([1, 2, 3], dtype="int64", name="result")), + ([1, 2, 3], "Int32", {}, pd.Series([1, 2, 3], dtype="int32", name="result")), + ([1, 2, 3], "Int16", {}, pd.Series([1, 2, 3], dtype="int16", name="result")), + ([1, 2, 3], "Int8", {}, pd.Series([1, 2, 3], dtype="int8", name="result")), + ([1, 2, 3], "UInt64", {}, pd.Series([1, 2, 3], dtype="uint64", name="result")), + ([1, 2, 3], "UInt32", {}, pd.Series([1, 2, 3], dtype="uint32", name="result")), + ([1, 2, 3], "UInt16", {}, pd.Series([1, 2, 3], dtype="uint16", name="result")), + ([1, 2, 3], "UInt8", {}, pd.Series([1, 2, 3], dtype="uint8", name="result")), ( [1.0, 2.0, 3.0], "Float64", + {}, pd.Series([1, 2, 3], dtype="float64", name="result"), ), ( [1.0, 2.0, 3.0], "Float32", + {}, pd.Series([1, 2, 3], dtype="float32", name="result"), ), ( [True, False, True], "Bool", + {}, pd.Series([True, False, True], dtype=bool, name="result"), ), + ( + ["express", "yourself"], + "String", + {}, + pd.Series(["express", "yourself"], dtype=object, name="result"), + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2)], + "Datetime", + {"time_unit": "us"}, + pd.Series( + [datetime(2020, 1, 1), datetime(2020, 1, 2)], + dtype="datetime64[us]", + name="result", + ), + ), + ( + [timedelta(1), timedelta(2)], + "Duration", + {"time_unit": "us"}, + pd.Series( + [timedelta(1), timedelta(2)], + dtype="timedelta64[us]", + name="result", + ), + ), ], ) def test_column_from_sequence( library: str, values: list[Any], dtype: str, + kwargs: dict[str, Any], expected: pd.Series[Any], ) -> None: df = integer_dataframe_1(library) @@ -44,9 +82,9 @@ def test_column_from_sequence( result = namespace.dataframe_from_columns( namespace.column_from_sequence( values, - dtype=getattr(namespace, dtype)(), + dtype=getattr(namespace, dtype)(**kwargs), name="result", - ).persist(), + ), ) result_pd = interchange_to_pandas(result)["result"] pd.testing.assert_series_equal(result_pd, expected) diff --git a/tests/namespace/convert_to_standard_column_test.py b/tests/namespace/convert_to_standard_column_test.py index 2fa2f644..e13485dd 100644 --- a/tests/namespace/convert_to_standard_column_test.py +++ b/tests/namespace/convert_to_standard_column_test.py @@ -14,10 +14,10 @@ ) def test_convert_to_std_column() -> None: s = pl.Series([1, 2, 3]).__column_consortium_standard__() - assert float(s.mean().persist()) == 2 + assert float(s.mean()) == 2 s = pl.Series("bob", [1, 2, 3]).__column_consortium_standard__() - assert float(s.mean().persist()) == 2 + assert float(s.mean()) == 2 s = pd.Series([1, 2, 3]).__column_consortium_standard__() - assert float(s.mean().persist()) == 2 + assert float(s.mean()) == 2 s = pd.Series([1, 2, 3], name="alice").__column_consortium_standard__() - assert float(s.mean().persist()) == 2 + assert float(s.mean()) == 2 diff --git a/tests/scalars/float_test.py b/tests/scalars/float_test.py index 48c2ff44..bfcfd5d0 100644 --- a/tests/scalars/float_test.py +++ b/tests/scalars/float_test.py @@ -36,13 +36,13 @@ def test_float_binary(library: str, attr: str) -> None: other = 0.5 df = integer_dataframe_2(library).persist() scalar = df.col("a").mean() - float_scalar = float(scalar.persist()) # type: ignore[arg-type] - assert (getattr(scalar, attr)(other) == getattr(float_scalar, attr)(other)).persist() + float_scalar = float(scalar) # type: ignore[arg-type] + assert getattr(scalar, attr)(other) == getattr(float_scalar, attr)(other) def test_float_binary_invalid(library: str) -> None: - lhs = integer_dataframe_2(library).persist().col("a").mean() - rhs = integer_dataframe_1(library).persist().col("b").mean() + lhs = integer_dataframe_2(library).col("a").mean() + rhs = integer_dataframe_1(library).col("b").mean() with pytest.raises(ValueError): _ = lhs > rhs @@ -52,7 +52,7 @@ def test_float_binary_lazy_valid(library: str) -> None: lhs = df.col("a").mean() rhs = df.col("b").mean() result = lhs > rhs - assert not bool(result.persist()) + assert not bool(result) @pytest.mark.parametrize( @@ -64,9 +64,10 @@ def test_float_binary_lazy_valid(library: str) -> None: ) def test_float_unary(library: str, attr: str) -> None: df = integer_dataframe_2(library).persist() - scalar = df.col("a").mean() - float_scalar = float(scalar.persist()) # type: ignore[arg-type] - assert (getattr(scalar, attr)() == getattr(float_scalar, attr)()).persist() + with pytest.warns(UserWarning): + scalar = df.col("a").persist().mean() + float_scalar = float(scalar) # type: ignore[arg-type] + assert getattr(scalar, attr)() == getattr(float_scalar, attr)() @pytest.mark.parametrize( @@ -78,7 +79,7 @@ def test_float_unary(library: str, attr: str) -> None: ], ) def test_float_unary_invalid(library: str, attr: str) -> None: - df = integer_dataframe_2(library).persist() + df = integer_dataframe_2(library) scalar = df.col("a").mean() float_scalar = float(scalar.persist()) # type: ignore[arg-type] with pytest.raises(RuntimeError): @@ -88,12 +89,11 @@ def test_float_unary_invalid(library: str, attr: str) -> None: def test_free_standing(library: str) -> None: df = integer_dataframe_1(library) namespace = df.__dataframe_namespace__() - ser = namespace.column_from_1d_array( + ser = namespace.column_from_1d_array( # type: ignore[call-arg] np.array([1, 2, 3]), - dtype=namespace.Int64(), name="a", ) - result = float((ser.mean() + 1).persist()) # type: ignore[arg-type] + result = float(ser.mean() + 1) # type: ignore[arg-type] assert result == 3.0 diff --git a/tests/utils.py b/tests/utils.py index a4f72a6b..d434e3db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -475,6 +475,9 @@ def interchange_to_pandas(result: Any) -> pd.DataFrame: if isinstance(result.dataframe, pl.LazyFrame): df = result.dataframe.collect() df = df.to_pandas() + elif isinstance(result.dataframe, pl.DataFrame): + df = result.dataframe + df = df.to_pandas() else: df = result.dataframe df = convert_dataframe_to_pandas_numpy(df)