Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
Expand Down Expand Up @@ -518,12 +517,9 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())

def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
def filter(self, predicate: ArrowExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
Expand Down
11 changes: 4 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
def row(self, index: int) -> tuple[Any, ...]:
return tuple(x for x in self.native.iloc[index])

def filter(self, predicate: PandasLikeExpr | list[bool]) -> Self:
if isinstance(predicate, list):
mask_native: pd.Series[Any] | list[bool] = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
def filter(self, predicate: PandasLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
return self._with_native(
self.native.loc[mask_native], validate_column_names=False
)
Expand Down
11 changes: 11 additions & 0 deletions narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import re
import sys
from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence
from datetime import timezone
from enum import Enum, auto
Expand Down Expand Up @@ -690,6 +691,10 @@ def _is_iterable(arg: Any | Iterable[Any]) -> bool:
return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series))


def is_iterator(val: Iterable[_T] | Any) -> TypeIs[Iterator[_T]]:
return isinstance(val, Iterator)


def parse_version(version: str | ModuleType | _SupportsVersion) -> tuple[int, ...]:
"""Simple version parser; split into a tuple of ints for comparison.

Expand Down Expand Up @@ -1344,6 +1349,12 @@ def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]:
return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp))


def predicates_contains_list_of_bool(
predicates: Collection[Any],
) -> TypeIs[Collection[list[bool]]]:
return any(is_list_of(pred, bool) for pred in predicates)


def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]:
# Check if an object is a sequence of `tp`, only sniffing the first element.
return bool(
Expand Down
53 changes: 27 additions & 26 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
is_compliant_lazyframe,
is_eager_allowed,
is_index_selector,
is_iterator,
is_lazy_allowed,
is_list_of,
is_sequence_like,
is_slice_none,
predicates_contains_list_of_bool,
qualified_type_name,
supports_arrow_c_stream,
zip_strict,
Expand Down Expand Up @@ -242,24 +244,20 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self:
return self._with_compliant(self._compliant_frame.drop(columns, strict=strict))

def filter(
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any
) -> Self:
if len(predicates) == 1 and is_list_of(predicates[0], bool):
predicate = predicates[0]
else:
from narwhals.functions import col

flat_predicates = flatten(predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx)
for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
from narwhals.functions import col

flat_predicates = flatten(predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
return self._with_compliant(self._compliant_frame.filter(predicate))

def sort(
Expand Down Expand Up @@ -1653,7 +1651,7 @@ def filter(

Arguments:
*predicates: Expression(s) that evaluates to a boolean Series. Can
also be a (single!) boolean list.
also be a boolean list(s).
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
joined with the other filter conditions using &.
Expand Down Expand Up @@ -1695,7 +1693,12 @@ def filter(
foo bar ham
1 2 7 b
"""
return super().filter(*predicates, **constraints)
impl = self.implementation
parsed_predicates = (
self._series.from_iterable("", p, backend=impl) if is_list_of(p, bool) else p
for p in predicates
)
return super().filter(*parsed_predicates, **constraints)

@overload
def group_by(
Expand Down Expand Up @@ -2850,15 +2853,14 @@ def unique(
)

def filter(
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any
) -> Self:
r"""Filter the rows in the LazyFrame based on a predicate expression.

The original order of the remaining rows is preserved.

Arguments:
*predicates: Expression that evaluates to a boolean Series. Can
also be a (single!) boolean list.
*predicates: Expression(s) that evaluates to a boolean Series.
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
joined with the other filter conditions using &.
Expand Down Expand Up @@ -2924,13 +2926,12 @@ def filter(
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
<BLANKLINE>
"""
if (
len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints
): # pragma: no cover
Comment on lines -2927 to -2929
Copy link
Member Author

@FBruzzesi FBruzzesi Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not covering full spectrum of possibilities

predicates_ = tuple(tuple(p) if is_iterator(p) else p for p in predicates)
if predicates_contains_list_of_bool(predicates_):
msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead."
raise TypeError(msg)

return super().filter(*predicates, **constraints)
return super().filter(*predicates_, **constraints)

def sink_parquet(self, file: str | Path | BytesIO) -> None:
"""Write LazyFrame to Parquet file.
Expand Down
Loading
Loading