diff --git a/tests/conftest.py b/tests/conftest.py index f46164fa9d..2597548324 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Any from typing import Callable from typing import Sequence +from typing import cast import pytest @@ -15,10 +16,19 @@ if TYPE_CHECKING: import duckdb + import pandas as pd import polars as pl + import pyarrow as pa + from pyspark.sql import DataFrame as PySparkDataFrame + from typing_extensions import TypeAlias + + from narwhals._spark_like.dataframe import SQLFrameDataFrame + from narwhals.typing import NativeFrame + from narwhals.typing import NativeLazyFrame + from tests.utils import Constructor + from tests.utils import ConstructorEager - from narwhals.typing import IntoDataFrame - from narwhals.typing import IntoFrame + Data: TypeAlias = "dict[str, list[Any]]" MIN_PANDAS_NULLABLE_VERSION = (2,) @@ -69,59 +79,60 @@ def pytest_collection_modifyitems( item.add_marker(skip_slow) -def pandas_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: +def pandas_constructor(obj: Data) -> pd.DataFrame: import pandas as pd return pd.DataFrame(obj) -def pandas_nullable_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: +def pandas_nullable_constructor(obj: Data) -> pd.DataFrame: import pandas as pd return pd.DataFrame(obj).convert_dtypes(dtype_backend="numpy_nullable") -def pandas_pyarrow_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: +def pandas_pyarrow_constructor(obj: Data) -> pd.DataFrame: import pandas as pd return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") -def modin_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: # pragma: no cover +def modin_constructor(obj: Data) -> NativeFrame: # pragma: no cover import modin.pandas as mpd import pandas as pd - return mpd.DataFrame(pd.DataFrame(obj)) # type: ignore[no-any-return] + df = mpd.DataFrame(pd.DataFrame(obj)) + return cast("NativeFrame", df) -def modin_pyarrow_constructor( - obj: dict[str, list[Any]], -) -> IntoDataFrame: # pragma: no cover +def modin_pyarrow_constructor(obj: Data) -> NativeFrame: # pragma: no cover import modin.pandas as mpd import pandas as pd - return mpd.DataFrame(pd.DataFrame(obj)).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] + df = mpd.DataFrame(pd.DataFrame(obj)).convert_dtypes(dtype_backend="pyarrow") + return cast("NativeFrame", df) -def cudf_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: # pragma: no cover +def cudf_constructor(obj: Data) -> NativeFrame: # pragma: no cover import cudf - return cudf.DataFrame(obj) # type: ignore[no-any-return] + df = cudf.DataFrame(obj) + return cast("NativeFrame", df) -def polars_eager_constructor(obj: dict[str, list[Any]]) -> IntoDataFrame: +def polars_eager_constructor(obj: Data) -> pl.DataFrame: import polars as pl return pl.DataFrame(obj) -def polars_lazy_constructor(obj: dict[str, list[Any]]) -> pl.LazyFrame: +def polars_lazy_constructor(obj: Data) -> pl.LazyFrame: import polars as pl return pl.LazyFrame(obj) -def duckdb_lazy_constructor(obj: dict[str, list[Any]]) -> duckdb.DuckDBPyRelation: +def duckdb_lazy_constructor(obj: Data) -> duckdb.DuckDBPyRelation: import duckdb import polars as pl @@ -129,43 +140,40 @@ def duckdb_lazy_constructor(obj: dict[str, list[Any]]) -> duckdb.DuckDBPyRelatio return duckdb.table("_df") -def dask_lazy_p1_constructor(obj: dict[str, list[Any]]) -> IntoFrame: # pragma: no cover +def dask_lazy_p1_constructor(obj: Data) -> NativeLazyFrame: # pragma: no cover import dask.dataframe as dd - return dd.from_dict(obj, npartitions=1) # type: ignore[no-any-return] + return cast("NativeLazyFrame", dd.from_dict(obj, npartitions=1)) -def dask_lazy_p2_constructor(obj: dict[str, list[Any]]) -> IntoFrame: # pragma: no cover +def dask_lazy_p2_constructor(obj: Data) -> NativeLazyFrame: # pragma: no cover import dask.dataframe as dd - return dd.from_dict(obj, npartitions=2) # type: ignore[no-any-return] + return cast("NativeLazyFrame", dd.from_dict(obj, npartitions=2)) -def pyarrow_table_constructor(obj: dict[str, Any]) -> IntoDataFrame: +def pyarrow_table_constructor(obj: dict[str, Any]) -> pa.Table: import pyarrow as pa return pa.table(obj) -def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cover - try: - from pyspark.sql import SparkSession - except ImportError: # pragma: no cover - pytest.skip("pyspark is not installed") - return None - +def pyspark_lazy_constructor() -> Callable[[Data], PySparkDataFrame]: # pragma: no cover + pytest.importorskip("pyspark") import warnings from atexit import register + from pyspark.sql import SparkSession + with warnings.catch_warnings(): # The spark session seems to trigger a polars warning. # Polars is imported in the tests, but not used in the spark operations warnings.filterwarnings( "ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning ) - + builder = cast("SparkSession.Builder", SparkSession.builder) session = ( - SparkSession.builder.appName("unit-tests") # pyright: ignore[reportAttributeAccessIssue] + builder.appName("unit-tests") .master("local[1]") .config("spark.ui.enabled", "false") # executing one task at a time makes the tests faster @@ -178,12 +186,12 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove register(session.stop) - def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: + def _constructor(obj: Data) -> PySparkDataFrame: _obj = deepcopy(obj) index_col_name = generate_temporary_column_name(n_bytes=8, columns=list(_obj)) _obj[index_col_name] = list(range(len(_obj[next(iter(_obj))]))) - return ( # type: ignore[no-any-return] + return ( session.createDataFrame([*zip(*_obj.values())], schema=[*_obj.keys()]) .repartition(2) .orderBy(index_col_name) @@ -193,16 +201,14 @@ def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: return _constructor -def sqlframe_pyspark_lazy_constructor( - obj: dict[str, Any], -) -> IntoFrame: # pragma: no cover +def sqlframe_pyspark_lazy_constructor(obj: Data) -> SQLFrameDataFrame: # pragma: no cover from sqlframe.duckdb import DuckDBSession session = DuckDBSession() return session.createDataFrame([*zip(*obj.values())], schema=[*obj.keys()]) -EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = { +EAGER_CONSTRUCTORS: dict[str, ConstructorEager] = { "pandas": pandas_constructor, "pandas[nullable]": pandas_nullable_constructor, "pandas[pyarrow]": pandas_pyarrow_constructor, @@ -212,14 +218,14 @@ def sqlframe_pyspark_lazy_constructor( "cudf": cudf_constructor, "polars[eager]": polars_eager_constructor, } -LAZY_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = { +LAZY_CONSTRUCTORS: dict[str, Constructor] = { "dask": dask_lazy_p2_constructor, "polars[lazy]": polars_lazy_constructor, "duckdb": duckdb_lazy_constructor, "pyspark": pyspark_lazy_constructor, # type: ignore[dict-item] "sqlframe": sqlframe_pyspark_lazy_constructor, } -GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor} +GPU_CONSTRUCTORS: dict[str, ConstructorEager] = {"cudf": cudf_constructor} def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: @@ -234,11 +240,12 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if x not in GPU_CONSTRUCTORS and x != "modin" # too slow ] else: # pragma: no cover - selected_constructors = metafunc.config.getoption("constructors").split(",") # pyright: ignore[reportAttributeAccessIssue] + opt = cast("str", metafunc.config.getoption("constructors")) + selected_constructors = opt.split(",") - eager_constructors: list[Callable[[Any], IntoDataFrame]] = [] + eager_constructors: list[ConstructorEager] = [] eager_constructors_ids: list[str] = [] - constructors: list[Callable[[Any], IntoFrame]] = [] + constructors: list[Constructor] = [] constructors_ids: list[str] = [] for constructor in selected_constructors: diff --git a/tests/series_only/hist_test.py b/tests/series_only/hist_test.py index 86ff012ff6..f28747ccde 100644 --- a/tests/series_only/hist_test.py +++ b/tests/series_only/hist_test.py @@ -375,10 +375,9 @@ def test_hist_non_monotonic(constructor_eager: ConstructorEager) -> None: ) @pytest.mark.slow def test_hist_bin_hypotheis( - constructor_eager: ConstructorEager, - data: list[float], - bin_deltas: list[float], + constructor_eager: ConstructorEager, data: list[float], bin_deltas: list[float] ) -> None: + pytest.importorskip("polars") import polars as pl if "cudf" in str(constructor_eager): @@ -388,12 +387,13 @@ def test_hist_bin_hypotheis( df = nw.from_native(constructor_eager({"values": data})).select( nw.col("values").cast(nw.Float64) ) + df_bins_native = constructor_eager({"bins": bin_deltas}) bins = ( - nw.from_native(constructor_eager({"bins": bin_deltas})["bins"], series_only=True) # type:ignore[index] + nw.from_native(df_bins_native, eager_only=True) + .get_column("bins") .cast(nw.Float64) .cum_sum() ) - result = df["values"].hist( bins=bins.to_list(), include_breakpoint=True, @@ -404,7 +404,6 @@ def test_hist_bin_hypotheis( include_breakpoint=True, include_category=False, ) - assert_equal_data(result, expected.to_dict(as_series=False)) diff --git a/tests/utils.py b/tests/utils.py index 42d70db18a..4469dcf0fe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,14 +17,16 @@ import narwhals as nw from narwhals.translate import from_native -from narwhals.typing import IntoDataFrame -from narwhals.typing import IntoFrame from narwhals.utils import Implementation from narwhals.utils import parse_version if TYPE_CHECKING: from typing_extensions import TypeAlias + from narwhals.typing import DataFrameLike + from narwhals.typing import NativeFrame + from narwhals.typing import NativeLazyFrame + def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: try: @@ -42,8 +44,8 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: PYARROW_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyarrow") PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark") -Constructor: TypeAlias = Callable[[Any], IntoFrame] -ConstructorEager: TypeAlias = Callable[[Any], IntoDataFrame] +Constructor: TypeAlias = Callable[[Any], "NativeLazyFrame | NativeFrame | DataFrameLike"] +ConstructorEager: TypeAlias = Callable[[Any], "NativeFrame | DataFrameLike"] def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: