diff --git a/docs/api-reference/expr_list.md b/docs/api-reference/expr_list.md index 84fb831c50..f44a25d751 100644 --- a/docs/api-reference/expr_list.md +++ b/docs/api-reference/expr_list.md @@ -7,6 +7,11 @@ - contains - get - len + - max + - mean + - median + - min + - sum - unique show_source: false show_bases: false diff --git a/docs/api-reference/series_list.md b/docs/api-reference/series_list.md index 7590732dee..39adbad185 100644 --- a/docs/api-reference/series_list.md +++ b/docs/api-reference/series_list.md @@ -7,6 +7,11 @@ - contains - get - len + - max + - mean + - median + - min + - sum - unique show_source: false show_bases: false diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index defad3dad6..25e598aedd 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -5,7 +5,7 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals._arrow.utils import ArrowSeriesNamespace +from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg from narwhals._compliant.any_namespace import ListNamespace from narwhals._utils import not_implemented @@ -20,5 +20,20 @@ def len(self) -> ArrowSeries: def get(self, index: int) -> ArrowSeries: return self.with_native(pc.list_element(self.native, index)) + def min(self) -> ArrowSeries: + return self.with_native(list_agg(self.native, "min")) + + def max(self) -> ArrowSeries: + return self.with_native(list_agg(self.native, "max")) + + def mean(self) -> ArrowSeries: + return self.with_native(list_agg(self.native, "mean")) + + def median(self) -> ArrowSeries: + return self.with_native(list_agg(self.native, "approximate_median")) + + def sum(self) -> ArrowSeries: + return self.with_native(list_agg(self.native, "sum")) + unique = not_implemented() contains = not_implemented() diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 46b5985e1d..dbd8aa6c62 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping + from typing import Literal from typing_extensions import TypeAlias, TypeIs @@ -494,3 +495,40 @@ def arange(start: int, end: int, step: int) -> ArrayAny: return pa.array(np.arange(start, end, step)) # NOTE: Added in https://github.com/apache/arrow/pull/46778 return pa.arange(start, end, step) # type: ignore[attr-defined] + + +def list_agg( + array: ChunkedArrayAny, + func: Literal["min", "max", "mean", "approximate_median", "sum"], +) -> ChunkedArrayAny: + lit_: Incomplete = lit + aggregation = ( + ("values", func, pc.ScalarAggregateOptions(min_count=0)) + if func == "sum" + else ("values", func) + ) + agg = pa.array( + pa.Table.from_arrays( + [pc.list_flatten(array), pc.list_parent_indices(array)], + names=["values", "offsets"], + ) + .group_by("offsets") + .aggregate([aggregation]) + .sort_by("offsets") + .column(f"values_{func}") + ) + non_empty_mask = pa.array(pc.not_equal(pc.list_value_length(array), lit(0))) + if func == "sum": + # Make sure sum of empty list is 0. + base_array = pc.if_else(non_empty_mask.is_null(), None, 0) + else: + base_array = pa.repeat(lit_(None, type=agg.type), len(array)) + return pa.chunked_array( + [ + pc.replace_with_mask( + base_array, + non_empty_mask.fill_null(False), # type: ignore[arg-type] + agg, + ) + ] + ) diff --git a/narwhals/_compliant/any_namespace.py b/narwhals/_compliant/any_namespace.py index b7e48a273f..cd76470d89 100644 --- a/narwhals/_compliant/any_namespace.py +++ b/narwhals/_compliant/any_namespace.py @@ -70,6 +70,11 @@ def get(self, index: int) -> CompliantT_co: ... def len(self) -> CompliantT_co: ... def unique(self) -> CompliantT_co: ... def contains(self, item: NonNestedLiteral) -> CompliantT_co: ... + def min(self) -> CompliantT_co: ... + def max(self) -> CompliantT_co: ... + def mean(self) -> CompliantT_co: ... + def median(self) -> CompliantT_co: ... + def sum(self) -> CompliantT_co: ... class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]): diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index ec59f604dd..8aa52f8d61 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -999,6 +999,21 @@ def contains(self, item: NonNestedLiteral) -> EagerExprT: def get(self, index: int) -> EagerExprT: return self.compliant._reuse_series_namespace("list", "get", index=index) + def min(self) -> EagerExprT: + return self.compliant._reuse_series_namespace("list", "min") + + def max(self) -> EagerExprT: + return self.compliant._reuse_series_namespace("list", "max") + + def mean(self) -> EagerExprT: + return self.compliant._reuse_series_namespace("list", "mean") + + def median(self) -> EagerExprT: + return self.compliant._reuse_series_namespace("list", "median") + + def sum(self) -> EagerExprT: + return self.compliant._reuse_series_namespace("list", "sum") + class CompliantExprNameNamespace( # type: ignore[misc] _ExprNamespace[CompliantExprT_co], diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index b726f2fc78..15b3f70a53 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -4,7 +4,7 @@ from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import ListNamespace -from narwhals._duckdb.utils import F, lit, when +from narwhals._duckdb.utils import F, col, lambda_expr, lit, when from narwhals._utils import requires if TYPE_CHECKING: @@ -40,3 +40,27 @@ def get(self, index: int) -> DuckDBExpr: return self.compliant._with_elementwise( lambda expr: F("list_extract", expr, lit(index + 1)) ) + + def min(self) -> DuckDBExpr: + return self.compliant._with_elementwise(lambda expr: F("list_min", expr)) + + def max(self) -> DuckDBExpr: + return self.compliant._with_elementwise(lambda expr: F("list_max", expr)) + + def mean(self) -> DuckDBExpr: + return self.compliant._with_elementwise(lambda expr: F("list_avg", expr)) + + def median(self) -> DuckDBExpr: + return self.compliant._with_elementwise(lambda expr: F("list_median", expr)) + + @requires.backend_version((1, 2)) + def sum(self) -> DuckDBExpr: + def func(expr: Expression) -> Expression: + elem = col("_") + expr_no_nulls = F("list_filter", expr, lambda_expr(elem, elem.isnotnull())) + expr_sum = F("list_sum", expr_no_nulls) + return when(F("array_length", expr_no_nulls) == lit(0), lit(0)).otherwise( + expr_sum + ) + + return self.compliant._with_callable(func) diff --git a/narwhals/_ibis/expr_list.py b/narwhals/_ibis/expr_list.py index 8070769308..64cc053831 100644 --- a/narwhals/_ibis/expr_list.py +++ b/narwhals/_ibis/expr_list.py @@ -2,8 +2,11 @@ from typing import TYPE_CHECKING +from ibis import cases, literal + from narwhals._compliant import LazyExprNamespace from narwhals._compliant.any_namespace import ListNamespace +from narwhals._utils import not_implemented if TYPE_CHECKING: import ibis.expr.types as ir @@ -27,3 +30,26 @@ def _get(expr: ir.ArrayColumn) -> ir.Column: return expr[index] return self.compliant._with_callable(_get) + + def min(self) -> IbisExpr: + return self.compliant._with_callable(lambda expr: expr.mins()) + + def max(self) -> IbisExpr: + return self.compliant._with_callable(lambda expr: expr.maxs()) + + def mean(self) -> IbisExpr: + return self.compliant._with_callable(lambda expr: expr.means()) + + def sum(self) -> IbisExpr: + def func(expr: ir.ArrayColumn) -> ir.Value: + expr_no_nulls = expr.filter(lambda x: x.notnull()) + len = expr_no_nulls.length() + return cases( + (len.isnull(), literal(None)), + (len == literal(0), literal(0)), + else_=expr.sums(), + ) + + return self.compliant._with_callable(func) + + median = not_implemented() diff --git a/narwhals/_pandas_like/series_list.py b/narwhals/_pandas_like/series_list.py index 2d087493df..d1fb085262 100644 --- a/narwhals/_pandas_like/series_list.py +++ b/narwhals/_pandas_like/series_list.py @@ -11,6 +11,8 @@ from narwhals._utils import not_implemented if TYPE_CHECKING: + from typing import Literal + from narwhals._pandas_like.series import PandasLikeSeries @@ -40,3 +42,41 @@ def get(self, index: int) -> PandasLikeSeries: result = self.native.list[index] result.name = self.native.name return self.with_native(result) + + def _agg( + self, func: Literal["min", "max", "mean", "approximate_median", "sum"] + ) -> PandasLikeSeries: + dtype_backend = get_dtype_backend( + self.native.dtype, self.compliant._implementation + ) + if dtype_backend != "pyarrow": # pragma: no cover + msg = "Only pyarrow backend is currently supported." + raise NotImplementedError(msg) + + from narwhals._arrow.utils import list_agg, native_to_narwhals_dtype + + ca = self.native.array._pa_array + result_arr = list_agg(ca, func) + nw_dtype = native_to_narwhals_dtype(result_arr.type, self.version) + out_dtype = narwhals_to_native_dtype( + nw_dtype, "pyarrow", self.implementation, self.version + ) + result_native = type(self.native)( + result_arr, dtype=out_dtype, index=self.native.index, name=self.native.name + ) + return self.with_native(result_native) + + def min(self) -> PandasLikeSeries: + return self._agg("min") + + def max(self) -> PandasLikeSeries: + return self._agg("max") + + def mean(self) -> PandasLikeSeries: + return self._agg("mean") + + def median(self) -> PandasLikeSeries: + return self._agg("approximate_median") + + def sum(self) -> PandasLikeSeries: + return self._agg("sum") diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index d638c791fd..f2081f99dc 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -362,6 +362,16 @@ def len(self) -> CompliantT: ... unique: Method[CompliantT] + max: Method[CompliantT] + + mean: Method[CompliantT] + + median: Method[CompliantT] + + min: Method[CompliantT] + + sum: Method[CompliantT] + class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]): _accessor: ClassVar[Accessor] = "struct" diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py index 31be5f5bb9..1b2ec723bc 100644 --- a/narwhals/_spark_like/expr_list.py +++ b/narwhals/_spark_like/expr_list.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from typing import TYPE_CHECKING from narwhals._compliant import LazyExprNamespace @@ -33,3 +34,51 @@ def _get(expr: Column) -> Column: return expr.getItem(index) return self.compliant._with_elementwise(_get) + + def min(self) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + return F.array_min(expr) + + return self.compliant._with_elementwise(func) + + def max(self) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + return F.array_max(F.array_compact(expr)) + + return self.compliant._with_elementwise(func) + + def sum(self) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + return F.aggregate(F.array_compact(expr), F.lit(0.0), operator.add) + + return self.compliant._with_elementwise(func) + + def mean(self) -> SparkLikeExpr: + def func(expr: Column) -> Column: + F = self.compliant._F + return F.try_divide( + F.aggregate(F.array_compact(expr), F.lit(0.0), operator.add), + F.array_size(F.array_compact(expr)), + ) + + return self.compliant._with_elementwise(func) + + def median(self) -> SparkLikeExpr: + def func(expr: Column) -> Column: # pragma: no cover + # sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548 + F = self.compliant._F + sorted_expr = F.array_compact(F.sort_array(expr)) + size = F.array_size(sorted_expr) + mid_index = (size / 2).cast("int") + odd_case = sorted_expr[mid_index] + even_case = (sorted_expr[mid_index - 1] + sorted_expr[mid_index]) / 2 + return ( + F.when((size.isNull()) | (size == 0), F.lit(None)) + .when(size % 2 == 1, odd_case) + .otherwise(even_case) + ) + + return self.compliant._with_elementwise(func) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index 8f9c94c6ab..dcc54d4404 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -143,3 +143,122 @@ def get(self, index: int) -> ExprT: return self._expr._append_node( ExprNode(ExprKind.ELEMENTWISE, "list.get", index=index) ) + + def min(self) -> ExprT: + """Compute the min value of the lists in the array. + + Examples: + >>> import duckdb + >>> import narwhals as nw + >>> df_native = duckdb.sql("SELECT * FROM VALUES ([1]), ([3, 4, NULL]) df(a)") + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_min=nw.col("a").list.min()) + ┌────────────────────────┐ + | Narwhals LazyFrame | + |------------------------| + |┌──────────────┬───────┐| + |│ a │ a_min │| + |│ int32[] │ int32 │| + |├──────────────┼───────┤| + |│ [1] │ 1 │| + |│ [3, 4, NULL] │ 3 │| + |└──────────────┴───────┘| + └────────────────────────┘ + """ + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.min")) + + def max(self) -> ExprT: + """Compute the max value of the lists in the array. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> df_native = pl.DataFrame({"a": [[1], [3, 4, None]]}) + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_max=nw.col("a").list.max()) + ┌────────────────────────┐ + | Narwhals DataFrame | + |------------------------| + |shape: (2, 2) | + |┌──────────────┬───────┐| + |│ a ┆ a_max │| + |│ --- ┆ --- │| + |│ list[i64] ┆ i64 │| + |╞══════════════╪═══════╡| + |│ [1] ┆ 1 │| + |│ [3, 4, null] ┆ 4 │| + |└──────────────┴───────┘| + └────────────────────────┘ + """ + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.max")) + + def mean(self) -> ExprT: + """Compute the mean value of the lists in the array. + + Examples: + >>> import pyarrow as pa + >>> import narwhals as nw + >>> df_native = pa.table({"a": [[1], [3, 4, None]]}) + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_mean=nw.col("a").list.mean()) + ┌──────────────────────┐ + | Narwhals DataFrame | + |----------------------| + |pyarrow.Table | + |a: list | + | child 0, item: int64| + |a_mean: double | + |---- | + |a: [[[1],[3,4,null]]] | + |a_mean: [[1,3.5]] | + └──────────────────────┘ + """ + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.mean")) + + def median(self) -> ExprT: + """Compute the median value of the lists in the array. + + Examples: + >>> import duckdb + >>> import narwhals as nw + >>> df_native = duckdb.sql("SELECT * FROM VALUES ([1]), ([3, 4, NULL]) df(a)") + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_median=nw.col("a").list.median()) + ┌───────────────────────────┐ + | Narwhals LazyFrame | + |---------------------------| + |┌──────────────┬──────────┐| + |│ a │ a_median │| + |│ int32[] │ double │| + |├──────────────┼──────────┤| + |│ [1] │ 1.0 │| + |│ [3, 4, NULL] │ 3.5 │| + |└──────────────┴──────────┘| + └───────────────────────────┘ + """ + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.median")) + + def sum(self) -> ExprT: + """Compute the sum value of the lists in the array. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> df_native = pl.DataFrame({"a": [[1], [3, 4, None]]}) + >>> df = nw.from_native(df_native) + >>> df.with_columns(a_sum=nw.col("a").list.sum()) + ┌────────────────────────┐ + | Narwhals DataFrame | + |------------------------| + |shape: (2, 2) | + |┌──────────────┬───────┐| + |│ a ┆ a_sum │| + |│ --- ┆ --- │| + |│ list[i64] ┆ i64 │| + |╞══════════════╪═══════╡| + |│ [1] ┆ 1 │| + |│ [3, 4, null] ┆ 7 │| + |└──────────────┴───────┘| + └────────────────────────┘ + """ + return self._expr._append_node(ExprNode(ExprKind.ELEMENTWISE, "list.sum")) diff --git a/narwhals/series_list.py b/narwhals/series_list.py index baa7ed8c8e..d5ead244e1 100644 --- a/narwhals/series_list.py +++ b/narwhals/series_list.py @@ -117,3 +117,105 @@ def get(self, index: int) -> SeriesT: return self._narwhals_series._with_compliant( self._narwhals_series._compliant_series.list.get(index) ) + + def min(self) -> SeriesT: + """Compute the min value of the lists in the array. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> s_native = pl.Series([[1], [3, 4, None]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.min().to_native() # doctest: +NORMALIZE_WHITESPACE + shape: (2,) + Series: '' [i64] + [ + 1 + 3 + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.min() + ) + + def max(self) -> SeriesT: + """Compute the max value of the lists in the array. + + Examples: + >>> import pyarrow as pa + >>> import narwhals as nw + >>> s_native = pa.chunked_array([[[1], [3, 4, None]]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.max().to_native() # doctest: +ELLIPSIS + + [ + [ + 1, + 4 + ] + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.max() + ) + + def mean(self) -> SeriesT: + """Compute the mean value of the lists in the array. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> s_native = pl.Series([[1], [3, 4, None]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.mean().to_native() # doctest: +NORMALIZE_WHITESPACE + shape: (2,) + Series: '' [f64] + [ + 1.0 + 3.5 + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.mean() + ) + + def median(self) -> SeriesT: + """Compute the median value of the lists in the array. + + Examples: + >>> import pyarrow as pa + >>> import narwhals as nw + >>> s_native = pa.chunked_array([[[1], [3, 4, None]]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.median().to_native() # doctest: +ELLIPSIS + + [ + [ + 1, + 3 + ] + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.median() + ) + + def sum(self) -> SeriesT: + """Compute the sum value of the lists in the array. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> s_native = pl.Series([[1], [3, 4, None]]) + >>> s = nw.from_native(s_native, series_only=True) + >>> s.list.sum().to_native() # doctest: +NORMALIZE_WHITESPACE + shape: (2,) + Series: '' [i64] + [ + 1 + 7 + ] + """ + return self._narwhals_series._with_compliant( + self._narwhals_series._compliant_series.list.sum() + ) diff --git a/tests/expr_and_series/list/max_test.py b/tests/expr_and_series/list/max_test.py new file mode 100644 index 0000000000..f3cd5db5a1 --- /dev/null +++ b/tests/expr_and_series/list/max_test.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import PANDAS_VERSION, assert_equal_data + +if TYPE_CHECKING: + from tests.utils import Constructor, ConstructorEager + +data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} +expected = [4, -1, None, None, None] + + +def test_max_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any(backend in str(constructor) for backend in ("dask", "cudf", "sqlframe")): + # sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548 + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.max() + ) + assert_equal_data(result, {"a": expected}) + + +def test_max_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(backend in str(constructor_eager) for backend in ("cudf",)): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.max() + assert_equal_data({"a": result}, {"a": expected}) diff --git a/tests/expr_and_series/list/mean_test.py b/tests/expr_and_series/list/mean_test.py new file mode 100644 index 0000000000..9ff5984b2e --- /dev/null +++ b/tests/expr_and_series/list/mean_test.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import PANDAS_VERSION, assert_equal_data + +if TYPE_CHECKING: + from tests.utils import Constructor, ConstructorEager + +data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} +expected = [2.75, -1, None, None, None] + + +def test_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any(backend in str(constructor) for backend in ("dask", "cudf", "sqlframe")): + # sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548 + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.mean() + ) + assert_equal_data(result, {"a": expected}) + + +def test_mean_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(backend in str(constructor_eager) for backend in ("cudf",)): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.mean() + assert_equal_data({"a": result}, {"a": expected}) diff --git a/tests/expr_and_series/list/median_test.py b/tests/expr_and_series/list/median_test.py new file mode 100644 index 0000000000..b1baa242d7 --- /dev/null +++ b/tests/expr_and_series/list/median_test.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import PANDAS_VERSION, POLARS_VERSION, assert_equal_data, is_windows + +if TYPE_CHECKING: + from tests.utils import Constructor, ConstructorEager + +data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], [], [3, 4, None]]} +expected = [2.5, -1, None, None, None, 3.5] +expected_pyarrow = [2.5, -1, None, None, None, 3] + + +def test_median_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any( + backend in str(constructor) for backend in ("dask", "cudf", "sqlframe", "ibis") + ) or ("polars" in str(constructor) and POLARS_VERSION < (0, 20, 7)): + # sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548 + # ibis issue: https://github.com/ibis-project/ibis/issues/11788 + request.applymarker(pytest.mark.xfail) + if os.environ.get("SPARK_CONNECT", None) and "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + if ( + any(backend in str(constructor) for backend in ("pandas", "pyarrow")) + and sys.version_info < (3, 10) + and is_windows + ): # pragma: no cover + reason = "The issue only affects old Python versions on Windows." + pytest.skip(reason=reason) + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.median() + ) + if any( + backend in str(constructor) + for backend in ("pandas", "pyarrow", "pandas[pyarrow]") + ): + # there is a mismatch as pyarrow uses an approximate median + assert_equal_data(result, {"a": expected_pyarrow}) + else: + assert_equal_data(result, {"a": expected}) + + +def test_median_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(backend in str(constructor_eager) for backend in ("cudf",)) or ( + "polars" in str(constructor_eager) and POLARS_VERSION < (0, 20, 7) + ): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + if ( + any(backend in str(constructor_eager) for backend in ("pandas", "pyarrow")) + and sys.version_info < (3, 10) + and is_windows + ): # pragma: no cover + reason = "The issue only affects old Python versions on Windows." + pytest.skip(reason=reason) + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.median() + if any( + backend in str(constructor_eager) + for backend in ("pandas", "pyarrow", "pandas[pyarrow]") + ): + # there is a mismatch as pyarrow uses an approximate median + assert_equal_data({"a": result}, {"a": expected_pyarrow}) + else: + assert_equal_data({"a": result}, {"a": expected}) diff --git a/tests/expr_and_series/list/min_test.py b/tests/expr_and_series/list/min_test.py new file mode 100644 index 0000000000..2039f7de56 --- /dev/null +++ b/tests/expr_and_series/list/min_test.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import PANDAS_VERSION, assert_equal_data + +if TYPE_CHECKING: + from tests.utils import Constructor, ConstructorEager + +data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} +expected = [2, -1, None, None, None] + + +def test_min_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any(backend in str(constructor) for backend in ("dask", "cudf")): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.min() + ) + assert_equal_data(result, {"a": expected}) + + +def test_min_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(backend in str(constructor_eager) for backend in ("cudf",)): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.min().to_list() + assert_equal_data({"a": result}, {"a": expected}) diff --git a/tests/expr_and_series/list/sum_test.py b/tests/expr_and_series/list/sum_test.py new file mode 100644 index 0000000000..1f0ff7729e --- /dev/null +++ b/tests/expr_and_series/list/sum_test.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +from tests.utils import DUCKDB_VERSION, PANDAS_VERSION, assert_equal_data + +if TYPE_CHECKING: + from tests.utils import Constructor, ConstructorEager + +data = {"a": [[3, None, 2, 2, 4, None], [-1], None, [None, None, None], []]} +expected = [11, -1, None, 0, 0] + + +def test_sum_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if any(backend in str(constructor) for backend in ("dask", "cudf", "sqlframe")): + # sqlframe issue: https://github.com/eakmanrq/sqlframe/issues/548 + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 2): + reason = "version too old, duckdb 1.2 required for LambdaExpression." + pytest.skip(reason=reason) + result = nw.from_native(constructor(data)).select( + nw.col("a").cast(nw.List(nw.Int32())).list.sum() + ) + assert_equal_data(result, {"a": expected}) + + +def test_sum_series( + request: pytest.FixtureRequest, constructor_eager: ConstructorEager +) -> None: + if any(backend in str(constructor_eager) for backend in ("cudf",)): + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor_eager): + if PANDAS_VERSION < (2, 2): + pytest.skip() + pytest.importorskip("pyarrow") + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].cast(nw.List(nw.Int32())).list.sum() + assert_equal_data({"a": result}, {"a": expected})