diff --git a/.github/workflows/earliest_versions.yml b/.github/workflows/earliest_versions.yml index 51570902..f536efde 100644 --- a/.github/workflows/earliest_versions.yml +++ b/.github/workflows/earliest_versions.yml @@ -31,4 +31,4 @@ jobs: - name: install-reqs run: python -m pip install --upgrade tox virtualenv setuptools pip && python -m pip install -r requirements-dev.txt - name: Run pytest - run: pytest tests --cov=dataframe_api_compat --cov=tests --cov-fail-under=80 + run: pytest tests --cov=dataframe_api_compat/pandas_standard --cov=dataframe_api_compat/polars_standard --cov=tests --cov-fail-under=80 diff --git a/.github/workflows/mkdocs.yml b/.github/workflows/mkdocs.yml index 62f66047..5716005d 100644 --- a/.github/workflows/mkdocs.yml +++ b/.github/workflows/mkdocs.yml @@ -19,14 +19,11 @@ jobs: with: python-version: 3.x - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - - uses: actions/cache@v3 with: key: mkdocs-material-${{ env.cache_id }} path: .cache restore-keys: | mkdocs-material- - - run: pip install -r docs/requirements-docs.txt -e . pandas polars - + - run: pip install -r docs/requirements-docs.txt -e . pandas "polars<0.20.8" - run: mkdocs gh-deploy --force diff --git a/.github/workflows/random_version.yml b/.github/workflows/random_version.yml index a20d0270..7c2a0f43 100644 --- a/.github/workflows/random_version.yml +++ b/.github/workflows/random_version.yml @@ -35,4 +35,4 @@ jobs: - name: install-reqs run: python -m pip install --upgrade tox virtualenv setuptools pip && python -m pip install -r requirements-dev.txt - name: Run pytest - run: pytest tests --cov=dataframe_api_compat --cov=tests --cov-fail-under=80 + run: pytest tests --cov=dataframe_api_compat/pandas_standard --cov=dataframe_api_compat/polars_standard --cov=tests --cov-fail-under=80 diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 9fd35414..3c2f7d47 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -28,17 +28,32 @@ jobs: key: ${{ runner.os }}-build-${{ matrix.python-version }} - name: install-reqs run: python -m pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt - - name: Run pytest - run: pytest tests --cov=dataframe_api_compat --cov=tests --cov-fail-under=100 + - name: Run pytest for pandas and polars + run: | + pytest tests --cov=dataframe_api_compat/pandas_standard --cov=tests --cov-append --cov-fail-under=50 --cov-report= --library pandas-numpy + pytest tests --cov=dataframe_api_compat/pandas_standard --cov=tests --cov-append --cov-fail-under=50 --cov-report= --library pandas-nullable + pytest tests --cov=dataframe_api_compat/polars_standard --cov=tests --cov-append --cov-fail-under=95 --library polars-lazy - name: install type-checking reqs run: python -m pip install 'git+https://github.com/data-apis/dataframe-api.git#egg=dataframe_api&subdirectory=spec/API_specification' mypy typing-extensions - name: run mypy - run: mypy dataframe_api_compat tests + run: mypy dataframe_api_compat/pandas_standard dataframe_api_compat/polars_standard - name: run polars integration tests run: pip uninstall pandas -y && pytest tests/integration/upstream_test.py::TestPolars && pip install -U pandas - name: run pandas integration tests run: pip uninstall polars -y && pytest tests/integration/upstream_test.py::TestPandas - + - name: Update env for modin + run: | + python -m pip uninstall -r requirements-dev.txt -y + python -m pip install -r requirements-dev-modin.txt + - name: Run pytest for modin + run: | + pytest tests --cov=dataframe_api_compat/modin_standard --cov=tests --cov-append --cov-fail-under=100 --library modin + - name: install type-checking reqs + run: python -m pip install 'git+https://github.com/data-apis/dataframe-api.git#egg=dataframe_api&subdirectory=spec/API_specification' mypy typing-extensions + - name: run mypy + run: | + mypy dataframe_api_compat/modin_standard + mypy tests --follow-imports silent tox-all-supported: strategy: matrix: diff --git a/dataframe_api_compat/__init__.py b/dataframe_api_compat/__init__.py index 837e502d..7950bf4e 100644 --- a/dataframe_api_compat/__init__.py +++ b/dataframe_api_compat/__init__.py @@ -1,13 +1,19 @@ from __future__ import annotations -import contextlib +import importlib +from typing import TYPE_CHECKING -with contextlib.suppress(ModuleNotFoundError): - from dataframe_api_compat import pandas_standard +if TYPE_CHECKING: + from types import ModuleType -with contextlib.suppress(ModuleNotFoundError): - from dataframe_api_compat import polars_standard +__all__ = ["pandas_standard", "polars_standard", "modin_standard"] + + +def __getattr__(name: str) -> ModuleType: + if name in __all__: + return importlib.import_module("." + name, __name__) + msg = f"module {__name__!r} has no attribute {name!r}" + raise AttributeError(msg) -__all__ = ["pandas_standard", "polars_standard"] __version__ = "0.2.6" diff --git a/dataframe_api_compat/modin_standard/__init__.py b/dataframe_api_compat/modin_standard/__init__.py new file mode 100644 index 00000000..fcc0e9a6 --- /dev/null +++ b/dataframe_api_compat/modin_standard/__init__.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import datetime as dt +import re +from functools import reduce +from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import cast + +import modin.pandas as pd + +from dataframe_api_compat.modin_standard.column_object import Column +from dataframe_api_compat.modin_standard.dataframe_object import DataFrame +from dataframe_api_compat.modin_standard.scalar_object import Scalar + +if TYPE_CHECKING: + from collections.abc import Sequence + + from dataframe_api.groupby_object import Aggregation as AggregationT + from dataframe_api.typing import Column as ColumnT + from dataframe_api.typing import DataFrame as DataFrameT + from dataframe_api.typing import DType + from dataframe_api.typing import Namespace as NamespaceT + from dataframe_api.typing import Scalar as ScalarT + + BoolT = NamespaceT.Bool + DateT = NamespaceT.Date + DatetimeT = NamespaceT.Datetime + DurationT = NamespaceT.Duration + Float32T = NamespaceT.Float32 + Float64T = NamespaceT.Float64 + Int8T = NamespaceT.Int8 + Int16T = NamespaceT.Int16 + Int32T = NamespaceT.Int32 + Int64T = NamespaceT.Int64 + StringT = NamespaceT.String + UInt8T = NamespaceT.UInt8 + UInt16T = NamespaceT.UInt16 + UInt32T = NamespaceT.UInt32 + UInt64T = NamespaceT.UInt64 + NullTypeT = NamespaceT.NullType +else: + NamespaceT = object + BoolT = object + DateT = object + DatetimeT = object + DurationT = object + Float32T = object + Float64T = object + Int8T = object + Int16T = object + Int32T = object + Int64T = object + StringT = object + UInt8T = object + UInt16T = object + UInt32T = object + UInt64T = object + AggregationT = object + NullTypeT = object + +SUPPORTED_VERSIONS = frozenset({"2023.11-beta"}) + + +def map_pandas_dtype_to_standard_dtype(dtype: Any) -> DType: + # this is copied from the pandas implementation; + # TODO: need similar tests for modin (tests with extension dtypes) + if dtype == "int64": + return Namespace.Int64() + if dtype == "Int64": # pragma: no cover + return Namespace.Int64() + if dtype == "int32": + return Namespace.Int32() + if dtype == "Int32": # pragma: no cover + return Namespace.Int32() + if dtype == "int16": + return Namespace.Int16() + if dtype == "Int16": # pragma: no cover + return Namespace.Int16() + if dtype == "int8": + return Namespace.Int8() + if dtype == "Int8": # pragma: no cover + return Namespace.Int8() + if dtype == "uint64": + return Namespace.UInt64() + if dtype == "UInt64": # pragma: no cover + return Namespace.UInt64() + if dtype == "uint32": + return Namespace.UInt32() + if dtype == "UInt32": # pragma: no cover + return Namespace.UInt32() + if dtype == "uint16": + return Namespace.UInt16() + if dtype == "UInt16": # pragma: no cover + return Namespace.UInt16() + if dtype == "uint8": + return Namespace.UInt8() + if dtype == "UInt8": # pragma: no cover + return Namespace.UInt8() + if dtype == "float64": + return Namespace.Float64() + if dtype == "Float64": # pragma: no cover + return Namespace.Float64() + if dtype == "float32": + return Namespace.Float32() + if dtype == "Float32": # pragma: no cover + return Namespace.Float32() + if dtype in ("bool", "boolean"): + # Also for `pandas.core.arrays.boolean.BooleanDtype` + return Namespace.Bool() + if dtype == "object": + return Namespace.String() + if dtype == "string": # pragma: no cover + return Namespace.String() + if hasattr(dtype, "name"): + # For types like `numpy.dtypes.DateTime64DType` + dtype = dtype.name + if dtype.startswith("datetime64["): + match = re.search(r"datetime64\[(\w{1,2})", dtype) + assert match is not None + time_unit = cast(Literal["ms", "us"], match.group(1)) + return Namespace.Datetime(time_unit) + if dtype.startswith("timedelta64["): + match = re.search(r"timedelta64\[(\w{1,2})", dtype) + assert match is not None + time_unit = cast(Literal["ms", "us"], match.group(1)) + return Namespace.Duration(time_unit) + msg = f"Unsupported dtype! {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def map_standard_dtype_to_pandas_dtype(dtype: DType) -> Any: + if isinstance(dtype, Namespace.Int64): + return "int64" + if isinstance(dtype, Namespace.Int32): + return "int32" + if isinstance(dtype, Namespace.Int16): + return "int16" + if isinstance(dtype, Namespace.Int8): + return "int8" + if isinstance(dtype, Namespace.UInt64): + return "uint64" + if isinstance(dtype, Namespace.UInt32): + return "uint32" + if isinstance(dtype, Namespace.UInt16): + return "uint16" + if isinstance(dtype, Namespace.UInt8): + return "uint8" + if isinstance(dtype, Namespace.Float64): + return "float64" + if isinstance(dtype, Namespace.Float32): + return "float32" + if isinstance(dtype, Namespace.Bool): + return "bool" + if isinstance(dtype, Namespace.String): + return "object" + if isinstance(dtype, Namespace.Datetime): + if dtype.time_zone is not None: # pragma: no cover (todo) + return f"datetime64[{dtype.time_unit}, {dtype.time_zone}]" + return f"datetime64[{dtype.time_unit}]" + if isinstance(dtype, Namespace.Duration): + return f"timedelta64[{dtype.time_unit}]" + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def convert_to_standard_compliant_column( + ser: pd.Series[Any], + api_version: str | None = None, +) -> Column: # pragma: no cover + # TODO: remove pragma after after modin implements `__column_consortium_standard__` + if ser.name is not None and not isinstance(ser.name, str): + msg = f"Expected column with string name, got: {ser.name}" + raise ValueError(msg) + if ser.name is None: + ser = ser.rename("") + return Column( + ser, + api_version=api_version or "2023.11-beta", + df=None, + is_persisted=True, + ) + + +def convert_to_standard_compliant_dataframe( + df: pd.DataFrame, + api_version: str | None = None, +) -> DataFrame: + return DataFrame(df, api_version=api_version or "2023.11-beta") + + +class Namespace(NamespaceT): + def __init__(self, *, api_version: str) -> None: + self.__dataframe_api_version__ = api_version + self._api_version = api_version + + class Int64(Int64T): + ... + + class Int32(Int32T): + ... + + class Int16(Int16T): + ... + + class Int8(Int8T): + ... + + class UInt64(UInt64T): + ... + + class UInt32(UInt32T): + ... + + class UInt16(UInt16T): + ... + + class UInt8(UInt8T): + ... + + class Float64(Float64T): + ... + + class Float32(Float32T): + ... + + class Bool(BoolT): + ... + + class String(StringT): + ... + + class Date(DateT): + ... + + class Datetime(DatetimeT): + def __init__( + self, + time_unit: Literal["ms", "us"], + time_zone: str | None = None, + ) -> None: + self.time_unit = time_unit + # TODO validate time zone + self.time_zone = time_zone + + class Duration(DurationT): + def __init__(self, time_unit: Literal["ms", "us"]) -> None: + self.time_unit = time_unit + + class NullType(NullTypeT): + ... + + null = NullType() + + def dataframe_from_columns( + self, + *columns: ColumnT, + ) -> DataFrame: + data = {} + api_versions: set[str] = set() + for col in columns: + ser = col._materialise() # type: ignore[attr-defined] + data[ser.name] = ser + api_versions.add(col._api_version) # type: ignore[attr-defined] + return DataFrame(pd.DataFrame(data), api_version=list(api_versions)[0]) + + def column_from_1d_array( # type: ignore[override] + self, + data: Any, + *, + name: str | None = None, + ) -> Column: + ser = pd.Series(data, name=name) + return Column(ser, api_version=self._api_version, df=None, is_persisted=True) + + def column_from_sequence( + self, + sequence: Sequence[Any], + *, + dtype: DType | None = None, + name: str = "", + ) -> Column: + if dtype is not None: + ser = pd.Series( + sequence, + dtype=map_standard_dtype_to_pandas_dtype(dtype), + name=name, + ) + else: + ser = pd.Series(sequence, name=name) + return Column(ser, api_version=self._api_version, df=None, is_persisted=True) + + def concat( + self, + dataframes: Sequence[DataFrameT], + ) -> DataFrame: + dataframes = cast("Sequence[DataFrame]", dataframes) + dtypes = dataframes[0].dataframe.dtypes + dfs: list[pd.DataFrame] = [] + api_versions: set[str] = set() + for df in dataframes: + # TODO: implement `testing` module in modin + # For example: `pd.testing.assert_series_equal` + if not df.dataframe.dtypes.equals(dtypes): + msg = "Expected matching columns" + raise ValueError(msg) + dfs.append(df.dataframe) + api_versions.add(df._api_version) + if len(api_versions) > 1: # pragma: no cover + msg = f"Multiple api versions found: {api_versions}" + raise ValueError(msg) + return DataFrame( + pd.concat( + dfs, + axis=0, + ignore_index=True, + ), + api_version=api_versions.pop(), + ) + + def dataframe_from_2d_array( + self, + data: Any, + *, + names: Sequence[str], + ) -> DataFrame: + df = pd.DataFrame(data, columns=list(names)) + return DataFrame(df, api_version=self._api_version) + + def is_null(self, value: Any) -> bool: + return value is self.null + + def is_dtype(self, dtype: DType, kind: str | tuple[str, ...]) -> bool: + if isinstance(kind, str): + kind = (kind,) + dtypes: set[Any] = set() + for _kind in kind: + if _kind == "bool": + dtypes.add(Namespace.Bool) + if _kind == "signed integer" or _kind == "integral" or _kind == "numeric": + dtypes |= { + Namespace.Int64, + Namespace.Int32, + Namespace.Int16, + Namespace.Int8, + } + if _kind == "unsigned integer" or _kind == "integral" or _kind == "numeric": + dtypes |= { + Namespace.UInt64, + Namespace.UInt32, + Namespace.UInt16, + Namespace.UInt8, + } + if _kind == "floating" or _kind == "numeric": + dtypes |= {Namespace.Float64, Namespace.Float32} + if _kind == "string": + dtypes.add(Namespace.String) + return isinstance(dtype, tuple(dtypes)) + + def date(self, year: int, month: int, day: int) -> Scalar: + return Scalar( + pd.Timestamp(dt.date(year, month, day)), + api_version=self._api_version, + df=None, + is_persisted=True, + ) + + # --- horizontal reductions + + def all_horizontal(self, *columns: ColumnT, skip_nulls: bool = True) -> ColumnT: + return reduce(lambda x, y: x & y, columns) + + def any_horizontal(self, *columns: ColumnT, skip_nulls: bool = True) -> ColumnT: + return reduce(lambda x, y: x | y, columns) + + def sorted_indices( + self, + *columns: ColumnT, + ascending: Sequence[bool] | bool = True, + nulls_position: Literal["first", "last"] = "last", + ) -> Column: + raise NotImplementedError + + def unique_indices( + self, + *columns: ColumnT, + skip_nulls: bool | ScalarT = True, + ) -> Column: + raise NotImplementedError + + class Aggregation(AggregationT): + def __init__(self, column_name: str, output_name: str, aggregation: str) -> None: + self.column_name = column_name + self.output_name = output_name + self.aggregation = aggregation + + def __repr__(self) -> str: # pragma: no cover + return f"{self.__class__.__name__}({self.column_name!r}, {self.output_name!r}, {self.aggregation!r})" + + def _replace(self, **kwargs: str) -> AggregationT: + return self.__class__( + column_name=kwargs.get("column_name", self.column_name), + output_name=kwargs.get("output_name", self.output_name), + aggregation=kwargs.get("aggregation", self.aggregation), + ) + + def rename(self, name: str | ScalarT) -> AggregationT: + return self.__class__(self.column_name, name, self.aggregation) # type: ignore[arg-type] + + @classmethod + def any( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "any") + + @classmethod + def all( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "all") + + @classmethod + def min( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "min") + + @classmethod + def max( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "max") + + @classmethod + def sum( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "sum") + + @classmethod + def prod( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "prod") + + @classmethod + def median( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "median") + + @classmethod + def mean( + cls: AggregationT, + column: str, + *, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "mean") + + @classmethod + def std( + cls: AggregationT, + column: str, + *, + correction: float | ScalarT | NullTypeT = 1, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "std") + + @classmethod + def var( + cls: AggregationT, + column: str, + *, + correction: float | ScalarT | NullTypeT = 1, + skip_nulls: bool | ScalarT = True, + ) -> AggregationT: + return Namespace.Aggregation(column, column, "var") + + @classmethod + def size( + cls: AggregationT, + ) -> AggregationT: + return Namespace.Aggregation("__placeholder__", "size", "size") diff --git a/dataframe_api_compat/modin_standard/column_object.py b/dataframe_api_compat/modin_standard/column_object.py new file mode 100644 index 00000000..624401fe --- /dev/null +++ b/dataframe_api_compat/modin_standard/column_object.py @@ -0,0 +1,571 @@ +from __future__ import annotations + +import warnings +from datetime import datetime +from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import NoReturn + +import modin.pandas as pd +import numpy as np +from pandas.api.types import is_extension_array_dtype + +import dataframe_api_compat.modin_standard +from dataframe_api_compat.utils import validate_comparand + +if TYPE_CHECKING: + from dataframe_api import Column as ColumnT + from dataframe_api.typing import DType + from dataframe_api.typing import NullType + from dataframe_api.typing import Scalar + + from dataframe_api_compat.modin_standard.dataframe_object import DataFrame +else: + ColumnT = object + + +NUMPY_MAPPING = { + "Int64": "int64", + "Int32": "int32", + "Int16": "int16", + "Int8": "int8", + "UInt64": "uint64", + "UInt32": "uint32", + "UInt16": "uint16", + "UInt8": "uint8", + "boolean": "bool", + "Float64": "float64", + "Float32": "float32", +} + + +class Column(ColumnT): + def __init__( + self, + series: pd.Series[Any], + *, + df: DataFrame | None, + api_version: str, + is_persisted: bool = False, + ) -> None: + """Parameters + ---------- + df + DataFrame this column originates from. + """ + + self._name = series.name + assert self._name is not None + self._series = series + self._api_version = api_version + self._df = df + self._is_persisted = is_persisted + assert is_persisted ^ (df is not None) + + def _to_scalar(self, value: Any) -> Scalar: + from dataframe_api_compat.modin_standard.scalar_object import Scalar + + return Scalar( + value, + api_version=self._api_version, + df=self._df, + is_persisted=self._is_persisted, + ) + + def __repr__(self) -> str: # pragma: no cover + header = f" Standard Column (api_version={self._api_version}) " + length = len(header) + return ( + "┌" + + "─" * length + + "┐\n" + + f"|{header}|\n" + + "| Add `.column` to see native output |\n" + + "└" + + "─" * length + + "┘\n" + ) + + def __iter__(self) -> NoReturn: + msg = "" + raise NotImplementedError(msg) + + def _from_series(self, series: pd.Series) -> Column: + return Column( + series.reset_index(drop=True).rename(series.name), + api_version=self._api_version, + df=self._df, + is_persisted=self._is_persisted, + ) + + def _materialise(self) -> pd.Series: + if not self._is_persisted: + msg = "Column is not persisted, please call `.persist()` first.\nNote: `persist` forces computation, use it with care, only when you need to,\nand as late and little as possible." + raise RuntimeError( + msg, + ) + return self.column + + # In the standard + def __column_namespace__( + self, + ) -> dataframe_api_compat.modin_standard.Namespace: + return dataframe_api_compat.modin_standard.Namespace( + api_version=self._api_version, + ) + + def persist(self) -> Column: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on Column that was already persisted", + UserWarning, + stacklevel=2, + ) + return Column( + self.column, + df=None, + api_version=self._api_version, + is_persisted=True, + ) + + @property + def name(self) -> str: + return self._name # type: ignore[no-any-return] + + @property + def column(self) -> pd.Series[Any]: + return self._series + + @property + def dtype(self) -> DType: + return dataframe_api_compat.modin_standard.map_pandas_dtype_to_standard_dtype( + self._series.dtype, + ) + + @property + def parent_dataframe(self) -> DataFrame | None: + return self._df + + def take(self, indices: Column) -> Column: + return self._from_series(self.column.iloc[indices.column]) + + def filter(self, mask: Column) -> Column: + ser = self.column + return self._from_series(ser.loc[mask.column]) + + def get_value(self, row_number: int) -> Any: + ser = self.column + return self._to_scalar( + ser.iloc[row_number], + ) + + def slice_rows( + self, + start: int | None, + stop: int | None, + step: int | None, + ) -> Column: + return self._from_series(self.column.iloc[start:stop:step]) + + # Binary comparisons + + def __eq__(self, other: Column | Any) -> Column: # type: ignore[override] + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser == other).rename(ser.name)) + + def __ne__(self, other: Column | Any) -> Column: # type: ignore[override] + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser != other).rename(ser.name)) + + def __ge__(self, other: Column | Any) -> Column: + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser >= other).rename(ser.name)) + + def __gt__(self, other: Column | Any) -> Column: + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser > other).rename(ser.name)) + + def __le__(self, other: Column | Any) -> Column: + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser <= other).rename(ser.name)) + + def __lt__(self, other: Column | Any) -> Column: + other = validate_comparand(self, other) + ser = self.column + return self._from_series((ser < other).rename(ser.name)) + + def __and__(self, other: Column | bool | Scalar) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser & other).rename(ser.name)) + + def __rand__(self, other: Column | Any) -> Column: + return self.__and__(other) + + def __or__(self, other: Column | bool | Scalar) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser | other).rename(ser.name)) + + def __ror__(self, other: Column | Any) -> Column: + return self.__or__(other) + + def __add__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser + other).rename(ser.name)) + + def __radd__(self, other: Column | Any) -> Column: + return self.__add__(other) + + def __sub__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser - other).rename(ser.name)) + + def __rsub__(self, other: Column | Any) -> Column: + return -1 * self.__sub__(other) + + def __mul__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser * other).rename(ser.name)) + + def __rmul__(self, other: Column | Any) -> Column: + return self.__mul__(other) + + def __truediv__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser / other).rename(ser.name)) + + def __rtruediv__(self, other: Column | Any) -> Column: + raise NotImplementedError + + def __floordiv__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser // other).rename(ser.name)) + + def __rfloordiv__(self, other: Column | Any) -> Column: + raise NotImplementedError + + def __pow__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser**other).rename(ser.name)) + + def __rpow__(self, other: Column | Any) -> Column: # pragma: no cover + raise NotImplementedError + + def __mod__(self, other: Column | Any) -> Column: + ser = self.column + other = validate_comparand(self, other) + return self._from_series((ser % other).rename(ser.name)) + + def __rmod__(self, other: Column | Any) -> Column: # pragma: no cover + raise NotImplementedError + + def __divmod__(self, other: Column | Any) -> tuple[Column, Column]: + quotient = self // other + remainder = self - quotient * other + return quotient, remainder + + # Unary + + def __invert__(self: Column) -> Column: + ser = self.column + return self._from_series(~ser) + + # Reductions + + def any(self, *, skip_nulls: bool | Scalar = True) -> Scalar: + ser = self.column + return self._to_scalar(ser.any()) + + def all(self, *, skip_nulls: bool | Scalar = True) -> Scalar: + ser = self.column + return self._to_scalar(ser.all()) + + def min(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.min()) + + def max(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.max()) + + def sum(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.sum()) + + def prod(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.prod()) + + def median(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.median()) + + def mean(self, *, skip_nulls: bool | Scalar = True) -> Any: + ser = self.column + return self._to_scalar(ser.mean()) + + def std( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> Any: + ser = self.column + return self._to_scalar( + ser.std(ddof=correction), + ) + + def var( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> Any: + ser = self.column + return self._to_scalar( + ser.var(ddof=correction), + ) + + def len(self) -> Scalar: + return self._to_scalar(len(self._series)) + + def n_unique( + self, + *, + skip_nulls: bool = True, + ) -> Scalar: + ser = self.column + return self._to_scalar( + ser.nunique(), + ) + + # Transformations + + def is_null(self) -> Column: + ser = self.column + return self._from_series(ser.isna()) + + def is_nan(self) -> Column: + ser = self.column + if is_extension_array_dtype(ser.dtype): # pragma: no cover + # this is copied from the pandas implementation; + # however this code (in pandas) is only tested for `pandas-nullable` + # TODO: need similar tests for modin + return self._from_series((ser != ser).fillna(False)) # noqa: PLR0124 + return self._from_series(ser.isna()) + + def sort( + self, + *, + ascending: bool = True, + nulls_position: Literal["first", "last"] = "last", + ) -> Column: + ser = self.column + if ascending: + return self._from_series(ser.sort_values().rename(self.name)) + return self._from_series(ser.sort_values().rename(self.name)[::-1]) + + def is_in(self, values: Column) -> Column: + ser = self.column + return self._from_series(ser.isin(values.column)) + + def sorted_indices( + self, + *, + ascending: bool = True, + nulls_position: Literal["first", "last"] = "last", + ) -> Column: + ser = self.column + if ascending: + return self._from_series(ser.sort_values().index.to_series(name=self.name)) + return self._from_series(ser.sort_values().index.to_series(name=self.name)[::-1]) + + def unique_indices( + self, + *, + skip_nulls: bool | Scalar = True, + ) -> Column: # pragma: no cover + msg = "not yet supported" + raise NotImplementedError(msg) + + def fill_nan(self, value: float | NullType | Scalar) -> Column: + ser = self.column.copy() + if is_extension_array_dtype(ser.dtype): # pragma: no cover + # this is copied from the pandas implementation; + # however this code (in pandas) is only tested for `pandas-nullable` + # TODO: need similar tests for modin + if self.__column_namespace__().is_null(value): + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = pd.NA + else: + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = value + else: + if self.__column_namespace__().is_null(value): + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = np.nan + else: + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = value + return self._from_series(ser) + + def fill_null( + self, + value: Any, + ) -> Column: + value = validate_comparand(self, value) + ser = self.column.copy() + if is_extension_array_dtype(ser.dtype): # pragma: no cover + # this is copied from the pandas implementation; + # however this code (in pandas) is only tested for `pandas-nullable` + # TODO: need similar tests for modin + # Mask should include NA values, but not NaN ones + mask = ser.isna() & (~(ser != ser).fillna(False)) # noqa: PLR0124 + ser = ser.where(~mask, value) + else: + ser = ser.fillna(value) + return self._from_series(ser.rename(self.name)) + + def cumulative_sum(self, *, skip_nulls: bool | Scalar = True) -> Column: + ser = self.column + return self._from_series(ser.cumsum()) + + def cumulative_prod(self, *, skip_nulls: bool | Scalar = True) -> Column: + ser = self.column + return self._from_series(ser.cumprod()) + + def cumulative_max(self, *, skip_nulls: bool | Scalar = True) -> Column: + ser = self.column + return self._from_series(ser.cummax()) + + def cumulative_min(self, *, skip_nulls: bool | Scalar = True) -> Column: + ser = self.column + return self._from_series(ser.cummin()) + + def rename(self, name: str | Scalar) -> Column: + ser = self.column + return self._from_series(ser.rename(name)) + + def shift(self, offset: int | Scalar) -> Column: + ser = self.column + return self._from_series(ser.shift(offset)) + + # Conversions + + def to_array(self) -> Any: + ser = self._materialise() + return ser.to_numpy( + dtype=NUMPY_MAPPING.get(self.column.dtype.name, self.column.dtype.name), + ) + + def cast(self, dtype: DType) -> Column: + ser = self.column + pandas_dtype = ( + dataframe_api_compat.modin_standard.map_standard_dtype_to_pandas_dtype( + dtype, + ) + ) + return self._from_series(ser.astype(pandas_dtype)) + + # --- temporal methods --- + + def year(self) -> Column: + ser = self.column + return self._from_series(ser.dt.year) + + def month(self) -> Column: + ser = self.column + return self._from_series(ser.dt.month) + + def day(self) -> Column: + ser = self.column + return self._from_series(ser.dt.day) + + def hour(self) -> Column: + ser = self.column + return self._from_series(ser.dt.hour) + + def minute(self) -> Column: + ser = self.column + return self._from_series(ser.dt.minute) + + def second(self) -> Column: + ser = self.column + return self._from_series(ser.dt.second) + + def microsecond(self) -> Column: + ser = self.column + return self._from_series(ser.dt.microsecond) + + def nanosecond(self) -> Column: + ser = self.column + return self._from_series(ser.dt.microsecond * 1000 + ser.dt.nanosecond) + + def iso_weekday(self) -> Column: + ser = self.column + return self._from_series(ser.dt.weekday + 1) + + def floor(self, frequency: str) -> Column: + frequency = ( + frequency.replace("day", "D") + .replace("hour", "H") + .replace("minute", "T") + .replace("second", "S") + .replace("millisecond", "ms") + .replace("microsecond", "us") + .replace("nanosecond", "ns") + ) + ser = self.column + return self._from_series(ser.dt.floor(frequency)) + + def unix_timestamp( + self, + *, + time_unit: str | Scalar = "s", + ) -> Column: + ser = self.column + if ser.dt.tz is None: + result = ser - datetime(1970, 1, 1) + else: # pragma: no cover (todo: tz-awareness) + result = ser.dt.tz_convert("UTC").dt.tz_localize(None) - datetime(1970, 1, 1) + if time_unit == "s": + result = pd.Series( + np.floor(result.dt.total_seconds().astype("float64")), + name=ser.name, + ) + elif time_unit == "ms": + result = pd.Series( + np.floor( + np.floor(result.dt.total_seconds()) * 1000 + + result.dt.microseconds // 1000, + ), + name=ser.name, + ) + elif time_unit == "us": + result = pd.Series( + np.floor(result.dt.total_seconds()) * 1_000_000 + result.dt.microseconds, + name=ser.name, + ) + elif time_unit == "ns": + result = pd.Series( + ( + np.floor(result.dt.total_seconds()).astype("Int64") * 1_000_000 + + result.dt.microseconds.astype("Int64") + ) + * 1000 + + result.dt.nanoseconds.astype("Int64"), + name=ser.name, + ) + else: # pragma: no cover + msg = "Got invalid time_unit" + raise AssertionError(msg) + return self._from_series(result) diff --git a/dataframe_api_compat/modin_standard/dataframe_object.py b/dataframe_api_compat/modin_standard/dataframe_object.py new file mode 100644 index 00000000..54d757d3 --- /dev/null +++ b/dataframe_api_compat/modin_standard/dataframe_object.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +import collections +import warnings +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterator +from typing import Literal +from typing import NoReturn + +import modin.pandas as pd +import numpy as np +from pandas.api.types import is_extension_array_dtype + +import dataframe_api_compat +from dataframe_api_compat.utils import validate_comparand + +if TYPE_CHECKING: + from collections.abc import Mapping + from collections.abc import Sequence + + from dataframe_api import DataFrame as DataFrameT + from dataframe_api.typing import AnyScalar + from dataframe_api.typing import Column + from dataframe_api.typing import DType + from dataframe_api.typing import NullType + from dataframe_api.typing import Scalar + + from dataframe_api_compat.modin_standard.group_by_object import GroupBy +else: + DataFrameT = object + + +class DataFrame(DataFrameT): + """dataframe object""" + + def __init__( + self, + dataframe: pd.DataFrame, + *, + api_version: str, + is_persisted: bool = False, + ) -> None: + self._is_persisted = is_persisted + self._validate_columns(dataframe.columns) + self._dataframe = dataframe.reset_index(drop=True) + self._api_version = api_version + + # Validation helper methods + + def _validate_is_persisted(self) -> pd.DataFrame: + if not self._is_persisted: + msg = "Method requires you to call `.persist` first.\n\nNote: `.persist` forces materialisation in lazy libraries and so should be called as late as possible in your pipeline. Use with care." + raise ValueError( + msg, + ) + return self.dataframe + + def __repr__(self) -> str: # pragma: no cover + header = f" Standard DataFrame (api_version={self._api_version}) " + length = len(header) + return ( + "┌" + + "─" * length + + "┐\n" + + f"|{header}|\n" + + "| Add `.dataframe` to see native output |\n" + + "└" + + "─" * length + + "┘\n" + ) + + def _validate_columns(self, columns: Sequence[str]) -> None: + counter = collections.Counter(columns) + for col, count in counter.items(): + if count > 1: + msg = f"Expected unique column names, got {col} {count} time(s)" + raise ValueError( + msg, + ) + + def _validate_booleanness(self) -> None: + if not ( + (self.dataframe.dtypes == "bool") | (self.dataframe.dtypes == "boolean") + ).all(): + msg = "'any' can only be called on DataFrame where all dtypes are 'bool'" + raise TypeError( + msg, + ) + + def _from_dataframe(self, df: pd.DataFrame) -> DataFrame: + return DataFrame( + df, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) + + # Properties + @property + def schema(self) -> dict[str, DType]: + return { + column_name: dataframe_api_compat.modin_standard.map_pandas_dtype_to_standard_dtype( + dtype.name, + ) + for column_name, dtype in self.dataframe.dtypes.items() + } + + @property + def dataframe(self) -> pd.DataFrame: + return self._dataframe + + @property + def column_names(self) -> list[str]: + return self.dataframe.columns.tolist() # type: ignore[no-any-return] + + # In the Standard + + def __dataframe_namespace__( + self, + ) -> dataframe_api_compat.modin_standard.Namespace: + return dataframe_api_compat.modin_standard.Namespace( + api_version=self._api_version, + ) + + def iter_columns(self) -> Iterator[Column]: + return (self.col(col_name) for col_name in self.column_names) + + def col(self, name: str) -> Column: + from dataframe_api_compat.modin_standard.column_object import Column + + return Column( + self.dataframe.loc[:, name], + df=None if self._is_persisted else self, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) + + def shape(self) -> tuple[int, int]: + df = self._validate_is_persisted() + return df.shape # type: ignore[no-any-return] + + def group_by(self, *keys: str) -> GroupBy: + from dataframe_api_compat.modin_standard.group_by_object import GroupBy + + for key in keys: + if key not in self.column_names: + msg = f"key {key} not present in DataFrame's columns" + raise KeyError(msg) + return GroupBy(self, keys, api_version=self._api_version) + + def select(self, *columns: str) -> DataFrame: + cols = list(columns) + if cols and isinstance(cols[0], (list, tuple)): + msg = f"Expected iterable of column names, but the first element is: {type(cols[0])}" + raise TypeError(msg) + return self._from_dataframe( + self.dataframe.loc[:, list(columns)], + ) + + def take( + self, + indices: Column, + ) -> DataFrame: + _indices = validate_comparand(self, indices) + return self._from_dataframe( + self.dataframe.iloc[_indices.to_list(), :], + ) + + def slice_rows( + self, + start: int | None, + stop: int | None, + step: int | None, + ) -> DataFrame: + return self._from_dataframe(self.dataframe.iloc[start:stop:step]) + + def filter( + self, + mask: Column, + ) -> DataFrame: + _mask = validate_comparand(self, mask) + df = self.dataframe + df = df.loc[_mask] + return self._from_dataframe(df) + + def assign( + self, + *columns: Column, + ) -> DataFrame: + from dataframe_api_compat.modin_standard.column_object import Column + + df = self.dataframe.copy() # TODO: remove defensive copy with CoW? + for column in columns: + if not isinstance(column, Column): + msg = f"Expected iterable of Column, but the first element is: {type(column)}" + raise TypeError(msg) + _series = validate_comparand(self, column) + df[_series.name] = _series + return self._from_dataframe(df) + + def drop(self, *labels: str) -> DataFrame: + return self._from_dataframe( + self.dataframe.drop(list(labels), axis=1), + ) + + def rename(self, mapping: Mapping[str, str]) -> DataFrame: + if not isinstance(mapping, collections.abc.Mapping): + msg = f"Expected Mapping, got: {type(mapping)}" + raise TypeError(msg) + return self._from_dataframe( + self.dataframe.rename(columns=mapping), + ) + + def get_column_names(self) -> list[str]: # pragma: no cover + # TODO: add a test after modin implements `__dataframe_consortium_standard__` + # DO NOT REMOVE + # This one is used in upstream tests - even if deprecated, + # just leave it in for backwards compatibility + return self.dataframe.columns.tolist() # type: ignore[no-any-return] + + def sort( + self, + *keys: str, + ascending: Sequence[bool] | bool = True, + nulls_position: Literal["first", "last"] = "last", + ) -> DataFrame: + if not keys: + keys = self.dataframe.columns.tolist() + df = self.dataframe + return self._from_dataframe( + df.sort_values(list(keys), ascending=ascending), + ) + + # Binary operations + + def __eq__(self, other: AnyScalar) -> DataFrame: # type: ignore[override] + return self._from_dataframe(self.dataframe.__eq__(other)) + + def __ne__(self, other: AnyScalar) -> DataFrame: # type: ignore[override] + return self._from_dataframe(self.dataframe.__ne__(other)) + + def __ge__(self, other: AnyScalar) -> DataFrame: + return self._from_dataframe(self.dataframe.__ge__(other)) + + def __gt__(self, other: AnyScalar) -> DataFrame: + return self._from_dataframe(self.dataframe.__gt__(other)) + + def __le__(self, other: AnyScalar) -> DataFrame: + return self._from_dataframe(self.dataframe.__le__(other)) + + def __lt__(self, other: AnyScalar) -> DataFrame: + return self._from_dataframe(self.dataframe.__lt__(other)) + + def __and__(self, other: AnyScalar) -> DataFrame: + return self._from_dataframe( + self.dataframe.__and__(other), + ) + + def __rand__(self, other: Column | AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self.__and__(_other) + + def __or__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe(self.dataframe.__or__(_other)) + + def __ror__(self, other: Column | AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self.__or__(_other) + + def __add__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__add__(_other), + ) + + def __radd__(self, other: Column | AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self.__add__(_other) + + def __sub__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__sub__(_other), + ) + + def __rsub__(self, other: Column | AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return -1 * self.__sub__(_other) + + def __mul__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__mul__(_other), + ) + + def __rmul__(self, other: Column | AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self.__mul__(_other) + + def __truediv__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__truediv__(_other), + ) + + def __rtruediv__(self, other: Column | AnyScalar) -> DataFrame: # pragma: no cover + _other = validate_comparand(self, other) + raise NotImplementedError + + def __floordiv__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__floordiv__(_other), + ) + + def __rfloordiv__(self, other: Column | AnyScalar) -> DataFrame: # pragma: no cover + _other = validate_comparand(self, other) + raise NotImplementedError + + def __pow__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__pow__(_other), + ) + + def __rpow__(self, other: Column | AnyScalar) -> DataFrame: # pragma: no cover + _other = validate_comparand(self, other) + raise NotImplementedError + + def __mod__(self, other: AnyScalar) -> DataFrame: + _other = validate_comparand(self, other) + return self._from_dataframe( + self.dataframe.__mod__(other), + ) + + def __rmod__(self, other: Column | AnyScalar) -> DataFrame: # type: ignore[misc] # pragma: no cover + _other = validate_comparand(self, other) + raise NotImplementedError + + def __divmod__( + self, + other: DataFrame | AnyScalar, + ) -> tuple[DataFrame, DataFrame]: + _other = validate_comparand(self, other) + quotient, remainder = self.dataframe.__divmod__(_other) + return self._from_dataframe(quotient), self._from_dataframe( + remainder, + ) + + # Unary + + def __invert__(self) -> DataFrame: + self._validate_booleanness() + return self._from_dataframe(self.dataframe.__invert__()) + + def __iter__(self) -> NoReturn: + raise NotImplementedError + + # Reductions + + def any(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + self._validate_booleanness() + return self._from_dataframe( + self.dataframe.any().to_frame().T, + ) + + def all(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + self._validate_booleanness() + return self._from_dataframe( + self.dataframe.all().to_frame().T, + ) + + def min(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.min().to_frame().T, + ) + + def max(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.max().to_frame().T, + ) + + def sum(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.sum().to_frame().T, + ) + + def prod(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.prod().to_frame().T, + ) + + def median(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.median().to_frame().T, + ) + + def mean(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + return self._from_dataframe( + self.dataframe.mean().to_frame().T, + ) + + def std( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> DataFrame: + return self._from_dataframe( + self.dataframe.std().to_frame().T, + ) + + def var( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> DataFrame: + return self._from_dataframe( + self.dataframe.var().to_frame().T, + ) + + # Transformations + + def is_null(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result: list[pd.Series] = [] + for column in self.dataframe.columns: + result.append(self.dataframe[column].isna()) + return self._from_dataframe(pd.concat(result, axis=1)) + + def is_nan(self) -> DataFrame: + return self.assign(*[col.is_nan() for col in self.iter_columns()]) + + def fill_nan(self, value: float | Scalar | NullType) -> DataFrame: + _value = validate_comparand(self, value) + new_cols = {} + df = self.dataframe + for col in df.columns: + ser = df[col].copy() + if is_extension_array_dtype(ser.dtype): # pragma: no cover + # this is copied from the pandas implementation; + # however this code (in pandas) is only tested for `pandas-nullable` + # TODO: need similar tests for modin + if self.__dataframe_namespace__().is_null(_value): + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = pd.NA + else: + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = _value + else: + if self.__dataframe_namespace__().is_null(_value): + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = np.nan + else: + ser[np.isnan(ser).fillna(False).to_numpy(bool)] = _value + new_cols[col] = ser + df = pd.DataFrame(new_cols) + return self._from_dataframe(df) + + def fill_null( + self, + value: AnyScalar, + *, + column_names: list[str] | None = None, + ) -> DataFrame: + if column_names is None: + column_names = self.dataframe.columns.tolist() + assert isinstance(column_names, list) # help type checkers + return self.assign( + *[ + col.fill_null(value) + for col in self.iter_columns() + if col.name in column_names + ], + ) + + def drop_nulls( + self, + *, + column_names: list[str] | None = None, + ) -> DataFrame: + namespace = self.__dataframe_namespace__() + mask = ~namespace.any_horizontal( + *[ + self.col(col_name).is_null() + for col_name in column_names or self.column_names + ], + ) + return self.filter(mask) + + # Other + + def join( + self, + other: DataFrame, + *, + how: Literal["left", "inner", "outer"], + left_on: str | list[str], + right_on: str | list[str], + ) -> DataFrame: + if how not in ["left", "inner", "outer"]: + msg = f"Expected 'left', 'inner', 'outer', got: {how}" + raise ValueError(msg) + + if isinstance(left_on, str): + left_on = [left_on] + if isinstance(right_on, str): + right_on = [right_on] + + if overlap := (set(self.column_names) - set(left_on)).intersection( + set(other.column_names) - set(right_on), + ): + msg = f"Found overlapping columns in join: {overlap}. Please rename columns to avoid this." + raise ValueError(msg) + + return self._from_dataframe( + self.dataframe.merge( + other.dataframe, + left_on=left_on, + right_on=right_on, + how=how, + ), + ) + + def persist(self) -> DataFrame: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on DataFrame that was already persisted", + UserWarning, + stacklevel=2, + ) + return DataFrame( + self.dataframe, + api_version=self._api_version, + is_persisted=True, + ) + + # Conversion + + def to_array(self, dtype: DType | None = None) -> Any: + self._validate_is_persisted() + return self.dataframe.to_numpy() + + def cast(self, dtypes: Mapping[str, DType]) -> DataFrame: + from dataframe_api_compat.modin_standard import map_standard_dtype_to_pandas_dtype + + df = self._dataframe + return self._from_dataframe( + df.astype( + { + col: map_standard_dtype_to_pandas_dtype(dtype) + for col, dtype in dtypes.items() + }, + ), + ) diff --git a/dataframe_api_compat/modin_standard/group_by_object.py b/dataframe_api_compat/modin_standard/group_by_object.py new file mode 100644 index 00000000..2e0f5a73 --- /dev/null +++ b/dataframe_api_compat/modin_standard/group_by_object.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import cast + +import modin.pandas as pd + +from dataframe_api_compat.modin_standard import Namespace +from dataframe_api_compat.modin_standard.dataframe_object import DataFrame + +if TYPE_CHECKING: + from collections.abc import Sequence + + from dataframe_api import Aggregation as AggregationT + from dataframe_api import GroupBy as GroupByT + from dataframe_api.typing import NullType + from dataframe_api.typing import Scalar + + +else: + GroupByT = object + + +class GroupBy(GroupByT): + def __init__(self, df: DataFrame, keys: Sequence[str], api_version: str) -> None: + self._df = df.dataframe + self._is_persisted = df._is_persisted + self._grouped = self._df.groupby(list(keys), sort=False, as_index=False) + self._keys = list(keys) + self._api_version = api_version + + def _validate_result(self, result: pd.DataFrame) -> None: + failed_columns = self._df.columns.difference(result.columns) + if len(failed_columns) > 0: # pragma: no cover + msg = "Groupby operation could not be performed on columns " + f"{failed_columns}. Please drop them before calling group_by." + raise AssertionError( + msg, + ) + + def _validate_booleanness(self) -> None: + if not ( + (self._df.drop(columns=self._keys).dtypes == "bool") + | (self._df.drop(columns=self._keys).dtypes == "boolean") + ).all(): + msg = "'function' can only be called on DataFrame where all dtypes are 'bool'" + raise TypeError( + msg, + ) + + def _to_dataframe(self, result: pd.DataFrame) -> DataFrame: + return DataFrame( + result, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) + + def size(self) -> DataFrame: + return self._to_dataframe(self._grouped.size()) + + def any(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + self._validate_booleanness() + result = self._grouped.any() + self._validate_result(result) + return self._to_dataframe(result) + + def all(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + self._validate_booleanness() + result = self._grouped.all() + self._validate_result(result) + return self._to_dataframe(result) + + def min(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.min() + self._validate_result(result) + return self._to_dataframe(result) + + def max(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.max() + self._validate_result(result) + return self._to_dataframe(result) + + def sum(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.sum() + self._validate_result(result) + return self._to_dataframe(result) + + def prod(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.prod() + self._validate_result(result) + return self._to_dataframe(result) + + def median(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.median() + self._validate_result(result) + return self._to_dataframe(result) + + def mean(self, *, skip_nulls: bool | Scalar = True) -> DataFrame: + result = self._grouped.mean() + self._validate_result(result) + return self._to_dataframe(result) + + def std( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> DataFrame: + result = self._grouped.std() + self._validate_result(result) + return self._to_dataframe(result) + + def var( + self, + *, + correction: float | Scalar | NullType = 1.0, + skip_nulls: bool | Scalar = True, + ) -> DataFrame: + result = self._grouped.var() + self._validate_result(result) + return self._to_dataframe(result) + + def aggregate( + self, + *aggregations: AggregationT, + ) -> DataFrame: + aggregations = validate_aggregations(*aggregations, keys=self._keys) + df = self._grouped.agg( + **{ + aggregation.output_name: resolve_aggregation( # type: ignore[attr-defined] + aggregation, + ) + for aggregation in aggregations + }, + ) + return self._to_dataframe( + df, + ) + + +def validate_aggregations( + *aggregations: AggregationT, + keys: Sequence[str], +) -> tuple[AggregationT, ...]: + return tuple( + ( + aggregation + if aggregation.aggregation != "size" # type: ignore[attr-defined] + else aggregation._replace(column_name=keys[0]) # type: ignore[attr-defined] + ) + for aggregation in aggregations + ) + + +def resolve_aggregation(aggregation: AggregationT) -> pd.NamedAgg: + aggregation = cast(Namespace.Aggregation, aggregation) + return pd.NamedAgg( + column=aggregation.column_name, + aggfunc=aggregation.aggregation, + ) diff --git a/dataframe_api_compat/modin_standard/scalar_object.py b/dataframe_api_compat/modin_standard/scalar_object.py new file mode 100644 index 00000000..a6d2c44d --- /dev/null +++ b/dataframe_api_compat/modin_standard/scalar_object.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING +from typing import Any + +from dataframe_api_compat.utils import validate_comparand + +if TYPE_CHECKING: + from dataframe_api.typing import DType + from dataframe_api.typing import Namespace + from dataframe_api.typing import Scalar as ScalarT + + from dataframe_api_compat.modin_standard.dataframe_object import DataFrame +else: + ScalarT = object + + +class Scalar(ScalarT): + def __init__( + self, + value: Any, + api_version: str, + df: DataFrame | None, + *, + is_persisted: bool = False, + ) -> None: + self._value = value + self._api_version = api_version + self._df = df + self._is_persisted = is_persisted + assert is_persisted ^ (df is not None) + + def __scalar_namespace__(self) -> Namespace: + from dataframe_api_compat.modin_standard import Namespace + + return Namespace(api_version=self._api_version) + + def _from_scalar(self, scalar: Scalar) -> Scalar: + return Scalar( + scalar, + df=self._df, + api_version=self._api_version, + is_persisted=self._is_persisted, + ) + + @property + def dtype(self) -> DType: # pragma: no cover # todo + msg = "dtype not yet implemented for Scalar" + raise NotImplementedError(msg) + + @property + def scalar(self) -> Any: # pragma: no cover # todo + return self._value + + @property + def parent_dataframe(self) -> Any: # pragma: no cover # todo + return self._df + + def _materialise(self) -> Any: + if not self._is_persisted: + msg = "Can't call __bool__ on Scalar. Please use .persist() first." + raise RuntimeError(msg) + return self._value + + def persist(self) -> Scalar: + if self._is_persisted: + warnings.warn( + "Calling `.persist` on Scalar that was already persisted", + UserWarning, + stacklevel=2, + ) + return Scalar( + self._value, + df=None, + api_version=self._api_version, + is_persisted=True, + ) + + def __lt__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__lt__(other)) + + def __le__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__le__(other)) + + def __eq__(self, other: Any) -> Scalar: # type: ignore[override] + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__eq__(other)) + + def __ne__(self, other: Any) -> Scalar: # type: ignore[override] + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__ne__(other)) + + def __gt__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__gt__(other)) + + def __ge__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__ge__(other)) + + def __add__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__add__(other)) + + def __radd__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other + self._value) + + def __sub__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__sub__(other)) + + def __rsub__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other - self._value) + + def __mul__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__mul__(other)) + + def __rmul__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other * self._value) + + def __mod__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__mod__(other)) + + def __rmod__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other % self._value) + + def __pow__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__pow__(other)) + + def __rpow__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other**self._value) + + def __floordiv__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__floordiv__(other)) + + def __rfloordiv__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other // self._value) + + def __truediv__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(self._value.__truediv__(other)) + + def __rtruediv__(self, other: Any) -> Scalar: + other = validate_comparand(self, other) + if other is NotImplemented: + return NotImplemented + return self._from_scalar(other / self._value) + + def __neg__(self) -> Scalar: + return self._from_scalar(self._value.__neg__()) + + def __abs__(self) -> Scalar: + return self._from_scalar(self._value.__abs__()) + + def __bool__(self) -> bool: + return self._materialise().__bool__() # type: ignore[no-any-return] + + def __int__(self) -> int: + return self._materialise().__int__() # type: ignore[no-any-return] + + def __float__(self) -> float: + return self._materialise().__float__() # type: ignore[no-any-return] + + def __repr__(self) -> str: # pragma: no cover + header = f" Standard Scalar (api_version={self._api_version}) " + length = len(header) + return ( + "┌" + + "─" * length + + "┐\n" + + f"|{header}|\n" + + "| Add `.scalar` to see native output |\n" + + "└" + + "─" * length + + "┘\n" + ) diff --git a/mypy.ini b/mypy.ini index 37b6d040..f0a4c48d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,3 +6,9 @@ disable_error_code=empty-body [mypy-pandas.*] ignore_missing_imports = True + +[mypy-modin.*] +ignore_missing_imports = True + +[mypy-polars.*] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 6cf0717f..9d766a5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,13 @@ line-length = 90 filterwarnings = [ "error", 'ignore:distutils Version classes are deprecated:DeprecationWarning', + # modin specific + 'ignore:.*pkg_resources:DeprecationWarning', + 'ignore:Ray execution environment not yet initialized:UserWarning', + "ignore:Distributing object.:UserWarning", + 'ignore:.*ray:ResourceWarning', + 'ignore:.*is not currently supported by PandasOnRay:UserWarning', + 'ignore:.*implementation has mismatches with pandas:UserWarning', ] xfail_strict = true diff --git a/requirements-dev-modin.txt b/requirements-dev-modin.txt new file mode 100644 index 00000000..cf8cdf4b --- /dev/null +++ b/requirements-dev-modin.txt @@ -0,0 +1,5 @@ +covdefaults +modin[ray] +pre-commit +pytest +pytest-cov diff --git a/requirements-dev.txt b/requirements-dev.txt index dcc57d5f..57a44240 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ covdefaults pandas -polars +polars<0.20.8 pre-commit pyarrow pytest diff --git a/tests/column/and_or_test.py b/tests/column/and_or_test.py index 2fe5846c..3cdd98d5 100644 --- a/tests/column/and_or_test.py +++ b/tests/column/and_or_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_column_with_reference -def test_column_and(library: str) -> None: +def test_column_and(library: BaseHandler) -> None: df = bool_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() ser = df.col("a") @@ -14,7 +15,7 @@ def test_column_and(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_column_or(library: str) -> None: +def test_column_or(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -24,7 +25,7 @@ def test_column_or(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_column_and_with_scalar(library: str) -> None: +def test_column_and_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -34,7 +35,7 @@ def test_column_and_with_scalar(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_column_or_with_scalar(library: str) -> None: +def test_column_or_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/any_all_test.py b/tests/column/any_all_test.py index a775dfff..9f59786f 100644 --- a/tests/column/any_all_test.py +++ b/tests/column/any_all_test.py @@ -2,10 +2,11 @@ import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 -def test_expr_any(library: str) -> None: +def test_expr_any(library: BaseHandler) -> None: df = bool_dataframe_1(library) with pytest.raises(RuntimeError): bool(df.col("a").any()) @@ -15,7 +16,7 @@ def test_expr_any(library: str) -> None: assert bool(result.persist()) -def test_expr_all(library: str) -> None: +def test_expr_all(library: BaseHandler) -> None: df = bool_dataframe_1(library).persist() result = df.col("a").all() assert not bool(result) diff --git a/tests/column/cast_test.py b/tests/column/cast_test.py index 111a4300..2e206b04 100644 --- a/tests/column/cast_test.py +++ b/tests/column/cast_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_cast_integers(library: str) -> None: +def test_cast_integers(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.assign(df.col("a").cast(ns.Int32())) diff --git a/tests/column/col_sorted_indices_test.py b/tests/column/col_sorted_indices_test.py index bd46afe5..2bcd0850 100644 --- a/tests/column/col_sorted_indices_test.py +++ b/tests/column/col_sorted_indices_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_6 -def test_expression_sorted_indices_ascending(library: str) -> None: +def test_expression_sorted_indices_ascending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() col = df.col @@ -14,7 +15,7 @@ def test_expression_sorted_indices_ascending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_expression_sorted_indices_descending(library: str) -> None: +def test_expression_sorted_indices_descending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() col = df.col @@ -24,7 +25,7 @@ def test_expression_sorted_indices_descending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_column_sorted_indices_ascending(library: str) -> None: +def test_column_sorted_indices_ascending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() sorted_indices = df.col("b").sorted_indices() @@ -33,7 +34,7 @@ def test_column_sorted_indices_ascending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_column_sorted_indices_descending(library: str) -> None: +def test_column_sorted_indices_descending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() sorted_indices = df.col("b").sorted_indices(ascending=False) diff --git a/tests/column/col_to_array_object_test.py b/tests/column/col_to_array_object_test.py index ab87f8aa..7bd50bf8 100644 --- a/tests/column/col_to_array_object_test.py +++ b/tests/column/col_to_array_object_test.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import integer_dataframe_1 @@ -22,21 +23,21 @@ "float64", ], ) -def test_column_to_array_object(library: str, dtype: str) -> None: # noqa: ARG001 +def test_column_to_array_object(library: BaseHandler, dtype: str) -> None: # noqa: ARG001 ser = integer_dataframe_1(library).col("a").persist() result = np.asarray(ser.to_array()) expected = np.array([1, 2, 3], dtype=np.int64) np.testing.assert_array_equal(result, expected) -def test_column_to_array_object_bool(library: str) -> None: +def test_column_to_array_object_bool(library: BaseHandler) -> None: df = bool_dataframe_1(library).persist().col("a") result = np.asarray(df.to_array()) expected = np.array([True, True, False], dtype="bool") np.testing.assert_array_equal(result, expected) -def test_column_to_array_object_invalid(library: str) -> None: +def test_column_to_array_object_invalid(library: BaseHandler) -> None: df = bool_dataframe_1(library).col("a") with pytest.raises(RuntimeError): _ = np.asarray(df.to_array()) diff --git a/tests/column/comparisons_test.py b/tests/column/comparisons_test.py index 9ba58710..fdf2cdca 100644 --- a/tests/column/comparisons_test.py +++ b/tests/column/comparisons_test.py @@ -4,6 +4,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_7 @@ -28,7 +29,7 @@ ], ) def test_column_comparisons( - library: str, + library: BaseHandler, comparison: str, expected_data: list[object], expected_dtype: str, @@ -40,7 +41,7 @@ def test_column_comparisons( other = df.col("b") result = df.assign(getattr(ser, comparison)(other).rename("result")) expected_ns_dtype = getattr(ns, expected_dtype) - if comparison == "__pow__" and library in ("polars", "polars-lazy"): + if comparison == "__pow__" and library.name in ("polars", "polars-lazy"): # TODO result = result.cast({"result": ns.Int64()}) expected_ns_dtype = ns.Int64 @@ -66,7 +67,7 @@ def test_column_comparisons( ], ) def test_column_comparisons_scalar( - library: str, + library: BaseHandler, comparison: str, expected_data: list[object], expected_dtype: str, @@ -78,7 +79,7 @@ def test_column_comparisons_scalar( other = 3 result = df.assign(getattr(ser, comparison)(other).rename("result")) expected_ns_dtype = getattr(ns, expected_dtype) - if comparison == "__pow__" and library in ("polars", "polars-lazy"): + if comparison == "__pow__" and library.name in ("polars", "polars-lazy"): result = result.cast({"result": ns.Int64()}) expected_ns_dtype = ns.Int64 compare_column_with_reference(result.col("result"), expected_data, expected_ns_dtype) @@ -93,7 +94,7 @@ def test_column_comparisons_scalar( ], ) def test_right_column_comparisons( - library: str, + library: BaseHandler, comparison: str, expected_data: list[object], ) -> None: diff --git a/tests/column/cross_df_comparisons_test.py b/tests/column/cross_df_comparisons_test.py index d7cca76e..26847551 100644 --- a/tests/column/cross_df_comparisons_test.py +++ b/tests/column/cross_df_comparisons_test.py @@ -2,16 +2,17 @@ import pytest +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 -def test_invalid_comparisons(library: str) -> None: +def test_invalid_comparisons(library: BaseHandler) -> None: with pytest.raises(ValueError): _ = integer_dataframe_1(library).col("a") > integer_dataframe_2(library).col("a") -def test_invalid_comparisons_scalar(library: str) -> None: +def test_invalid_comparisons_scalar(library: BaseHandler) -> None: with pytest.raises(ValueError): _ = ( integer_dataframe_1(library).col("a") diff --git a/tests/column/cumulative_test.py b/tests/column/cumulative_test.py index 2c229b82..8405d62a 100644 --- a/tests/column/cumulative_test.py +++ b/tests/column/cumulative_test.py @@ -5,6 +5,7 @@ from packaging.version import Version from packaging.version import parse +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 @@ -19,7 +20,7 @@ ], ) def test_cumulative_functions_column( - library: str, + library: BaseHandler, func: str, expected_data: list[float], ) -> None: @@ -30,7 +31,7 @@ def test_cumulative_functions_column( result = df.assign(getattr(ser, func)().rename("result")) if ( - parse(pd.__version__) < Version("2.0.0") and library == "pandas-nullable" + parse(pd.__version__) < Version("2.0.0") and library.name == "pandas-nullable" ): # pragma: no cover # Upstream bug result = result.cast({"result": ns.Int64()}) diff --git a/tests/column/divmod_test.py b/tests/column/divmod_test.py index dd16fec6..b56416a3 100644 --- a/tests/column/divmod_test.py +++ b/tests/column/divmod_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 -def test_expression_divmod(library: str) -> None: +def test_expression_divmod(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -18,7 +19,7 @@ def test_expression_divmod(library: str) -> None: compare_column_with_reference(result.col("result"), [1, 2, 3], dtype=ns.Int64) -def test_expression_divmod_with_scalar(library: str) -> None: +def test_expression_divmod_with_scalar(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/fill_nan_test.py b/tests/column/fill_nan_test.py index 137dd4e0..3b6c37c9 100644 --- a/tests/column/fill_nan_test.py +++ b/tests/column/fill_nan_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import nan_dataframe_1 -def test_column_fill_nan(library: str) -> None: +def test_column_fill_nan(library: BaseHandler) -> None: # TODO: test with nullable pandas, check null isn't filled df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() @@ -14,7 +15,7 @@ def test_column_fill_nan(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Float64) -def test_column_fill_nan_with_null(library: str) -> None: +def test_column_fill_nan_with_null(library: BaseHandler) -> None: # TODO: test with nullable pandas, check null isn't filled df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() diff --git a/tests/column/fill_null_test.py b/tests/column/fill_null_test.py index 611c9efb..c946420d 100644 --- a/tests/column/fill_null_test.py +++ b/tests/column/fill_null_test.py @@ -1,23 +1,24 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import nan_dataframe_1 from tests.utils import null_dataframe_2 -def test_fill_null_column(library: str) -> None: +def test_fill_null_column(library: BaseHandler) -> None: df = null_dataframe_2(library) ser = df.col("a") result = df.assign(ser.fill_null(0).rename("result")).col("result") - assert float(result.get_value(2).persist()) == 0.0 # type:ignore[arg-type] - assert float(result.get_value(1).persist()) != 0.0 # type:ignore[arg-type] - assert float(result.get_value(0).persist()) != 0.0 # type:ignore[arg-type] + assert float(result.get_value(2).persist()) == 0.0 # type: ignore[arg-type] + assert float(result.get_value(1).persist()) != 0.0 # type: ignore[arg-type] + assert float(result.get_value(0).persist()) != 0.0 # type: ignore[arg-type] -def test_fill_null_noop_column(library: str) -> None: +def test_fill_null_noop_column(library: BaseHandler) -> None: df = nan_dataframe_1(library) ser = df.col("a") result = df.assign(ser.fill_null(0).rename("result")).persist().col("result") - if library != "pandas-numpy": + if library.name not in ("pandas-numpy", "modin"): # nan should not have changed! assert float(result.get_value(2)) != float( # type: ignore[arg-type] result.get_value(2), # type: ignore[arg-type] diff --git a/tests/column/get_rows_by_mask_test.py b/tests/column/get_rows_by_mask_test.py index 2a170a0e..5e95ce2c 100644 --- a/tests/column/get_rows_by_mask_test.py +++ b/tests/column/get_rows_by_mask_test.py @@ -2,11 +2,12 @@ import pandas as pd +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 -def test_column_filter(library: str) -> None: +def test_column_filter(library: BaseHandler) -> None: df = integer_dataframe_1(library) ser = df.col("a") mask = ser > 1 @@ -16,7 +17,7 @@ def test_column_filter(library: str) -> None: pd.testing.assert_series_equal(result_pd, expected) -def test_column_take_by_mask_noop(library: str) -> None: +def test_column_take_by_mask_noop(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/get_rows_test.py b/tests/column/get_rows_test.py index 9254ef45..93dd3c8c 100644 --- a/tests/column/get_rows_test.py +++ b/tests/column/get_rows_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 -def test_expression_take(library: str) -> None: +def test_expression_take(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/get_value_test.py b/tests/column/get_value_test.py index 49a30c89..af333d46 100644 --- a/tests/column/get_value_test.py +++ b/tests/column/get_value_test.py @@ -1,13 +1,14 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_get_value(library: str) -> None: +def test_get_value(library: BaseHandler) -> None: result = integer_dataframe_1(library).persist().col("a").get_value(0) assert int(result) == 1 # type: ignore[call-overload] -def test_mean_scalar(library: str) -> None: +def test_mean_scalar(library: BaseHandler) -> None: result = integer_dataframe_1(library).persist().col("a").max() assert int(result) == 3 # type: ignore[call-overload] diff --git a/tests/column/invalid_pandas_test.py b/tests/column/invalid_pandas_test.py deleted file mode 100644 index 1505baf8..00000000 --- a/tests/column/invalid_pandas_test.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -import pandas as pd -import pytest - -from dataframe_api_compat.pandas_standard import convert_to_standard_compliant_dataframe - - -def test_repeated_columns() -> None: - df = pd.DataFrame({"a": [1, 2]}, index=["b", "b"]).T - with pytest.raises( - ValueError, - match=r"Expected unique column names, got b 2 time\(s\)", - ): - convert_to_standard_compliant_dataframe(df, "2023.08-beta") diff --git a/tests/column/invert_test.py b/tests/column/invert_test.py index b6003eea..32a39fe5 100644 --- a/tests/column/invert_test.py +++ b/tests/column/invert_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_column_with_reference -def test_expression_invert(library: str) -> None: +def test_expression_invert(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -13,7 +14,7 @@ def test_expression_invert(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_column_invert(library: str) -> None: +def test_column_invert(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/is_in_test.py b/tests/column/is_in_test.py index f840706b..d26edcf0 100644 --- a/tests/column/is_in_test.py +++ b/tests/column/is_in_test.py @@ -5,6 +5,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import float_dataframe_1 from tests.utils import float_dataframe_2 @@ -24,8 +25,8 @@ ) @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated") def test_is_in( - library: str, - df_factory: Callable[[str], Any], + library: BaseHandler, + df_factory: Callable[[BaseHandler], Any], expected_values: list[bool], ) -> None: df = df_factory(library) @@ -46,8 +47,8 @@ def test_is_in( ) @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated") def test_expr_is_in( - library: str, - df_factory: Callable[[str], Any], + library: BaseHandler, + df_factory: Callable[[BaseHandler], Any], expected_values: list[bool], ) -> None: df = df_factory(library) diff --git a/tests/column/is_nan_test.py b/tests/column/is_nan_test.py index b0d04025..3e0694fb 100644 --- a/tests/column/is_nan_test.py +++ b/tests/column/is_nan_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import nan_dataframe_1 -def test_column_is_nan(library: str) -> None: +def test_column_is_nan(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/is_null_test.py b/tests/column/is_null_test.py index fdc8e34b..b39cfdf9 100644 --- a/tests/column/is_null_test.py +++ b/tests/column/is_null_test.py @@ -1,23 +1,24 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import nan_dataframe_1 from tests.utils import null_dataframe_1 -def test_column_is_null_1(library: str) -> None: +def test_column_is_null_1(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") result = df.assign(ser.is_null().rename("result")) - if library == "pandas-numpy": + if library.name == "pandas-numpy": expected = [False, False, True] else: expected = [False, False, False] compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_column_is_null_2(library: str) -> None: +def test_column_is_null_2(library: BaseHandler) -> None: df = null_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") diff --git a/tests/column/len_test.py b/tests/column/len_test.py index 7008cf2e..ae61df34 100644 --- a/tests/column/len_test.py +++ b/tests/column/len_test.py @@ -1,8 +1,9 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_column_len(library: str) -> None: +def test_column_len(library: BaseHandler) -> None: result = integer_dataframe_1(library).col("a").len().persist().scalar assert result == 3 diff --git a/tests/column/n_unique_test.py b/tests/column/n_unique_test.py index 3f7d9f36..3f824fc4 100644 --- a/tests/column/n_unique_test.py +++ b/tests/column/n_unique_test.py @@ -1,8 +1,9 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_column_len(library: str) -> None: +def test_column_len(library: BaseHandler) -> None: result = integer_dataframe_1(library).col("a").n_unique().persist().scalar assert result == 3 diff --git a/tests/column/name_test.py b/tests/column/name_test.py index efd5934d..94ba559c 100644 --- a/tests/column/name_test.py +++ b/tests/column/name_test.py @@ -5,11 +5,12 @@ from packaging.version import Version from packaging.version import parse +from tests.utils import BaseHandler from tests.utils import convert_to_standard_compliant_dataframe from tests.utils import integer_dataframe_1 -def test_name(library: str) -> None: +def test_name(library: BaseHandler) -> None: df = integer_dataframe_1(library).persist() name = df.col("a").name assert name == "a" diff --git a/tests/column/non_unique_column_names_test.py b/tests/column/non_unique_column_names_test.py new file mode 100644 index 00000000..f4f5b440 --- /dev/null +++ b/tests/column/non_unique_column_names_test.py @@ -0,0 +1,33 @@ +from typing import Any + +import pytest + +from tests.utils import BaseHandler + + +def test_repeated_columns(library: BaseHandler) -> None: + convert_to_standard_compliant_dataframe: Any + if library.name in ("pandas-numpy", "pandas-nullable"): + import pandas as pd + + from dataframe_api_compat.pandas_standard import ( + convert_to_standard_compliant_dataframe, + ) + + df = pd.DataFrame([[1, 2]], columns=["b", "b"]) + elif library.name == "modin": + import modin.pandas as pd + + from dataframe_api_compat.modin_standard import ( + convert_to_standard_compliant_dataframe, + ) + + df = pd.DataFrame([[1, 2]], columns=["b", "b"]) + else: # pragma: no cover + msg = f"Not supported library: {library}" + raise AssertionError(msg) + with pytest.raises( + ValueError, + match=r"Expected unique column names, got b 2 time\(s\)", + ): + convert_to_standard_compliant_dataframe(df, "2023.08-beta") diff --git a/tests/column/parent_dataframe_test.py b/tests/column/parent_dataframe_test.py index 63ca6dd4..7122c675 100644 --- a/tests/column/parent_dataframe_test.py +++ b/tests/column/parent_dataframe_test.py @@ -1,6 +1,7 @@ +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_parent_dataframe(library: str) -> None: +def test_parent_dataframe(library: BaseHandler) -> None: df = integer_dataframe_1(library) assert df.col("a").parent_dataframe is df diff --git a/tests/column/pow_test.py b/tests/column/pow_test.py index 253a7218..c34e3aed 100644 --- a/tests/column/pow_test.py +++ b/tests/column/pow_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_float_powers_column(library: str) -> None: +def test_float_powers_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -15,7 +16,7 @@ def test_float_powers_column(library: str) -> None: compare_dataframe_with_reference(result, expected, expected_dtype) # type: ignore[arg-type] -def test_float_powers_scalar_column(library: str) -> None: +def test_float_powers_scalar_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") @@ -26,26 +27,26 @@ def test_float_powers_scalar_column(library: str) -> None: compare_dataframe_with_reference(result, expected, expected_dtype) # type: ignore[arg-type] -def test_int_powers_column(library: str) -> None: +def test_int_powers_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") other = df.col("b") * 1 result = df.assign(ser.__pow__(other).rename("result")) - if library in ("polars", "polars-lazy"): + if library.name in ("polars", "polars-lazy"): result = result.cast({name: ns.Int64() for name in ("a", "b", "result")}) expected = {"a": [1, 2, 3], "b": [4, 5, 6], "result": [1, 32, 729]} expected_dtype = {name: ns.Int64 for name in ("a", "b", "result")} compare_dataframe_with_reference(result, expected, expected_dtype) -def test_int_powers_scalar_column(library: str) -> None: +def test_int_powers_scalar_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() ser = df.col("a") other = 1 result = df.assign(ser.__pow__(other).rename("result")) - if library in ("polars", "polars-lazy"): + if library.name in ("polars", "polars-lazy"): result = result.cast({name: ns.Int64() for name in ("a", "b", "result")}) expected = {"a": [1, 2, 3], "b": [4, 5, 6], "result": [1, 2, 3]} expected_dtype = {name: ns.Int64 for name in ("a", "b", "result")} diff --git a/tests/column/reductions_test.py b/tests/column/reductions_test.py index 25d85d8b..107399d4 100644 --- a/tests/column/reductions_test.py +++ b/tests/column/reductions_test.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 @@ -20,7 +21,7 @@ ], ) def test_expression_reductions( - library: str, + library: BaseHandler, reduction: str, expected: float, expected_dtype: str, diff --git a/tests/column/rename_test.py b/tests/column/rename_test.py index 7904fb54..419bc922 100644 --- a/tests/column/rename_test.py +++ b/tests/column/rename_test.py @@ -1,9 +1,10 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_rename(library: str) -> None: +def test_rename(library: BaseHandler) -> None: df = integer_dataframe_1(library).persist() ser = df.col("a") result = ser.rename("new_name") diff --git a/tests/column/schema_test.py b/tests/column/schema_test.py index 20064924..d933a8db 100644 --- a/tests/column/schema_test.py +++ b/tests/column/schema_test.py @@ -3,15 +3,16 @@ import pytest from packaging.version import Version -from tests.utils import PANDAS_VERSION +from tests.utils import BaseHandler from tests.utils import mixed_dataframe_1 +from tests.utils import pandas_version -@pytest.mark.skipif( - Version("2.0.0") > PANDAS_VERSION, - reason="no pyarrow support", -) -def test_schema(library: str) -> None: +def test_schema(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable") and pandas_version() < Version( + "2.0.0", + ): # pragma: no cover + pytest.skip(reason="no pyarrow support") df = mixed_dataframe_1(library) namespace = df.__dataframe_namespace__() result = df.col("a").dtype diff --git a/tests/column/shift_test.py b/tests/column/shift_test.py index 86c084ce..a01ab025 100644 --- a/tests/column/shift_test.py +++ b/tests/column/shift_test.py @@ -1,44 +1,32 @@ -import pandas as pd -import polars as pl -from polars.testing import assert_frame_equal - +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import float_dataframe_1 from tests.utils import integer_dataframe_1 -def test_shift_with_fill_value(library: str) -> None: +def test_shift_with_fill_value(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.assign(df.col("a").shift(1).fill_null(999)) - if library == "pandas-numpy": + if library.name in ("pandas-numpy", "modin"): result = result.cast({name: ns.Int64() for name in ("a", "b")}) expected = {"a": [999, 1, 2], "b": [4, 5, 6]} compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_shift_without_fill_value(library: str) -> None: +def test_shift_without_fill_value(library: BaseHandler) -> None: df = float_dataframe_1(library) + ns = df.__dataframe_namespace__() result = df.assign(df.col("a").shift(-1)) - if library == "pandas-numpy": - expected = pd.DataFrame({"a": [3.0, float("nan")]}) - pd.testing.assert_frame_equal(result.dataframe, expected) - elif library == "pandas-nullable": - expected = pd.DataFrame({"a": [3.0, None]}, dtype="Float64") - pd.testing.assert_frame_equal(result.dataframe, expected) - elif library == "polars-lazy": - expected = pl.DataFrame({"a": [3.0, None]}) - assert_frame_equal(result.dataframe.collect(), expected) # type: ignore[attr-defined] - else: # pragma: no cover - msg = "unexpected library" - raise AssertionError(msg) + expected = {"a": [3.0, float("nan")]} + compare_dataframe_with_reference(result, expected, dtype=ns.Float64) -def test_shift_with_fill_value_complicated(library: str) -> None: +def test_shift_with_fill_value_complicated(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.assign(df.col("a").shift(1).fill_null(df.col("a").mean())) - if library == "pandas-nullable": + if library.name == "pandas-nullable": result = result.cast({"a": ns.Float64()}) expected = {"a": [2.0, 1, 2], "b": [4, 5, 6]} expected_dtype = {"a": ns.Float64, "b": ns.Int64} diff --git a/tests/column/slice_rows_test.py b/tests/column/slice_rows_test.py index b37ba783..6047a5de 100644 --- a/tests/column/slice_rows_test.py +++ b/tests/column/slice_rows_test.py @@ -5,6 +5,7 @@ import pandas as pd import pytest +from tests.utils import BaseHandler from tests.utils import integer_dataframe_3 @@ -18,7 +19,7 @@ ], ) def test_column_slice_rows( - library: str, + library: BaseHandler, start: int | None, stop: int | None, step: int | None, diff --git a/tests/column/sort_test.py b/tests/column/sort_test.py index 7cafc5e3..045886a0 100644 --- a/tests/column/sort_test.py +++ b/tests/column/sort_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_6 -def test_expression_sort_ascending(library: str) -> None: +def test_expression_sort_ascending(library: BaseHandler) -> None: df = integer_dataframe_6(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() s_sorted = df.col("b").sort().rename("c") @@ -17,7 +18,7 @@ def test_expression_sort_ascending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_expression_sort_descending(library: str) -> None: +def test_expression_sort_descending(library: BaseHandler) -> None: df = integer_dataframe_6(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() s_sorted = df.col("b").sort(ascending=False).rename("c") @@ -30,7 +31,7 @@ def test_expression_sort_descending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_column_sort_ascending(library: str) -> None: +def test_column_sort_ascending(library: BaseHandler) -> None: df = integer_dataframe_6(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() s_sorted = df.col("b").sort().rename("c") @@ -43,7 +44,7 @@ def test_column_sort_ascending(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_column_sort_descending(library: str) -> None: +def test_column_sort_descending(library: BaseHandler) -> None: df = integer_dataframe_6(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() s_sorted = df.col("b").sort(ascending=False).rename("c") diff --git a/tests/column/statistics_test.py b/tests/column/statistics_test.py index b7e84868..02ff1716 100644 --- a/tests/column/statistics_test.py +++ b/tests/column/statistics_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 -def test_mean(library: str) -> None: +def test_mean(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.assign((df.col("a") - df.col("a").mean()).rename("result")) diff --git a/tests/column/temporal/components_test.py b/tests/column/temporal/components_test.py index f0fb3bd6..ccc2cf64 100644 --- a/tests/column/temporal/components_test.py +++ b/tests/column/temporal/components_test.py @@ -4,6 +4,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import temporal_dataframe_1 @@ -21,7 +22,7 @@ ("unix_timestamp", [1577840521, 1577934062, 1578027849]), ], ) -def test_col_components(library: str, attr: str, expected: list[int]) -> None: +def test_col_components(library: BaseHandler, attr: str, expected: list[int]) -> None: df = temporal_dataframe_1(library) ns = df.__dataframe_namespace__() for col_name in ("a", "c", "e"): @@ -43,7 +44,11 @@ def test_col_components(library: str, attr: str, expected: list[int]) -> None: ("e", [123543, 321654, 987321]), ], ) -def test_col_microsecond(library: str, col_name: str, expected: list[int]) -> None: +def test_col_microsecond( + library: BaseHandler, + col_name: str, + expected: list[int], +) -> None: df = temporal_dataframe_1(library) ns = df.__dataframe_namespace__() result = ( @@ -64,7 +69,7 @@ def test_col_microsecond(library: str, col_name: str, expected: list[int]) -> No ("e", [123543000, 321654000, 987321000]), ], ) -def test_col_nanosecond(library: str, col_name: str, expected: list[int]) -> None: +def test_col_nanosecond(library: BaseHandler, col_name: str, expected: list[int]) -> None: df = temporal_dataframe_1(library) ns = df.__dataframe_namespace__() result = ( @@ -87,7 +92,7 @@ def test_col_nanosecond(library: str, col_name: str, expected: list[int]) -> Non ], ) def test_col_unix_timestamp_time_units( - library: str, + library: BaseHandler, time_unit: Literal["s", "ms", "us", "ns"], expected: list[int], ) -> None: diff --git a/tests/column/temporal/filter_test.py b/tests/column/temporal/filter_test.py index 27c2d901..a095a46e 100644 --- a/tests/column/temporal/filter_test.py +++ b/tests/column/temporal/filter_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import temporal_dataframe_1 -def test_filter_w_date(library: str) -> None: +def test_filter_w_date(library: BaseHandler) -> None: df = temporal_dataframe_1(library).select("a", "index") ns = df.__dataframe_namespace__() result = df.filter(df.col("a") > ns.date(2020, 1, 2)).select("index") diff --git a/tests/column/temporal/floor_test.py b/tests/column/temporal/floor_test.py index b9bf5d85..9ccd08b6 100644 --- a/tests/column/temporal/floor_test.py +++ b/tests/column/temporal/floor_test.py @@ -4,6 +4,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import temporal_dataframe_1 @@ -14,7 +15,7 @@ ("1day", [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)]), ], ) -def test_floor(library: str, freq: str, expected: list[datetime]) -> None: +def test_floor(library: BaseHandler, freq: str, expected: list[datetime]) -> None: df = temporal_dataframe_1(library) ns = df.__dataframe_namespace__() col = df.col diff --git a/tests/conftest.py b/tests/conftest.py index 7f1a5d5f..f2590734 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,12 @@ import sys from typing import Any +import pytest + +from tests.utils import ModinHandler +from tests.utils import PandasHandler +from tests.utils import PolarsHandler + LIBRARIES = { (3, 8): ["pandas-numpy", "pandas-nullable", "polars-lazy"], (3, 9): ["pandas-numpy", "pandas-nullable", "polars-lazy"], @@ -11,10 +17,75 @@ (3, 12): ["polars-lazy"], } +LIBRARIES_HANDLERS = { + "pandas-numpy": PandasHandler("pandas-numpy"), + "pandas-nullable": PandasHandler("pandas-nullable"), + "polars-lazy": PolarsHandler("polars-lazy"), + "modin": ModinHandler("modin"), +} + + +def pytest_addoption(parser: Any) -> None: + parser.addoption( + "--library", + action="store", + default=None, + type=str, + help="library to test", + ) + + +def pytest_configure(config: Any) -> None: + library = config.option.library + if library is None: # pragma: no cover + # `LIBRARIES` is already initialized + return + else: + assert library in ("pandas-numpy", "pandas-nullable", "polars-lazy", "modin") + global LIBRARIES # noqa: PLW0603 + LIBRARIES = { + (3, 8): [library], + (3, 9): [library], + (3, 10): [library], + (3, 11): [library], + (3, 12): [library], + } + def pytest_generate_tests(metafunc: Any) -> None: if "library" in metafunc.fixturenames: - metafunc.parametrize( - "library", - LIBRARIES[sys.version_info[:2]], - ) + libraries = LIBRARIES[sys.version_info[:2]] + lib_handlers = [LIBRARIES_HANDLERS[lib] for lib in libraries] + + metafunc.parametrize("library", lib_handlers, ids=libraries) + + +ci_skip_ids = [ + # polars does not allow to create a dataframe with non-unique columns + "non_unique_column_names_test.py::test_repeated_columns[polars-lazy]", + # TODO: enable after modin adds implementation for standard + "scale_column_test.py::test_scale_column[modin]", + "scale_column_test.py::test_scale_column_polars_from_persisted_df[modin]", + "convert_to_standard_column_test.py::test_convert_to_std_column[modin]", +] + + +ci_xfail_ids = [ + # https://github.com/modin-project/modin/issues/7212 + "join_test.py::test_join_left[modin]", + "join_test.py::test_join_two_keys[modin]", + "persistedness_test.py::test_cross_df_propagation[modin]", + # https://github.com/modin-project/modin/issues/3602 + "aggregate_test.py::test_aggregate[modin]", + "aggregate_test.py::test_aggregate_only_size[modin]", +] + + +def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: # pragma: no cover + for item in items: + if any(id_ in item.nodeid for id_ in ci_xfail_ids): + item.add_marker(pytest.mark.xfail(strict=True)) + elif any(id_ in item.nodeid for id_ in ci_skip_ids): + item.add_marker( + pytest.mark.skip("does not make sense for a specific implementation"), + ) diff --git a/tests/dataframe/all_rowwise_test.py b/tests/dataframe/all_rowwise_test.py index 92b2df73..60421a5c 100644 --- a/tests/dataframe/all_rowwise_test.py +++ b/tests/dataframe/all_rowwise_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_dataframe_with_reference -def test_all_horizontal(library: str) -> None: +def test_all_horizontal(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() mask = ns.all_horizontal(*[df.col(col_name) for col_name in df.column_names]) @@ -15,7 +16,7 @@ def test_all_horizontal(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_all_horizontal_invalid(library: str) -> None: +def test_all_horizontal_invalid(library: BaseHandler) -> None: df = bool_dataframe_1(library) namespace = df.__dataframe_namespace__() with pytest.raises(ValueError): diff --git a/tests/dataframe/and_test.py b/tests/dataframe/and_test.py index 2b99778b..9c26ac39 100644 --- a/tests/dataframe/and_test.py +++ b/tests/dataframe/and_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_dataframe_with_reference -def test_and_with_scalar(library: str) -> None: +def test_and_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() other = True @@ -13,7 +14,7 @@ def test_and_with_scalar(library: str) -> None: compare_dataframe_with_reference(result, expected, ns.Bool) -def test_rand_with_scalar(library: str) -> None: +def test_rand_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() other = True diff --git a/tests/dataframe/any_all_test.py b/tests/dataframe/any_all_test.py index 63f9a95d..e97f6a27 100644 --- a/tests/dataframe/any_all_test.py +++ b/tests/dataframe/any_all_test.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import bool_dataframe_3 from tests.utils import compare_dataframe_with_reference @@ -15,7 +16,7 @@ ], ) def test_reductions( - library: str, + library: BaseHandler, reduction: str, expected_data: dict[str, object], ) -> None: @@ -25,7 +26,7 @@ def test_reductions( compare_dataframe_with_reference(result, expected_data, dtype=ns.Bool) # type: ignore[arg-type] -def test_any(library: str) -> None: +def test_any(library: BaseHandler) -> None: df = bool_dataframe_3(library) ns = df.__dataframe_namespace__() result = df.any() @@ -33,7 +34,7 @@ def test_any(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_all(library: str) -> None: +def test_all(library: BaseHandler) -> None: df = bool_dataframe_3(library) ns = df.__dataframe_namespace__() result = df.all() diff --git a/tests/dataframe/any_rowwise_test.py b/tests/dataframe/any_rowwise_test.py index 6fbb9177..c43fc137 100644 --- a/tests/dataframe/any_rowwise_test.py +++ b/tests/dataframe/any_rowwise_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_dataframe_with_reference -def test_any_horizontal(library: str) -> None: +def test_any_horizontal(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() mask = ns.any_horizontal(*[df.col(col_name) for col_name in df.column_names]) @@ -15,7 +16,7 @@ def test_any_horizontal(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_any_horizontal_invalid(library: str) -> None: +def test_any_horizontal_invalid(library: BaseHandler) -> None: df = bool_dataframe_1(library) namespace = df.__dataframe_namespace__() with pytest.raises(ValueError): diff --git a/tests/dataframe/assign_test.py b/tests/dataframe/assign_test.py index f6daf5af..02d3ccb4 100644 --- a/tests/dataframe/assign_test.py +++ b/tests/dataframe/assign_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_insert_columns(library: str) -> None: +def test_insert_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() new_col = (df.col("b") + 3).rename("result") @@ -18,7 +19,7 @@ def test_insert_columns(library: str) -> None: compare_dataframe_with_reference(df, expected, dtype=ns.Int64) -def test_insert_multiple_columns(library: str) -> None: +def test_insert_multiple_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() new_col = (df.col("b") + 3).rename("result") @@ -30,7 +31,7 @@ def test_insert_multiple_columns(library: str) -> None: compare_dataframe_with_reference(df, expected, dtype=ns.Int64) -def test_insert_multiple_columns_invalid(library: str) -> None: +def test_insert_multiple_columns_invalid(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") df.__dataframe_namespace__() new_col = (df.col("b") + 3).rename("result") @@ -38,7 +39,7 @@ def test_insert_multiple_columns_invalid(library: str) -> None: _ = df.assign([new_col.rename("c"), new_col.rename("d")]) # type: ignore[arg-type] -def test_insert_eager_columns(library: str) -> None: +def test_insert_eager_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() new_col = (df.col("b") + 3).rename("result") diff --git a/tests/dataframe/cast_test.py b/tests/dataframe/cast_test.py index 7e3a199d..e5679335 100644 --- a/tests/dataframe/cast_test.py +++ b/tests/dataframe/cast_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_cast_integers(library: str) -> None: +def test_cast_integers(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.cast({"a": ns.Int32()}) diff --git a/tests/dataframe/columns_iter_test.py b/tests/dataframe/columns_iter_test.py index f8a02f3b..a65b640d 100644 --- a/tests/dataframe/columns_iter_test.py +++ b/tests/dataframe/columns_iter_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_iter_columns(library: str) -> None: +def test_iter_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.assign( diff --git a/tests/dataframe/comparisons_test.py b/tests/dataframe/comparisons_test.py index 6886191f..458c56f3 100644 --- a/tests/dataframe/comparisons_test.py +++ b/tests/dataframe/comparisons_test.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 @@ -25,7 +26,7 @@ ], ) def test_comparisons_with_scalar( - library: str, + library: BaseHandler, comparison: str, expected_data: dict[str, object], expected_dtype: str, @@ -47,7 +48,7 @@ def test_comparisons_with_scalar( ], ) def test_rcomparisons_with_scalar( - library: str, + library: BaseHandler, comparison: str, expected_data: dict[str, object], ) -> None: diff --git a/tests/dataframe/cross_df_comparison_test.py b/tests/dataframe/cross_df_comparison_test.py index 6e96ddc1..fc8cbbef 100644 --- a/tests/dataframe/cross_df_comparison_test.py +++ b/tests/dataframe/cross_df_comparison_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 -def test_invalid_comparisons(library: str) -> None: +def test_invalid_comparisons(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) df2 = integer_dataframe_2(library) mask = df2.col("a") > 1 diff --git a/tests/dataframe/divmod_test.py b/tests/dataframe/divmod_test.py index 0a62d3fe..7ddb61aa 100644 --- a/tests/dataframe/divmod_test.py +++ b/tests/dataframe/divmod_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_divmod_with_scalar(library: str) -> None: +def test_divmod_with_scalar(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() other = 2 diff --git a/tests/dataframe/drop_column_test.py b/tests/dataframe/drop_column_test.py index 9f948245..97e6d205 100644 --- a/tests/dataframe/drop_column_test.py +++ b/tests/dataframe/drop_column_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_drop_column(library: str) -> None: +def test_drop_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.drop("a") diff --git a/tests/dataframe/drop_nulls_test.py b/tests/dataframe/drop_nulls_test.py index 2bf6e604..7797fe68 100644 --- a/tests/dataframe/drop_nulls_test.py +++ b/tests/dataframe/drop_nulls_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import null_dataframe_1 -def test_drop_nulls(library: str) -> None: +def test_drop_nulls(library: BaseHandler) -> None: df = null_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.drop_nulls() diff --git a/tests/dataframe/fill_nan_test.py b/tests/dataframe/fill_nan_test.py index e21f8af1..8ae7264f 100644 --- a/tests/dataframe/fill_nan_test.py +++ b/tests/dataframe/fill_nan_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import nan_dataframe_1 -def test_fill_nan(library: str) -> None: +def test_fill_nan(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.fill_nan(-1) @@ -15,7 +16,7 @@ def test_fill_nan(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Float64) -def test_fill_nan_with_scalar(library: str) -> None: +def test_fill_nan_with_scalar(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.fill_nan(df.col("a").get_value(0)) @@ -24,20 +25,20 @@ def test_fill_nan_with_scalar(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Float64) -def test_fill_nan_with_scalar_invalid(library: str) -> None: +def test_fill_nan_with_scalar_invalid(library: BaseHandler) -> None: df = nan_dataframe_1(library) other = df + 1 with pytest.raises(ValueError): _ = df.fill_nan(other.col("a").get_value(0)) -def test_fill_nan_with_null(library: str) -> None: +def test_fill_nan_with_null(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.fill_nan(ns.null) n_nans = result.is_nan().sum() result = n_nans.col("a").persist().get_value(0).scalar - if library == "pandas-numpy": + if library.name in ("pandas-numpy", "modin"): # null is nan for pandas-numpy assert result == 1 else: diff --git a/tests/dataframe/fill_null_test.py b/tests/dataframe/fill_null_test.py index 12c24e8d..b93c1e88 100644 --- a/tests/dataframe/fill_null_test.py +++ b/tests/dataframe/fill_null_test.py @@ -2,6 +2,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import nan_dataframe_1 from tests.utils import null_dataframe_2 @@ -15,7 +16,7 @@ ["b"], ], ) -def test_fill_null(library: str, column_names: list[str] | None) -> None: +def test_fill_null(library: BaseHandler, column_names: list[str] | None) -> None: df = null_dataframe_2(library) df.__dataframe_namespace__() result = df.fill_null(0, column_names=column_names) @@ -32,14 +33,14 @@ def test_fill_null(library: str, column_names: list[str] | None) -> None: assert result.col("b").persist().get_value(2).scalar == 0 -def test_fill_null_noop(library: str) -> None: +def test_fill_null_noop(library: BaseHandler) -> None: df = nan_dataframe_1(library) result_raw = df.fill_null(0) if hasattr(result_raw.dataframe, "collect"): result = result_raw.dataframe.collect() else: result = result_raw.dataframe - if library != "pandas-numpy": + if library.name not in ("pandas-numpy", "modin"): # nan should not have changed! assert result["a"][2] != result["a"][2] else: diff --git a/tests/dataframe/get_column_by_name_test.py b/tests/dataframe/get_column_by_name_test.py index 6ddae877..c1b31dc4 100644 --- a/tests/dataframe/get_column_by_name_test.py +++ b/tests/dataframe/get_column_by_name_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_get_column(library: str) -> None: +def test_get_column(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() col = df.col diff --git a/tests/dataframe/get_column_names_test.py b/tests/dataframe/get_column_names_test.py index 86e45ab3..138ade7a 100644 --- a/tests/dataframe/get_column_names_test.py +++ b/tests/dataframe/get_column_names_test.py @@ -1,9 +1,10 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_get_column_names(library: str) -> None: +def test_get_column_names(library: BaseHandler) -> None: df = integer_dataframe_1(library) result = df.column_names assert list(result) == ["a", "b"] diff --git a/tests/dataframe/get_rows_by_mask_test.py b/tests/dataframe/get_rows_by_mask_test.py index a2ae421c..c1048550 100644 --- a/tests/dataframe/get_rows_by_mask_test.py +++ b/tests/dataframe/get_rows_by_mask_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_filter(library: str) -> None: +def test_filter(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() mask = df.col("a") % 2 == 1 diff --git a/tests/dataframe/get_rows_test.py b/tests/dataframe/get_rows_test.py index 16391c64..61b4ddd8 100644 --- a/tests/dataframe/get_rows_test.py +++ b/tests/dataframe/get_rows_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_take(library: str) -> None: +def test_take(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() df = df.assign((df.col("a") - 1).sort(ascending=False).rename("result")) diff --git a/tests/dataframe/invert_test.py b/tests/dataframe/invert_test.py index ed84c32e..91c47412 100644 --- a/tests/dataframe/invert_test.py +++ b/tests/dataframe/invert_test.py @@ -2,12 +2,13 @@ import pytest +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_invert(library: str) -> None: +def test_invert(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() result = ~df @@ -15,7 +16,7 @@ def test_invert(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_invert_invalid(library: str) -> None: +def test_invert_invalid(library: BaseHandler) -> None: df = integer_dataframe_1(library) with pytest.raises(TypeError): _ = ~df diff --git a/tests/dataframe/is_nan_test.py b/tests/dataframe/is_nan_test.py index 3d82f9dc..e7bbdea7 100644 --- a/tests/dataframe/is_nan_test.py +++ b/tests/dataframe/is_nan_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import nan_dataframe_1 -def test_dataframe_is_nan(library: str) -> None: +def test_dataframe_is_nan(library: BaseHandler) -> None: df = nan_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.is_nan() diff --git a/tests/dataframe/is_null_test.py b/tests/dataframe/is_null_test.py index c2a469b2..ac322893 100644 --- a/tests/dataframe/is_null_test.py +++ b/tests/dataframe/is_null_test.py @@ -1,15 +1,16 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import nan_dataframe_2 from tests.utils import null_dataframe_1 -def test_is_null_1(library: str) -> None: +def test_is_null_1(library: BaseHandler) -> None: df = nan_dataframe_2(library) ns = df.__dataframe_namespace__() result = df.is_null() - if library == "pandas-numpy": + if library.name == "pandas-numpy": # nan and null are the same in pandas-numpy expected = {"a": [False, False, True]} else: @@ -17,7 +18,7 @@ def test_is_null_1(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_is_null_2(library: str) -> None: +def test_is_null_2(library: BaseHandler) -> None: df = null_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.is_null() diff --git a/tests/dataframe/join_test.py b/tests/dataframe/join_test.py index b893e83e..b6b91ce9 100644 --- a/tests/dataframe/join_test.py +++ b/tests/dataframe/join_test.py @@ -3,13 +3,14 @@ import pytest from packaging.version import Version -from tests.utils import PANDAS_VERSION +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 +from tests.utils import pandas_version -def test_join_left(library: str) -> None: +def test_join_left(library: BaseHandler) -> None: left = integer_dataframe_1(library) right = integer_dataframe_2(library).rename({"b": "c"}) result = left.join(right, left_on="a", right_on="a", how="left") @@ -18,19 +19,21 @@ def test_join_left(library: str) -> None: expected_dtype = { "a": ns.Int64, "b": ns.Int64, - "c": ns.Int64 if library in ["pandas-nullable", "polars-lazy"] else ns.Float64, + "c": ns.Int64 + if library.name in ["pandas-nullable", "polars-lazy"] + else ns.Float64, } compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_join_overlapping_names(library: str) -> None: +def test_join_overlapping_names(library: BaseHandler) -> None: left = integer_dataframe_1(library) right = integer_dataframe_2(library) with pytest.raises(ValueError): _ = left.join(right, left_on="a", right_on="a", how="left") -def test_join_inner(library: str) -> None: +def test_join_inner(library: BaseHandler) -> None: left = integer_dataframe_1(library) right = integer_dataframe_2(library).rename({"b": "c"}) result = left.join(right, left_on="a", right_on="a", how="inner") @@ -40,13 +43,13 @@ def test_join_inner(library: str) -> None: @pytest.mark.skip(reason="outer join has changed in Polars recently, need to fixup") -def test_join_outer(library: str) -> None: # pragma: no cover +def test_join_outer(library: BaseHandler) -> None: # pragma: no cover left = integer_dataframe_1(library) right = integer_dataframe_2(library).rename({"b": "c"}) result = left.join(right, left_on="a", right_on="a", how="outer").sort("a") ns = result.__dataframe_namespace__() if ( - library == "pandas-nullable" and Version("2.0.0") > PANDAS_VERSION + library.name == "pandas-nullable" and Version("2.0.0") > pandas_version() ): # pragma: no cover # upstream bug result = result.cast({"a": ns.Int64()}) @@ -57,13 +60,17 @@ def test_join_outer(library: str) -> None: # pragma: no cover } expected_dtype = { "a": ns.Int64, - "b": ns.Int64 if library in ["pandas-nullable", "polars-lazy"] else ns.Float64, - "c": ns.Int64 if library in ["pandas-nullable", "polars-lazy"] else ns.Float64, + "b": ns.Int64 + if library.name in ["pandas-nullable", "polars-lazy"] + else ns.Float64, + "c": ns.Int64 + if library.name in ["pandas-nullable", "polars-lazy"] + else ns.Float64, } compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_join_two_keys(library: str) -> None: +def test_join_two_keys(library: BaseHandler) -> None: left = integer_dataframe_1(library) right = integer_dataframe_2(library).rename({"b": "c"}) result = left.join(right, left_on=["a", "b"], right_on=["a", "c"], how="left") @@ -72,12 +79,14 @@ def test_join_two_keys(library: str) -> None: expected_dtype = { "a": ns.Int64, "b": ns.Int64, - "c": ns.Int64 if library in ["pandas-nullable", "polars-lazy"] else ns.Float64, + "c": ns.Int64 + if library.name in ["pandas-nullable", "polars-lazy"] + else ns.Float64, } compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_join_invalid(library: str) -> None: +def test_join_invalid(library: BaseHandler) -> None: left = integer_dataframe_1(library) right = integer_dataframe_2(library).rename({"b": "c"}) with pytest.raises(ValueError): diff --git a/tests/dataframe/or_test.py b/tests/dataframe/or_test.py index 1a4a8c95..383cac4f 100644 --- a/tests/dataframe/or_test.py +++ b/tests/dataframe/or_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import bool_dataframe_1 from tests.utils import compare_dataframe_with_reference -def test_or_with_scalar(library: str) -> None: +def test_or_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() other = True @@ -13,7 +14,7 @@ def test_or_with_scalar(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Bool) -def test_ror_with_scalar(library: str) -> None: +def test_ror_with_scalar(library: BaseHandler) -> None: df = bool_dataframe_1(library) ns = df.__dataframe_namespace__() other = True diff --git a/tests/dataframe/pow_test.py b/tests/dataframe/pow_test.py index eff8b95b..1ba1e370 100644 --- a/tests/dataframe/pow_test.py +++ b/tests/dataframe/pow_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_float_scalar_powers(library: str) -> None: +def test_float_scalar_powers(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() other = 1.0 diff --git a/tests/dataframe/reductions_test.py b/tests/dataframe/reductions_test.py index 2055a7ef..9a83bc44 100644 --- a/tests/dataframe/reductions_test.py +++ b/tests/dataframe/reductions_test.py @@ -4,6 +4,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 @@ -22,7 +23,7 @@ ], ) def test_dataframe_reductions( - library: str, + library: BaseHandler, reduction: str, expected: dict[str, Any], expected_dtype: str, diff --git a/tests/dataframe/rename_columns_test.py b/tests/dataframe/rename_columns_test.py index 63081cf5..2e193b11 100644 --- a/tests/dataframe/rename_columns_test.py +++ b/tests/dataframe/rename_columns_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_rename(library: str) -> None: +def test_rename(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.rename({"a": "c", "b": "e"}) @@ -14,7 +15,7 @@ def test_rename(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_rename_invalid(library: str) -> None: +def test_rename_invalid(library: BaseHandler) -> None: df = integer_dataframe_1(library) with pytest.raises( TypeError, diff --git a/tests/dataframe/schema_test.py b/tests/dataframe/schema_test.py index c8d0538e..a02cac00 100644 --- a/tests/dataframe/schema_test.py +++ b/tests/dataframe/schema_test.py @@ -1,19 +1,18 @@ from __future__ import annotations -import pandas as pd import pytest from packaging.version import Version -from packaging.version import parse -from tests.utils import PANDAS_VERSION +from tests.utils import BaseHandler from tests.utils import mixed_dataframe_1 +from tests.utils import pandas_version -@pytest.mark.skipif( - Version("2.0.0") > PANDAS_VERSION, - reason="no pyarrow support", -) -def test_schema(library: str) -> None: +def test_schema(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable") and pandas_version() < Version( + "2.0.0", + ): # pragma: no cover + pytest.skip(reason="no pyarrow support") df = mixed_dataframe_1(library) namespace = df.__dataframe_namespace__() result = df.schema @@ -51,7 +50,8 @@ def test_schema(library: str) -> None: assert isinstance(result["m"], namespace.Datetime) assert isinstance(result["n"], namespace.Datetime) if not ( - library.startswith("pandas") and parse(pd.__version__) < Version("2.0.0") + library.name in ("pandas-numpy", "pandas-nullable") + and pandas_version() < Version("2.0.0") ): # pragma: no cover (coverage bug?) # pandas non-nanosecond support only came in 2.0 assert result["n"].time_unit == "ms" @@ -60,14 +60,18 @@ def test_schema(library: str) -> None: assert result["n"].time_zone is None assert isinstance(result["o"], namespace.Datetime) if not ( - library.startswith("pandas") and parse(pd.__version__) < Version("2.0.0") + library.name in ("pandas-numpy", "pandas-nullable") + and pandas_version() < Version("2.0.0") ): # pragma: no cover (coverage bug?) # pandas non-nanosecond support only came in 2.0 assert result["o"].time_unit == "us" else: # pragma: no cover pass assert result["o"].time_zone is None - if not (library.startswith("pandas") and parse(pd.__version__) < Version("2.0.0")): + if not ( + library.name in ("pandas-numpy", "pandas-nullable") + and pandas_version() < Version("2.0.0") + ): # pandas non-nanosecond support only came in 2.0 - before that, these would be 'float' assert isinstance(result["p"], namespace.Duration) assert result["p"].time_unit == "ms" diff --git a/tests/dataframe/select_test.py b/tests/dataframe/select_test.py index 60bde31f..02c5d52f 100644 --- a/tests/dataframe/select_test.py +++ b/tests/dataframe/select_test.py @@ -2,11 +2,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_select(library: str) -> None: +def test_select(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.select("b") @@ -14,7 +15,7 @@ def test_select(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_select_list_of_str(library: str) -> None: +def test_select_list_of_str(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() result = df.select("a", "b") @@ -22,14 +23,14 @@ def test_select_list_of_str(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_select_list_of_str_invalid(library: str) -> None: +def test_select_list_of_str_invalid(library: BaseHandler) -> None: df = integer_dataframe_1(library) with pytest.raises(TypeError): _ = df.select(["a", "b"]) # type: ignore[arg-type] @pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated") -def test_select_empty(library: str) -> None: +def test_select_empty(library: BaseHandler) -> None: df = integer_dataframe_1(library) result = df.select() assert result.column_names == [] diff --git a/tests/dataframe/shape_test.py b/tests/dataframe/shape_test.py index 491ed0ee..b127b982 100644 --- a/tests/dataframe/shape_test.py +++ b/tests/dataframe/shape_test.py @@ -2,10 +2,11 @@ import pytest +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_shape(library: str) -> None: +def test_shape(library: BaseHandler) -> None: df = integer_dataframe_1(library).persist() assert df.shape() == (3, 2) diff --git a/tests/dataframe/slice_rows_test.py b/tests/dataframe/slice_rows_test.py index 027b98df..271e878e 100644 --- a/tests/dataframe/slice_rows_test.py +++ b/tests/dataframe/slice_rows_test.py @@ -4,6 +4,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_3 @@ -18,7 +19,7 @@ ], ) def test_slice_rows( - library: str, + library: BaseHandler, start: int | None, stop: int | None, step: int | None, diff --git a/tests/dataframe/sort_test.py b/tests/dataframe/sort_test.py index 1698e671..b5b8649b 100644 --- a/tests/dataframe/sort_test.py +++ b/tests/dataframe/sort_test.py @@ -2,12 +2,13 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_5 @pytest.mark.parametrize("keys", [["a", "b"], []]) -def test_sort(library: str, keys: list[str]) -> None: +def test_sort(library: BaseHandler, keys: list[str]) -> None: df = integer_dataframe_5(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() result = df.sort(*keys) @@ -17,7 +18,7 @@ def test_sort(library: str, keys: list[str]) -> None: @pytest.mark.parametrize("keys", [["a", "b"], []]) def test_sort_descending( - library: str, + library: BaseHandler, keys: list[str], ) -> None: df = integer_dataframe_5(library, api_version="2023.09-beta") diff --git a/tests/dataframe/to_array_object_test.py b/tests/dataframe/to_array_object_test.py index d9b53ad8..42bdb35f 100644 --- a/tests/dataframe/to_array_object_test.py +++ b/tests/dataframe/to_array_object_test.py @@ -2,10 +2,11 @@ import numpy as np +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_to_array_object(library: str) -> None: +def test_to_array_object(library: BaseHandler) -> None: df = integer_dataframe_1(library).persist() result = np.asarray(df.to_array(dtype="int64")) # type: ignore[call-arg] expected = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.int64) diff --git a/tests/dataframe/update_columns_test.py b/tests/dataframe/update_columns_test.py index 0dfc67fe..2b0ebc5c 100644 --- a/tests/dataframe/update_columns_test.py +++ b/tests/dataframe/update_columns_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_update_columns(library: str) -> None: +def test_update_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() col = df.col @@ -13,7 +14,7 @@ def test_update_columns(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_update_multiple_columns(library: str) -> None: +def test_update_multiple_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() col = df.col diff --git a/tests/dataframe/update_test.py b/tests/dataframe/update_test.py index 3c6b57a7..d5ef2c23 100644 --- a/tests/dataframe/update_test.py +++ b/tests/dataframe/update_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_update_column(library: str) -> None: +def test_update_column(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() new_col = df.col("b") + 3 @@ -16,7 +17,7 @@ def test_update_column(library: str) -> None: compare_dataframe_with_reference(df, expected, dtype=ns.Int64) -def test_update_columns(library: str) -> None: +def test_update_columns(library: BaseHandler) -> None: df = integer_dataframe_1(library, api_version="2023.09-beta") ns = df.__dataframe_namespace__() new_col_a = df.col("a") + 1 diff --git a/tests/groupby/aggregate_test.py b/tests/groupby/aggregate_test.py index 25619342..0ad381a2 100644 --- a/tests/groupby/aggregate_test.py +++ b/tests/groupby/aggregate_test.py @@ -1,8 +1,9 @@ +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_4 -def test_aggregate(library: str) -> None: +def test_aggregate(library: BaseHandler) -> None: df = integer_dataframe_4(library) df = df.assign((df.col("b") > 0).rename("d")) ns = df.__dataframe_namespace__() @@ -51,12 +52,12 @@ def test_aggregate(library: str) -> None: "d_any": ns.Bool, "d_all": ns.Bool, } - if library == "polars-lazy": + if library.name == "polars-lazy": result = result.cast({"b_count": ns.Int64()}) compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_aggregate_only_size(library: str) -> None: +def test_aggregate_only_size(library: BaseHandler) -> None: df = integer_dataframe_4(library) ns = df.__dataframe_namespace__() result = ( @@ -70,12 +71,12 @@ def test_aggregate_only_size(library: str) -> None: "key": [1, 2], "b_count": [2, 2], } - if library == "polars-lazy": + if library.name == "polars-lazy": result = result.cast({"b_count": ns.Int64()}) compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_aggregate_no_size(library: str) -> None: +def test_aggregate_no_size(library: BaseHandler) -> None: df = integer_dataframe_4(library) ns = df.__dataframe_namespace__() result = ( diff --git a/tests/groupby/groupby_any_all_test.py b/tests/groupby/groupby_any_all_test.py index 8ae9eea1..f2116e38 100644 --- a/tests/groupby/groupby_any_all_test.py +++ b/tests/groupby/groupby_any_all_test.py @@ -1,11 +1,13 @@ from __future__ import annotations +from typing import Any + import pandas as pd import pytest from packaging.version import Version from packaging.version import parse -from polars.exceptions import SchemaError +from tests.utils import BaseHandler from tests.utils import bool_dataframe_2 from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_4 @@ -19,7 +21,7 @@ ], ) def test_groupby_boolean( - library: str, + library: BaseHandler, aggregation: str, expected_b: list[bool], expected_c: list[bool], @@ -29,7 +31,7 @@ def test_groupby_boolean( result = getattr(df.group_by("key"), aggregation)() # need to sort result = result.sort("key") - if library == "pandas-nullable" and parse(pd.__version__) < Version( + if library.name == "pandas-nullable" and parse(pd.__version__) < Version( "2.0.0", ): # pragma: no cover # upstream bug @@ -39,9 +41,15 @@ def test_groupby_boolean( compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_group_by_invalid_any_all(library: str) -> None: +def test_group_by_invalid_any_all(library: BaseHandler) -> None: df = integer_dataframe_4(library).persist() - with pytest.raises((TypeError, SchemaError)): + + exceptions: tuple[Any, ...] = (TypeError,) + if library.name == "polars-lazy": + from polars.exceptions import SchemaError + + exceptions = (TypeError, SchemaError) + with pytest.raises(exceptions): df.group_by("key").any() - with pytest.raises((TypeError, SchemaError)): + with pytest.raises(exceptions): df.group_by("key").all() diff --git a/tests/groupby/invalid_test.py b/tests/groupby/invalid_test.py index 679d1acd..ad4be435 100644 --- a/tests/groupby/invalid_test.py +++ b/tests/groupby/invalid_test.py @@ -2,10 +2,11 @@ import pytest +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_groupby_invalid(library: str) -> None: +def test_groupby_invalid(library: BaseHandler) -> None: df = integer_dataframe_1(library).select("a") with pytest.raises((KeyError, TypeError)): df.group_by(0) # type: ignore[arg-type] diff --git a/tests/groupby/numeric_test.py b/tests/groupby/numeric_test.py index 075f1588..7109bc83 100644 --- a/tests/groupby/numeric_test.py +++ b/tests/groupby/numeric_test.py @@ -5,6 +5,7 @@ from packaging.version import Version from packaging.version import parse +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_4 @@ -28,7 +29,7 @@ ], ) def test_group_by_numeric( - library: str, + library: BaseHandler, aggregation: str, expected_b: list[float], expected_c: list[float], @@ -41,7 +42,7 @@ def test_group_by_numeric( expected = {"key": [1, 2], "b": expected_b, "c": expected_c} dtype = getattr(ns, expected_dtype) expected_ns_dtype = {"key": ns.Int64, "b": dtype, "c": dtype} - if library == "pandas-nullable" and parse(pd.__version__) < Version( + if library.name == "pandas-nullable" and parse(pd.__version__) < Version( "2.0.0", ): # pragma: no cover # upstream bug diff --git a/tests/groupby/size_test.py b/tests/groupby/size_test.py index 2d7da647..cf7f4c39 100644 --- a/tests/groupby/size_test.py +++ b/tests/groupby/size_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_4 -def test_group_by_size(library: str) -> None: +def test_group_by_size(library: BaseHandler) -> None: df = integer_dataframe_4(library) ns = df.__dataframe_namespace__() result = df.group_by("key").size() diff --git a/tests/integration/free_vs_w_parent_test.py b/tests/integration/free_vs_w_parent_test.py index dab61de4..8ac89133 100644 --- a/tests/integration/free_vs_w_parent_test.py +++ b/tests/integration/free_vs_w_parent_test.py @@ -1,27 +1,23 @@ import numpy as np -import polars as pl -from polars.testing import assert_series_equal +from tests.utils import BaseHandler +from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 -def test_free_vs_w_parent(library: str) -> None: +def test_free_vs_w_parent(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) - namespace = df1.__dataframe_namespace__() - free_ser1 = namespace.column_from_1d_array( # type: ignore[call-arg] + ns = df1.__dataframe_namespace__() + free_ser1 = ns.column_from_1d_array( # type: ignore[call-arg] np.array([1, 2, 3], dtype="int64"), name="preds", ) - free_ser2 = namespace.column_from_1d_array( # type: ignore[call-arg] + free_ser2 = ns.column_from_1d_array( # type: ignore[call-arg] np.array([4, 5, 6], dtype="int64"), name="preds", ) result = free_ser1 + free_ser2 - if library == "polars-lazy": - assert_series_equal( - pl.select(result.column)["preds"], - pl.Series("preds", [5, 7, 9], dtype=pl.Int64()), - ) - assert namespace.is_dtype(result.dtype, "integral") + compare_column_with_reference(result, [5, 7, 9], dtype=ns.Int64) + assert ns.is_dtype(result.dtype, "integral") diff --git a/tests/integration/persistedness_test.py b/tests/integration/persistedness_test.py index 9d6bf7de..8803e76c 100644 --- a/tests/integration/persistedness_test.py +++ b/tests/integration/persistedness_test.py @@ -1,11 +1,12 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 -def test_within_df_propagation(library: str) -> None: +def test_within_df_propagation(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) df1 = df1 + 1 with pytest.raises(RuntimeError): @@ -46,14 +47,14 @@ def test_within_df_propagation(library: str) -> None: assert int(scalar + 1) == 3 # type: ignore[call-overload] -def test_within_df_within_col_propagation(library: str) -> None: +def test_within_df_within_col_propagation(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) df1 = df1 + 1 df1 = df1.persist() assert int((df1.col("a") + 1).mean()) == 4 # type: ignore[call-overload] -def test_cross_df_propagation(library: str) -> None: +def test_cross_df_propagation(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) df2 = integer_dataframe_2(library) ns = df1.__dataframe_namespace__() @@ -69,12 +70,14 @@ def test_cross_df_propagation(library: str) -> None: expected_dtype = { "a": ns.Int64, "b": ns.Int64, - "c": ns.Int64 if library in ["pandas-nullable", "polars-lazy"] else ns.Float64, + "c": ns.Int64 + if library.name in ["pandas-nullable", "polars-lazy"] + else ns.Float64, } compare_dataframe_with_reference(result, expected, dtype=expected_dtype) # type: ignore[arg-type] -def test_multiple_propagations(library: str) -> None: +def test_multiple_propagations(library: BaseHandler) -> None: # This is a bit "ugly", as the user is "required" to call `persist` # multiple times to do things optimally df = integer_dataframe_1(library) @@ -97,7 +100,7 @@ def test_multiple_propagations(library: str) -> None: int(df1.col("a").mean()) # type: ignore[call-overload] -def test_parent_propagations(library: str) -> None: +def test_parent_propagations(library: BaseHandler) -> None: # Set up something like this: # # df diff --git a/tests/integration/scale_column_test.py b/tests/integration/scale_column_test.py index 4f07c8fd..eca8e5bf 100644 --- a/tests/integration/scale_column_test.py +++ b/tests/integration/scale_column_test.py @@ -1,44 +1,57 @@ from __future__ import annotations -import pandas as pd -import polars as pl import pytest from packaging.version import Version -from packaging.version import parse -from polars.testing import assert_series_equal +from tests.utils import BaseHandler +from tests.utils import compare_column_with_reference +from tests.utils import pandas_version +from tests.utils import polars_version -@pytest.mark.skipif( - parse(pd.__version__) < Version("2.1.0"), - reason="pandas doesn't support 3.8", -) -def test_scale_column_pandas() -> None: - s = pd.Series([1, 2, 3], name="a") - ser = s.__column_consortium_standard__() - ser = ser - ser.mean() - result = ser.column - pd.testing.assert_series_equal(result, pd.Series([-1, 0, 1.0], name="a")) +def test_scale_column(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable"): + if pandas_version() < Version("2.1.0"): # pragma: no cover + pytest.skip(reason="pandas doesn't support 3.8") + import pandas as pd + + s = pd.Series([1, 2, 3], name="a") + ser = s.__column_consortium_standard__() + elif library.name == "polars-lazy": + if polars_version() < Version("0.19.0"): # pragma: no cover + pytest.skip(reason="before consortium standard in polars") + import polars as pl + + s = pl.Series("a", [1, 2, 3]) + ser = s.__column_consortium_standard__() + else: # pragma: no cover + msg = f"Not supported library: {library}" + raise AssertionError(msg) -@pytest.mark.skipif( - parse(pl.__version__) < Version("0.19.0"), - reason="before consortium standard in polars", -) -def test_scale_column_polars() -> None: - s = pl.Series("a", [1, 2, 3]) - ser = s.__column_consortium_standard__() + ns = ser.__column_namespace__() ser = ser - ser.mean() - result = pl.select(ser.column)["a"] - assert_series_equal(result, pl.Series("a", [-1, 0, 1.0])) + compare_column_with_reference(ser, [-1, 0, 1.0], dtype=ns.Float64) + + +def test_scale_column_polars_from_persisted_df(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable"): + if pandas_version() < Version("2.1.0"): # pragma: no cover + pytest.skip(reason="pandas doesn't support 3.8") + import pandas as pd + + df = pd.DataFrame({"a": [1, 2, 3]}) + ser = df.__dataframe_consortium_standard__().col("a") + elif library.name == "polars-lazy": + if polars_version() < Version("0.19.0"): # pragma: no cover + pytest.skip(reason="before consortium standard in polars") + import polars as pl + df = pl.DataFrame({"a": [1, 2, 3]}) + ser = df.__dataframe_consortium_standard__().col("a") + else: # pragma: no cover + msg = f"Not supported library: {library}" + raise AssertionError(msg) -@pytest.mark.skipif( - parse(pl.__version__) < Version("0.19.0"), - reason="before consortium standard in polars", -) -def test_scale_column_polars_from_persisted_df() -> None: - df = pl.DataFrame({"a": [1, 2, 3]}) - ser = df.__dataframe_consortium_standard__().col("a") + ns = ser.__column_namespace__() ser = ser - ser.mean() - result = pl.select(ser.persist().column)["a"] - assert_series_equal(result, pl.Series("a", [-1, 0, 1.0])) + compare_column_with_reference(ser, [-1, 0, 1.0], dtype=ns.Float64) diff --git a/tests/integration/upstream_test.py b/tests/integration/upstream_test.py index abbfc982..caad0e92 100644 --- a/tests/integration/upstream_test.py +++ b/tests/integration/upstream_test.py @@ -7,7 +7,7 @@ class TestPolars: def test_dataframe(self) -> None: - import polars as pl + pl = pytest.importorskip("polars") if parse(pl.__version__) < Version("0.19.0"): # pragma: no cover # before consortium standard in polars @@ -20,7 +20,7 @@ def test_dataframe(self) -> None: assert result == expected def test_lazyframe(self) -> None: - import polars as pl + pl = pytest.importorskip("polars") if parse(pl.__version__) < Version("0.19.0"): # pragma: no cover # before consortium standard in polars diff --git a/tests/namespace/column_from_1d_array_test.py b/tests/namespace/column_from_1d_array_test.py index b2dac631..5e4929fe 100644 --- a/tests/namespace/column_from_1d_array_test.py +++ b/tests/namespace/column_from_1d_array_test.py @@ -8,10 +8,11 @@ import pytest from packaging.version import Version -from tests.utils import PANDAS_VERSION -from tests.utils import POLARS_VERSION +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 +from tests.utils import pandas_version +from tests.utils import polars_version @pytest.mark.parametrize( @@ -30,7 +31,7 @@ ], ) def test_column_from_1d_array( - library: str, + library: BaseHandler, pandas_dtype: str, column_dtype: str, ) -> None: @@ -52,7 +53,7 @@ def test_column_from_1d_array( def test_column_from_1d_array_string( - library: str, + library: BaseHandler, ) -> None: ser = integer_dataframe_1(library).persist().col("a") ns = ser.__column_namespace__() @@ -68,7 +69,7 @@ def test_column_from_1d_array_string( def test_column_from_1d_array_bool( - library: str, + library: BaseHandler, ) -> None: ser = integer_dataframe_1(library).persist().col("a") ns = ser.__column_namespace__() @@ -83,7 +84,7 @@ def test_column_from_1d_array_bool( compare_column_with_reference(result.col("result"), expected, dtype=ns.Bool) -def test_datetime_from_1d_array(library: str) -> None: +def test_datetime_from_1d_array(library: BaseHandler) -> None: ser = integer_dataframe_1(library).persist().col("a") ns = ser.__column_namespace__() arr = np.array([date(2020, 1, 1), date(2020, 1, 2)], dtype="datetime64[ms]") @@ -97,15 +98,16 @@ def test_datetime_from_1d_array(library: str) -> None: compare_column_with_reference(result.col("result"), expected, dtype=ns.Datetime) -@pytest.mark.skipif( - Version("0.19.9") > POLARS_VERSION, - reason="upstream bug", -) -@pytest.mark.skipif( - Version("2.0.0") > PANDAS_VERSION, - reason="pandas before non-nano", -) -def test_duration_from_1d_array(library: str) -> None: +def test_duration_from_1d_array(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable") and pandas_version() < Version( + "2.0.0", + ): # pragma: no cover + pytest.skip(reason="pandas before non-nano") + if library.name == "polars-lazy" and polars_version() < Version( + "0.19.9", + ): # pragma: no cover + pytest.skip(reason="upstream bug") + ser = integer_dataframe_1(library).persist().col("a") ns = ser.__column_namespace__() arr = np.array([timedelta(1), timedelta(2)], dtype="timedelta64[ms]") @@ -115,7 +117,7 @@ def test_duration_from_1d_array(library: str) -> None: name="result", ), ) - if library == "polars-lazy": + if library.name == "polars-lazy": # https://github.com/data-apis/dataframe-api/issues/329 result = result.cast({"result": ns.Duration("ms")}) expected = [timedelta(1), timedelta(2)] diff --git a/tests/namespace/column_from_sequence_test.py b/tests/namespace/column_from_sequence_test.py index e6362e12..f089dd52 100644 --- a/tests/namespace/column_from_sequence_test.py +++ b/tests/namespace/column_from_sequence_test.py @@ -6,6 +6,7 @@ import pytest +from tests.utils import BaseHandler from tests.utils import compare_column_with_reference from tests.utils import integer_dataframe_1 @@ -30,7 +31,7 @@ ], ) def test_column_from_sequence( - library: str, + library: BaseHandler, values: list[Any], dtype: str, kwargs: dict[str, Any], @@ -51,7 +52,7 @@ def test_column_from_sequence( def test_column_from_sequence_no_dtype( - library: str, + library: BaseHandler, ) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() diff --git a/tests/namespace/concat_test.py b/tests/namespace/concat_test.py index 79901d5a..f4413fab 100644 --- a/tests/namespace/concat_test.py +++ b/tests/namespace/concat_test.py @@ -1,15 +1,17 @@ from __future__ import annotations -import polars as pl +from typing import Any + import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 from tests.utils import integer_dataframe_4 -def test_concat(library: str) -> None: +def test_concat(library: BaseHandler) -> None: df1 = integer_dataframe_1(library) df2 = integer_dataframe_2(library) ns = df1.__dataframe_namespace__() @@ -18,10 +20,15 @@ def test_concat(library: str) -> None: compare_dataframe_with_reference(result, expected, dtype=ns.Int64) -def test_concat_mismatch(library: str) -> None: +def test_concat_mismatch(library: BaseHandler) -> None: df1 = integer_dataframe_1(library).persist() df2 = integer_dataframe_4(library).persist() ns = df1.__dataframe_namespace__() + exceptions: tuple[Any, ...] = (ValueError,) + if library.name == "polars-lazy": + import polars as pl + + exceptions = (ValueError, pl.exceptions.ShapeError) # TODO check the error - with pytest.raises((ValueError, pl.exceptions.ShapeError)): + with pytest.raises(exceptions): _ = ns.concat([df1, df2]).persist() diff --git a/tests/namespace/convert_to_standard_column_test.py b/tests/namespace/convert_to_standard_column_test.py index 029a4047..57ef3b3c 100644 --- a/tests/namespace/convert_to_standard_column_test.py +++ b/tests/namespace/convert_to_standard_column_test.py @@ -1,24 +1,32 @@ from __future__ import annotations -import pandas as pd -import polars as pl import pytest from packaging.version import Version -from tests.utils import PANDAS_VERSION -from tests.utils import POLARS_VERSION +from tests.utils import BaseHandler +from tests.utils import pandas_version +from tests.utils import polars_version -@pytest.mark.skipif( - Version("0.19.0") > POLARS_VERSION or Version("2.1.0") > PANDAS_VERSION, - reason="before consortium standard in polars/pandas", -) -def test_convert_to_std_column() -> None: - s = pl.Series([1, 2, 3]).__column_consortium_standard__() - assert float(s.mean()) == 2 - s = pl.Series("bob", [1, 2, 3]).__column_consortium_standard__() - assert float(s.mean()) == 2 - s = pd.Series([1, 2, 3]).__column_consortium_standard__() - assert float(s.mean()) == 2 - s = pd.Series([1, 2, 3], name="alice").__column_consortium_standard__() - assert float(s.mean()) == 2 +def test_convert_to_std_column(library: BaseHandler) -> None: + if library.name in ("pandas-numpy", "pandas-nullable"): + if pandas_version() < Version("2.1.0"): # pragma: no cover + pytest.skip(reason="before consortium standard in pandas") + import pandas as pd + + s = pd.Series([1, 2, 3]).__column_consortium_standard__() + assert float(s.mean()) == 2 + s = pd.Series([1, 2, 3], name="alice").__column_consortium_standard__() + assert float(s.mean()) == 2 + elif library.name == "polars-lazy": + if polars_version() < Version("0.19.0"): # pragma: no cover + pytest.skip(reason="before consortium standard in polars") + import polars as pl + + s = pl.Series([1, 2, 3]).__column_consortium_standard__() + assert float(s.mean()) == 2 + s = pl.Series("bob", [1, 2, 3]).__column_consortium_standard__() + assert float(s.mean()) == 2 + else: # pragma: no cover + msg = f"Not supported library: {library}" + raise AssertionError(msg) diff --git a/tests/namespace/dataframe_from_2d_array_test.py b/tests/namespace/dataframe_from_2d_array_test.py index 503486da..2b381d96 100644 --- a/tests/namespace/dataframe_from_2d_array_test.py +++ b/tests/namespace/dataframe_from_2d_array_test.py @@ -2,11 +2,12 @@ import numpy as np +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 -def test_dataframe_from_2d_array(library: str) -> None: +def test_dataframe_from_2d_array(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() arr = np.array([[1, 4], [2, 5], [3, 6]]) diff --git a/tests/namespace/is_dtype_test.py b/tests/namespace/is_dtype_test.py index 55eb17de..87b2995d 100644 --- a/tests/namespace/is_dtype_test.py +++ b/tests/namespace/is_dtype_test.py @@ -3,8 +3,9 @@ import pytest from packaging.version import Version -from tests.utils import PANDAS_VERSION +from tests.utils import BaseHandler from tests.utils import mixed_dataframe_1 +from tests.utils import pandas_version @pytest.mark.parametrize( @@ -20,11 +21,11 @@ (("string", "unsigned integer"), ["e", "f", "g", "h", "l"]), ], ) -@pytest.mark.skipif( - Version("2.0.0") > PANDAS_VERSION, - reason="before pandas got non-nano support", -) -def test_is_dtype(library: str, dtype: str, expected: list[str]) -> None: +def test_is_dtype(library: BaseHandler, dtype: str, expected: list[str]) -> None: + if library.name in ("pandas-numpy", "pandas-nullable") and pandas_version() < Version( + "2.0.0", + ): # pragma: no cover + pytest.skip(reason="pandas before non-nano") df = mixed_dataframe_1(library).persist() namespace = df.__dataframe_namespace__() result = [i for i in df.column_names if namespace.is_dtype(df.schema[i], dtype)] diff --git a/tests/namespace/namespace_is_null_test.py b/tests/namespace/namespace_is_null_test.py index 7d25d1c8..758795ba 100644 --- a/tests/namespace/namespace_is_null_test.py +++ b/tests/namespace/namespace_is_null_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 -def test_is_null(library: str) -> None: +def test_is_null(library: BaseHandler) -> None: df = integer_dataframe_1(library) other = integer_dataframe_2(library) # use scalar namespace just for coverage purposes diff --git a/tests/namespace/sorted_indices_test.py b/tests/namespace/sorted_indices_test.py index d99a4585..5e755d72 100644 --- a/tests/namespace/sorted_indices_test.py +++ b/tests/namespace/sorted_indices_test.py @@ -1,10 +1,11 @@ from __future__ import annotations +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_6 -def test_column_sorted_indices_ascending(library: str) -> None: +def test_column_sorted_indices_ascending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() sorted_indices = df.col("b").sorted_indices() @@ -19,7 +20,7 @@ def test_column_sorted_indices_ascending(library: str) -> None: "b": [4, 4, 3, 1, 2], "result": [3, 4, 2, 1, 0], } - if library in ("polars", "polars-lazy"): + if library.name in ("polars", "polars-lazy"): result = result.cast({"result": ns.Int64()}) try: compare_dataframe_with_reference(result, expected_1, dtype=ns.Int64) @@ -28,7 +29,7 @@ def test_column_sorted_indices_ascending(library: str) -> None: compare_dataframe_with_reference(result, expected_2, dtype=ns.Int64) -def test_column_sorted_indices_descending(library: str) -> None: +def test_column_sorted_indices_descending(library: BaseHandler) -> None: df = integer_dataframe_6(library) ns = df.__dataframe_namespace__() sorted_indices = df.col("b").sorted_indices(ascending=False) @@ -43,7 +44,7 @@ def test_column_sorted_indices_descending(library: str) -> None: "b": [4, 4, 3, 1, 2], "result": [0, 1, 2, 4, 3], } - if library in ("polars", "polars-lazy"): + if library.name in ("polars", "polars-lazy"): result = result.cast({"result": ns.Int64()}) try: compare_dataframe_with_reference(result, expected_1, dtype=ns.Int64) diff --git a/tests/namespace/to_array_object_test.py b/tests/namespace/to_array_object_test.py index 302d5738..3b1f63c2 100644 --- a/tests/namespace/to_array_object_test.py +++ b/tests/namespace/to_array_object_test.py @@ -2,17 +2,18 @@ import numpy as np +from tests.utils import BaseHandler from tests.utils import integer_dataframe_1 -def test_to_array_object(library: str) -> None: +def test_to_array_object(library: BaseHandler) -> None: df = integer_dataframe_1(library).persist() result = np.asarray(df.to_array(dtype="int64")) # type: ignore # noqa: PGH003 expected = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.int64) np.testing.assert_array_equal(result, expected) -def test_column_to_array_object(library: str) -> None: +def test_column_to_array_object(library: BaseHandler) -> None: col = integer_dataframe_1(library).col("a") result = np.asarray(col.persist().to_array()) result = np.asarray(col.persist().to_array()) diff --git a/tests/scalars/float_test.py b/tests/scalars/float_test.py index d8d76656..1784f85f 100644 --- a/tests/scalars/float_test.py +++ b/tests/scalars/float_test.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from tests.utils import BaseHandler from tests.utils import compare_dataframe_with_reference from tests.utils import integer_dataframe_1 from tests.utils import integer_dataframe_2 @@ -31,7 +32,7 @@ "__rtruediv__", ], ) -def test_float_binary(library: str, attr: str) -> None: +def test_float_binary(library: BaseHandler, attr: str) -> None: other = 0.5 df = integer_dataframe_2(library).persist() scalar = df.col("a").mean() @@ -39,14 +40,14 @@ def test_float_binary(library: str, attr: str) -> None: assert getattr(scalar, attr)(other) == getattr(float_scalar, attr)(other) -def test_float_binary_invalid(library: str) -> None: +def test_float_binary_invalid(library: BaseHandler) -> None: lhs = integer_dataframe_2(library).col("a").mean() rhs = integer_dataframe_1(library).col("b").mean() with pytest.raises(ValueError): _ = lhs > rhs -def test_float_binary_lazy_valid(library: str) -> None: +def test_float_binary_lazy_valid(library: BaseHandler) -> None: df = integer_dataframe_2(library).persist() lhs = df.col("a").mean() rhs = df.col("b").mean() @@ -61,7 +62,7 @@ def test_float_binary_lazy_valid(library: str) -> None: "__neg__", ], ) -def test_float_unary(library: str, attr: str) -> None: +def test_float_unary(library: BaseHandler, attr: str) -> None: df = integer_dataframe_2(library).persist() with pytest.warns(UserWarning): scalar = df.col("a").persist().mean() @@ -77,7 +78,7 @@ def test_float_unary(library: str, attr: str) -> None: "__bool__", ], ) -def test_float_unary_invalid(library: str, attr: str) -> None: +def test_float_unary_invalid(library: BaseHandler, attr: str) -> None: df = integer_dataframe_2(library) scalar = df.col("a").mean() float_scalar = float(scalar.persist()) # type: ignore[arg-type] @@ -85,7 +86,7 @@ def test_float_unary_invalid(library: str, attr: str) -> None: assert getattr(scalar, attr)() == getattr(float_scalar, attr)() -def test_free_standing(library: str) -> None: +def test_free_standing(library: BaseHandler) -> None: df = integer_dataframe_1(library) namespace = df.__dataframe_namespace__() ser = namespace.column_from_1d_array( # type: ignore[call-arg] @@ -96,7 +97,7 @@ def test_free_standing(library: str) -> None: assert result == 3.0 -def test_right_comparand(library: str) -> None: +def test_right_comparand(library: BaseHandler) -> None: df = integer_dataframe_1(library) ns = df.__dataframe_namespace__() col = df.col("a") # [1, 2, 3] diff --git a/tests/utils.py b/tests/utils.py index 019e9f3c..22fb0bdb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,25 +1,149 @@ from __future__ import annotations +import contextlib import math +from abc import abstractmethod from datetime import datetime from datetime import timedelta from typing import TYPE_CHECKING from typing import Any +from typing import ClassVar from typing import Mapping -import pandas as pd -import polars as pl +from packaging.version import Version from packaging.version import parse -import dataframe_api_compat.pandas_standard -import dataframe_api_compat.polars_standard - if TYPE_CHECKING: + import pandas as pd + import polars as pl from dataframe_api import Column from dataframe_api import DataFrame + from dataframe_api.typing import DType + + +def pandas_version() -> Version: + import pandas as pd + + return parse(pd.__version__) + + +def polars_version() -> Version: + import polars as pl + + return parse(pl.__version__) + + +class BaseHandler: + @property + @abstractmethod + def name(self) -> str: + ... + + @abstractmethod + def create_dataframe( + self, + data: Any, + api_version: str | None = None, + ) -> DataFrame: + ... + -POLARS_VERSION = parse(pl.__version__) -PANDAS_VERSION = parse(pd.__version__) +class PandasHandler(BaseHandler): + # for `pandas-nullable` case + # https://pandas.pydata.org/docs/user_guide/basics.html#dtypes + mapping: ClassVar[dict[str, str]] = { + "bool": "boolean", + "int64": "Int64", + "float64": "Float64", + } + + def __init__(self, name: str) -> None: + assert name in ("pandas-numpy", "pandas-nullable") + self._name = name + + @property + def name(self) -> str: + return self._name + + def create_dataframe( + self, + data: Any, + api_version: str | None = None, + ) -> DataFrame: + import pandas as pd + + import dataframe_api_compat.pandas_standard + + df = pd.DataFrame(data) + if self.name == "pandas-nullable": + new_dtypes = { + col_name: self.mapping.get(str(dtype), str(dtype)) + for col_name, dtype in zip(df.columns, df.dtypes) + } + df = df.astype(new_dtypes) + + return ( + dataframe_api_compat.pandas_standard.convert_to_standard_compliant_dataframe( + df, + api_version=api_version or "2023.11-beta", + ) + ) + + +class PolarsHandler(BaseHandler): + def __init__(self, name: str) -> None: + assert name == "polars-lazy" + self._name = name + + @property + def name(self) -> str: + return self._name + + def create_dataframe( + self, + data: Any, + api_version: str | None = None, + ) -> DataFrame: + import polars as pl + + import dataframe_api_compat.polars_standard + + df = pl.DataFrame(data) + + return ( + dataframe_api_compat.polars_standard.convert_to_standard_compliant_dataframe( + df, + api_version=api_version or "2023.11-beta", + ) + ) + + +class ModinHandler(BaseHandler): + def __init__(self, name: str) -> None: + assert name == "modin" + self._name = name + + @property + def name(self) -> str: + return self._name + + def create_dataframe( + self, + data: Any, + api_version: str | None = None, + ) -> DataFrame: + import modin.pandas as pd + + import dataframe_api_compat.modin_standard + + df = pd.DataFrame(data) + + return ( + dataframe_api_compat.modin_standard.convert_to_standard_compliant_dataframe( + df, + api_version=api_version or "2023.11-beta", + ) + ) def convert_to_standard_compliant_dataframe( @@ -27,14 +151,26 @@ def convert_to_standard_compliant_dataframe( api_version: str | None = None, ) -> DataFrame: # TODO: type return + import pandas as pd + + try: + polars_installed = True + import polars as pl + except ModuleNotFoundError: + polars_installed = False + if isinstance(df, pd.DataFrame): + import dataframe_api_compat.pandas_standard + return ( dataframe_api_compat.pandas_standard.convert_to_standard_compliant_dataframe( df, api_version=api_version, ) ) - elif isinstance(df, (pl.DataFrame, pl.LazyFrame)): + elif polars_installed and isinstance(df, (pl.DataFrame, pl.LazyFrame)): + import dataframe_api_compat.polars_standard + df_lazy = df.lazy() if isinstance(df, pl.DataFrame) else df return ( dataframe_api_compat.polars_standard.convert_to_standard_compliant_dataframe( @@ -47,316 +183,163 @@ def convert_to_standard_compliant_dataframe( raise AssertionError(msg) -def integer_dataframe_1(library: str, api_version: str | None = None) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, dtype="int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, dtype="Int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) - - -def integer_dataframe_2(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1, 2, 4], "b": [4, 2, 6]}, dtype="int64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [1, 2, 4], "b": [4, 2, 6]}, dtype="Int64") - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 2, 4], "b": [4, 2, 6]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def integer_dataframe_1( + library: BaseHandler, + api_version: str | None = None, +) -> DataFrame: + return library.create_dataframe( + {"a": [1, 2, 3], "b": [4, 5, 6]}, + api_version=api_version, + ) -def integer_dataframe_3(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - {"a": [1, 2, 3, 4, 5, 6, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, - dtype="int64", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame( - {"a": [1, 2, 3, 4, 5, 6, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, - dtype="Int64", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7], "b": [7, 6, 5, 4, 3, 2, 1]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def integer_dataframe_2(library: BaseHandler) -> DataFrame: + return library.create_dataframe( + {"a": [1, 2, 4], "b": [4, 2, 6]}, + ) -def integer_dataframe_4(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - {"key": [1, 1, 2, 2], "b": [1, 2, 3, 4], "c": [4, 5, 6, 7]}, - dtype="int64", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame( - {"key": [1, 1, 2, 2], "b": [1, 2, 3, 4], "c": [4, 5, 6, 7]}, - dtype="Int64", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"key": [1, 1, 2, 2], "b": [1, 2, 3, 4], "c": [4, 5, 6, 7]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def integer_dataframe_3(library: BaseHandler) -> DataFrame: + return library.create_dataframe( + {"a": [1, 2, 3, 4, 5, 6, 7], "b": [7, 6, 5, 4, 3, 2, 1]}, + ) -def integer_dataframe_5(library: str, api_version: str | None = None) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1, 1], "b": [4, 3]}, dtype="int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [1, 1], "b": [4, 3]}, dtype="Int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 1], "b": [4, 3]}) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) - - -def integer_dataframe_6(library: str, api_version: str | None = None) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1, 1, 1, 2, 2], "b": [4, 4, 3, 1, 2]}, dtype="int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [1, 1, 1, 2, 2], "b": [4, 4, 3, 1, 2]}, dtype="Int64") - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 1, 1, 2, 2], "b": [4, 4, 3, 1, 2]}) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) - - -def integer_dataframe_7(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 4]}, dtype="int64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 4]}, dtype="Int64") - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 4]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def integer_dataframe_4(library: BaseHandler) -> DataFrame: + return library.create_dataframe( + {"key": [1, 1, 2, 2], "b": [1, 2, 3, 4], "c": [4, 5, 6, 7]}, + ) -def nan_dataframe_1(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1.0, 2.0, float("nan")]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": +def integer_dataframe_5( + library: BaseHandler, + api_version: str | None = None, +) -> DataFrame: + return library.create_dataframe( + {"a": [1, 1], "b": [4, 3]}, + api_version=api_version, + ) + + +def integer_dataframe_6( + library: BaseHandler, + api_version: str | None = None, +) -> DataFrame: + return library.create_dataframe( + {"a": [1, 1, 1, 2, 2], "b": [4, 4, 3, 1, 2]}, + api_version=api_version, + ) + + +def integer_dataframe_7(library: BaseHandler) -> DataFrame: + return library.create_dataframe({"a": [1, 2, 3], "b": [1, 2, 4]}) + + +def nan_dataframe_1(library: BaseHandler) -> DataFrame: + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame({"a": [1.0, 2.0, 0.0]}, dtype="Float64") other = pd.DataFrame({"a": [1.0, 1.0, 0.0]}, dtype="Float64") return convert_to_standard_compliant_dataframe(df / other) - if library == "polars-lazy": - df = pl.DataFrame({"a": [1.0, 2.0, float("nan")]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + return library.create_dataframe({"a": [1.0, 2.0, float("nan")]}) -def nan_dataframe_2(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [0.0, 1.0, float("nan")]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": +def nan_dataframe_2(library: BaseHandler) -> DataFrame: + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame({"a": [0.0, 1.0, 0.0]}, dtype="Float64") other = pd.DataFrame({"a": [1.0, 1.0, 0.0]}, dtype="Float64") return convert_to_standard_compliant_dataframe(df / other) - if library == "polars-lazy": - df = pl.DataFrame({"a": [0.0, 1.0, float("nan")]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + return library.create_dataframe({"a": [0.0, 1.0, float("nan")]}) -def null_dataframe_1(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [1.0, 2.0, float("nan")]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": +def null_dataframe_1(library: BaseHandler) -> DataFrame: + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame({"a": [1.0, 2.0, pd.NA]}, dtype="Float64") return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": + if library.name == "polars-lazy": + import polars as pl + df = pl.DataFrame({"a": [1.0, 2.0, None]}) return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + return library.create_dataframe({"a": [1.0, 2.0, float("nan")]}) -def null_dataframe_2(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - {"a": [1.0, -1.0, float("nan")], "b": [1.0, -1.0, float("nan")]}, - dtype="float64", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": +def null_dataframe_2(library: BaseHandler) -> DataFrame: + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame( {"a": [1.0, 0.0, pd.NA], "b": [1.0, 1.0, pd.NA]}, dtype="Float64", ) return convert_to_standard_compliant_dataframe(df / df) - if library == "polars-lazy": + if library.name == "polars-lazy": + import polars as pl + df = pl.DataFrame({"a": [1.0, float("nan"), None], "b": [1.0, 1.0, None]}) return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + return library.create_dataframe( + {"a": [1.0, -1.0, float("nan")], "b": [1.0, -1.0, float("nan")]}, + ) -def bool_dataframe_1(library: str, api_version: str = "2023.09-beta") -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - {"a": [True, True, False], "b": [True, True, True]}, - dtype="bool", - ) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "pandas-nullable": - df = pd.DataFrame( - {"a": [True, True, False], "b": [True, True, True]}, - dtype="boolean", - ) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - if library == "polars-lazy": - df = pl.DataFrame({"a": [True, True, False], "b": [True, True, True]}) - return convert_to_standard_compliant_dataframe(df, api_version=api_version) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def bool_dataframe_1( + library: BaseHandler, + api_version: str = "2023.09-beta", +) -> DataFrame: + return library.create_dataframe( + {"a": [True, True, False], "b": [True, True, True]}, + api_version=api_version, + ) -def bool_dataframe_2(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - { - "key": [1, 1, 2, 2], - "b": [False, True, True, True], - "c": [True, False, False, False], - }, - ).astype({"key": "int64", "b": "bool", "c": "bool"}) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame( - { - "key": [1, 1, 2, 2], - "b": [False, True, True, True], - "c": [True, False, False, False], - }, - ).astype({"key": "Int64", "b": "boolean", "c": "boolean"}) - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame( - { - "key": [1, 1, 2, 2], - "b": [False, True, True, True], - "c": [True, False, False, False], - }, - ) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def bool_dataframe_2(library: BaseHandler) -> DataFrame: + return library.create_dataframe( + { + "key": [1, 1, 2, 2], + "b": [False, True, True, True], + "c": [True, False, False, False], + }, + ) -def bool_dataframe_3(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame( - {"a": [False, False], "b": [False, True], "c": [True, True]}, - dtype="bool", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame( - {"a": [False, False], "b": [False, True], "c": [True, True]}, - dtype="boolean", - ) - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"a": [False, False], "b": [False, True], "c": [True, True]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def bool_dataframe_3(library: BaseHandler) -> DataFrame: + return library.create_dataframe( + {"a": [False, False], "b": [False, True], "c": [True, True]}, + ) -def float_dataframe_1(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [2.0, 3.0]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [2.0, 3.0]}, dtype="Float64") - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame({"a": [2.0, 3.0]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def float_dataframe_1(library: BaseHandler) -> DataFrame: + return library.create_dataframe({"a": [2.0, 3.0]}) -def float_dataframe_2(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [2.0, 1.0]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": - df = pd.DataFrame({"a": [2.0, 1.0]}, dtype="Float64") - return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": # pragma: no cover - df = pl.DataFrame({"a": [2.0, 1.0]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) +def float_dataframe_2(library: BaseHandler) -> DataFrame: + return library.create_dataframe({"a": [2.0, 1.0]}) -def float_dataframe_3(library: str) -> DataFrame: - df: Any - if library == "pandas-numpy": - df = pd.DataFrame({"a": [float("nan"), 2.0]}, dtype="float64") - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": +def float_dataframe_3(library: BaseHandler) -> DataFrame: + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame({"a": [0.0, 2.0]}, dtype="Float64") other = pd.DataFrame({"a": [0.0, 1.0]}, dtype="Float64") return convert_to_standard_compliant_dataframe(df / other) - if library == "polars-lazy": # pragma: no cover - df = pl.DataFrame({"a": [float("nan"), 2.0]}) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + return library.create_dataframe({"a": [float("nan"), 2.0]}) + +def temporal_dataframe_1(library: BaseHandler) -> DataFrame: + if library.name in ["pandas-numpy", "pandas-nullable"]: + import pandas as pd -def temporal_dataframe_1(library: str) -> DataFrame: - if library in ["pandas-numpy", "pandas-nullable"]: df = pd.DataFrame( { + # the data for column "a" differs from other implementations due to pandas 1.5 compat + # https://github.com/data-apis/dataframe-api-compat/commit/aeca5cf1a052033b72388e3f87ad8b70d66cedec "a": [ datetime(2020, 1, 1, 1, 2, 1, 123000), datetime(2020, 1, 2, 3, 1, 2, 321000), @@ -400,7 +383,9 @@ def temporal_dataframe_1(library: str) -> DataFrame: }, ) return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": + if library.name == "polars-lazy": + import polars as pl + df = pl.DataFrame( { "a": [ @@ -446,8 +431,42 @@ def temporal_dataframe_1(library: str) -> DataFrame: }, ) return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + + return library.create_dataframe( + { + "a": [ + datetime(2020, 1, 1, 1, 2, 1, 123000), + datetime(2020, 1, 2, 3, 1, 2, 321000), + datetime(2020, 1, 3, 5, 4, 9, 987000), + ], + "b": [ + timedelta(1, milliseconds=1), + timedelta(2, milliseconds=3), + timedelta(3, milliseconds=5), + ], + "c": [ + datetime(2020, 1, 1, 1, 2, 1, 123543), + datetime(2020, 1, 2, 3, 1, 2, 321654), + datetime(2020, 1, 3, 5, 4, 9, 987321), + ], + "d": [ + timedelta(1, milliseconds=1), + timedelta(2, milliseconds=3), + timedelta(3, milliseconds=5), + ], + "e": [ + datetime(2020, 1, 1, 1, 2, 1, 123543), + datetime(2020, 1, 2, 3, 1, 2, 321654), + datetime(2020, 1, 3, 5, 4, 9, 987321), + ], + "f": [ + timedelta(1, milliseconds=1), + timedelta(2, milliseconds=3), + timedelta(3, milliseconds=5), + ], + "index": [0, 1, 2], + }, + ) def compare_column_with_reference( @@ -455,7 +474,9 @@ def compare_column_with_reference( reference: list[Any], dtype: Any, ) -> None: - column = column.persist() + with contextlib.suppress(UserWarning): + # the comparison should work regardless of whether method `persist` has already been called or not + column = column.persist() col_len = column.len().scalar assert col_len == len(reference), f"column length: {col_len} != {len(reference)}" assert isinstance( @@ -494,7 +515,7 @@ def compare_dataframe_with_reference( ) -def mixed_dataframe_1(library: str) -> DataFrame: +def mixed_dataframe_1(library: BaseHandler) -> DataFrame: df: Any data = { "a": [1, 2, 3], @@ -515,30 +536,9 @@ def mixed_dataframe_1(library: str) -> DataFrame: "p": [timedelta(days=1), timedelta(days=2), timedelta(days=3)], "q": [timedelta(days=1), timedelta(days=2), timedelta(days=3)], } - if library == "pandas-numpy": - df = pd.DataFrame(data).astype( - { - "a": "int64", - "b": "int32", - "c": "int16", - "d": "int8", - "e": "uint64", - "f": "uint32", - "g": "uint16", - "h": "uint8", - "i": "float64", - "j": "float32", - "k": "bool", - "l": "object", - "m": "datetime64[s]", - "n": "datetime64[ms]", - "o": "datetime64[us]", - "p": "timedelta64[ms]", - "q": "timedelta64[us]", - }, - ) - return convert_to_standard_compliant_dataframe(df) - if library == "pandas-nullable": + if library.name == "pandas-nullable": + import pandas as pd + df = pd.DataFrame(data).astype( { "a": "Int64", @@ -561,29 +561,26 @@ def mixed_dataframe_1(library: str) -> DataFrame: }, ) return convert_to_standard_compliant_dataframe(df) - if library == "polars-lazy": - df = pl.DataFrame( - data, - schema={ - "a": pl.Int64, - "b": pl.Int32, - "c": pl.Int16, - "d": pl.Int8, - "e": pl.UInt64, - "f": pl.UInt32, - "g": pl.UInt16, - "h": pl.UInt8, - "i": pl.Float64, - "j": pl.Float32, - "k": pl.Boolean, - "l": pl.Utf8, - "m": pl.Datetime("ms"), - "n": pl.Datetime("ms"), - "o": pl.Datetime("us"), - "p": pl.Duration("ms"), - "q": pl.Duration("us"), - }, - ) - return convert_to_standard_compliant_dataframe(df) - msg = f"Got unexpected library: {library}" # pragma: no cover - raise AssertionError(msg) + + result = library.create_dataframe(data) + ns = result.__dataframe_namespace__() + dtypes: Mapping[str, DType] = { + "a": ns.Int64(), + "b": ns.Int32(), + "c": ns.Int16(), + "d": ns.Int8(), + "e": ns.UInt64(), + "f": ns.UInt32(), + "g": ns.UInt16(), + "h": ns.UInt8(), + "i": ns.Float64(), + "j": ns.Float32(), + "k": ns.Bool(), + "l": ns.String(), + "m": ns.Datetime("ms"), + "n": ns.Datetime("ms"), + "o": ns.Datetime("us"), + "p": ns.Duration("ms"), + "q": ns.Duration("us"), + } + return result.cast(dtypes)