Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
Expand Down Expand Up @@ -518,12 +517,9 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())

def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
def filter(self, predicate: ArrowExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
Expand Down
11 changes: 4 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
def row(self, index: int) -> tuple[Any, ...]:
return tuple(x for x in self.native.iloc[index])

def filter(self, predicate: PandasLikeExpr | list[bool]) -> Self:
if isinstance(predicate, list):
mask_native: pd.Series[Any] | list[bool] = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
def filter(self, predicate: PandasLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
return self._with_native(
self.native.loc[mask_native], validate_column_names=False
)
Expand Down
39 changes: 20 additions & 19 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,25 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self:
def filter(
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
) -> Self:
if len(predicates) == 1 and is_list_of(predicates[0], bool):
predicate = predicates[0]
else:
from narwhals.functions import col

flat_predicates = flatten(predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx)
for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
from narwhals.functions import col

plx = self.__narwhals_namespace__()
parsed_predicates = tuple(
plx._series.from_iterable(pred, context=plx).to_narwhals()
if is_list_of(pred, bool)
else pred
for pred in predicates
)
flat_predicates = flatten(parsed_predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
return self._with_compliant(self._compliant_frame.filter(predicate))

def sort(
Expand Down Expand Up @@ -2924,9 +2927,7 @@ def filter(
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
<BLANKLINE>
"""
if (
len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints
): # pragma: no cover
Comment on lines -2927 to -2929
Copy link
Member Author

@FBruzzesi FBruzzesi Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not covering full spectrum of possibilities

if any(is_list_of(pred, bool) for pred in predicates): # pragma: no cover
msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead."
raise TypeError(msg)

Expand Down
105 changes: 84 additions & 21 deletions tests/frame/filter_test.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,89 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest

import narwhals as nw
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError
from tests.utils import Constructor, ConstructorEager, assert_equal_data

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}

def test_filter(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}

@pytest.mark.parametrize(
("predicates", "expected"),
[
((nw.col("a") > 1,), {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}),
((nw.col("a") > 1, nw.col("z") < 9.0), {"a": [3], "b": [4], "z": [8.0]}),
],
)
def test_filter_with_expr_predicates(
constructor: Constructor,
predicates: tuple[nw.Expr, ...],
expected: dict[str, list[Any]],
) -> None:
df = nw.from_native(constructor(data))
result = df.filter(nw.col("a") > 1)
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
result = df.filter(*predicates)
assert_equal_data(result, expected)


def test_filter_with_series(constructor_eager: ConstructorEager) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
def test_filter_with_series_predicates(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.filter(df["a"] > 1)
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
assert_equal_data(result, expected)

result = df.filter(df["a"] > 1, df["b"] < 6)
expected = {"a": [3], "b": [4], "z": [8.0]}
assert_equal_data(result, expected)


def test_filter_with_boolean_list(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
@pytest.mark.parametrize(
("predicates", "expected"),
[
(([False, True, True],), {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}),
(([True, True, False], [False, True, True]), {"a": [3], "b": [4], "z": [8.0]}),
],
)
def test_filter_with_boolean_list_predicates(
constructor: Constructor,
predicates: tuple[list[bool], ...],
expected: dict[str, list[Any]],
) -> None:
df = nw.from_native(constructor(data))
context = (
pytest.raises(TypeError, match="not supported")
if isinstance(df, nw.LazyFrame)
else does_not_raise()
)
with context:
result = df.filter([False, True, True])
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
result = df.filter(*predicates)
assert_equal_data(result, expected)


def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df = nw.from_native(constructor(data))
with pytest.raises(InvalidOperationError):
df.filter(nw.col("a").max() > 2).lazy().collect()


def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df = nw.from_native(constructor(data))
with pytest.raises(InvalidOperationError):
df.filter(nw.col("b").unique() > 2).lazy().collect()


def test_filter_with_constrains(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6]}
def test_filter_with_constrains_only(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result_scalar = df.filter(a=3)
expected_scalar = {"a": [3], "b": [4]}
expected_scalar = {"a": [3], "b": [4], "z": [8.0]}

assert_equal_data(result_scalar, expected_scalar)

result_expr = df.filter(a=nw.col("b") // 3)
expected_expr = {"a": [1, 2], "b": [4, 6]}
expected_expr = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}

assert_equal_data(result_expr, expected_expr)

Expand All @@ -73,17 +94,16 @@ def test_filter_missing_column(
constructor_id = str(request.node.callspec.id)
if any(id_ == constructor_id for id_ in ("sqlframe", "pyspark[connect]", "ibis")):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2], "b": [3, 4]}
df = nw.from_native(constructor(data))

df = nw.from_native(constructor(data))
if "polars" in str(constructor):
msg = r"unable to find column \"c\"; valid columns: \[\"a\", \"b\"\]"
msg = r"unable to find column \"c\"; valid columns: \[\"a\", \"b\"\, \"z\"\]"
elif any(id_ == constructor_id for id_ in ("duckdb", "pyspark")):
msg = r"\n\nHint: Did you mean one of these columns: \['a', 'b'\]?"
msg = r"\n\nHint: Did you mean one of these columns: \['a', 'b', 'z'\]?"
else:
msg = (
r"The following columns were not found: \[.*\]"
r"\n\nHint: Did you mean one of these columns: \['a', 'b'\]?"
r"\n\nHint: Did you mean one of these columns: \['a', 'b', 'z'\]?"
)

if "polars_lazy" in str(constructor) and isinstance(df, nw.LazyFrame):
Expand All @@ -92,3 +112,46 @@ def test_filter_missing_column(
else:
with pytest.raises(ColumnNotFoundError, match=msg):
df.filter(c=5)


def test_filter_with_predicates_and_contraints(
constructor_eager: ConstructorEager,
) -> None:
# Adapted from https://github.com/narwhals-dev/narwhals/pull/3173/commits/8433b2d75438df98004a3c850ad23628e2376836
df = nw.from_native(constructor_eager({"a": range(5), "b": [2, 2, 4, 2, 4]}))
mask = [True, False, True, True, False]
mask_2 = [True, True, False, True, False]
expected_mask_only = {"a": [0, 2, 3], "b": [2, 4, 2]}
expected_mixed = {"a": [0, 3], "b": [2, 2]}

result = df.filter(mask)
assert_equal_data(result, expected_mask_only)

msg = (
r"unable to find column \"c\"; valid columns: \[\"a\", \"b\"\]"
if "polars" in str(constructor_eager)
else (
r"The following columns were not found: \[.*\]"
r"\n\nHint: Did you mean one of these columns: \['a', 'b'\]?"
)
)
with pytest.raises(ColumnNotFoundError, match=msg):
df.filter(mask, c=1, d=2, e=3, f=4, g=5)

# NOTE: Everything from here is currently undefined
result = df.filter(mask, b=2)
assert_equal_data(result, expected_mixed)

result = df.filter(mask, nw.col("b") == 2)
assert_equal_data(result, expected_mixed)

result = df.filter(mask, mask_2)
assert_equal_data(result, expected_mixed)

result = df.filter(
mask, nw.Series.from_iterable("mask", mask_2, backend=df.implementation)
)
assert_equal_data(result, expected_mixed)

result = df.filter(mask, nw.col("b") != 4, b=2)
assert_equal_data(result, expected_mixed)
Loading