Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
Expand Down Expand Up @@ -518,12 +517,9 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())

def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
def filter(self, predicate: ArrowExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
Expand Down
11 changes: 4 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,10 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
def row(self, index: int) -> tuple[Any, ...]:
return tuple(x for x in self.native.iloc[index])

def filter(self, predicate: PandasLikeExpr | list[bool]) -> Self:
if isinstance(predicate, list):
mask_native: pd.Series[Any] | list[bool] = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
def filter(self, predicate: PandasLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = self._evaluate_into_exprs(predicate)[0]
mask_native = self._extract_comparand(mask)
return self._with_native(
self.native.loc[mask_native], validate_column_names=False
)
Expand Down
45 changes: 43 additions & 2 deletions narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

import os
import re
from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence
import sys
from collections.abc import (
Collection,
Container,
Generator,
Iterable,
Iterator,
Mapping,
MappingView,
Sequence,
Sized,
)
from datetime import timezone
from enum import Enum, auto
from functools import cache, lru_cache, partial, wraps
Expand Down Expand Up @@ -50,7 +61,7 @@
from narwhals.exceptions import ColumnNotFoundError, DuplicateError, InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Set # noqa: PYI025
from collections.abc import Iterator, Reversible, Set # noqa: PYI025
from types import ModuleType

import pandas as pd
Expand Down Expand Up @@ -124,6 +135,7 @@
CompliantSeries,
DTypes,
FileSource,
IntoExpr,
IntoSeriesT,
MultiIndexSelector,
SingleIndexSelector,
Expand Down Expand Up @@ -163,6 +175,11 @@ class _StoresColumns(Protocol):
def columns(self) -> Sequence[str]: ...


# note: reversed views don't match as instances of MappingView
if sys.version_info >= (3, 11): # pragma: no cover
_views: list[Reversible[Any]] = [{}.keys(), {}.values(), {}.items()]
_reverse_mapping_views = tuple(type(reversed(view)) for view in _views)

_T = TypeVar("_T")
NativeT_co = TypeVar("NativeT_co", covariant=True)
CompliantT_co = TypeVar("CompliantT_co", covariant=True)
Expand Down Expand Up @@ -690,6 +707,17 @@ def _is_iterable(arg: Any | Iterable[Any]) -> bool:
return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series))


def _is_generator(
val: object | Iterator[_T] | Generator[_T] | MappingView,
) -> TypeIs[Iterator[_T] | Generator[_T] | MappingView]:
# Adapted from https://github.com/pola-rs/polars/pull/16254
return (
(isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
or isinstance(val, MappingView)
or (sys.version_info >= (3, 11) and isinstance(val, _reverse_mapping_views))
)


def parse_version(version: str | ModuleType | _SupportsVersion) -> tuple[int, ...]:
"""Simple version parser; split into a tuple of ints for comparison.

Expand Down Expand Up @@ -1344,6 +1372,19 @@ def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]:
return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp))


def _predicates_has_no_list_of_bool(
predicates: tuple[IntoExpr | Iterable[IntoExpr] | list[bool], ...],
) -> TypeIs[tuple[IntoExpr | Iterable[IntoExpr], ...]]:
"""Guard function to check if predicates contain no `list[bool]` elements.

This is used in `LazyFrame.filter` to ensure type checker knows that
`list[bool]` elements are filtered out before calling `BaseFrame.filter`.

Returns True if no predicate is a `list[bool]`, False otherwise.
"""
return not any(is_list_of(pred, bool) for pred in predicates)


def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]:
# Check if an object is a sequence of `tp`, only sniffing the first element.
return bool(
Expand Down
53 changes: 28 additions & 25 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Implementation,
Version,
_Implementation,
_is_generator,
_predicates_has_no_list_of_bool,
can_lazyframe_collect,
check_columns_exist,
flatten,
Expand Down Expand Up @@ -242,24 +244,20 @@ def drop(self, *columns: Iterable[str], strict: bool) -> Self:
return self._with_compliant(self._compliant_frame.drop(columns, strict=strict))

def filter(
self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
self, *predicates: IntoExpr | Iterable[IntoExpr], **constraints: Any
) -> Self:
if len(predicates) == 1 and is_list_of(predicates[0], bool):
predicate = predicates[0]
else:
from narwhals.functions import col

flat_predicates = flatten(predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx)
for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
from narwhals.functions import col

flat_predicates = flatten(predicates)
check_expressions_preserve_length(*flat_predicates, function_name="filter")
plx = self.__narwhals_namespace__()
compliant_predicates, _kinds = self._flatten_and_extract(*flat_predicates)
compliant_constraints = (
(col(name) == v)._to_compliant_expr(plx) for name, v in constraints.items()
)
predicate = plx.all_horizontal(
*chain(compliant_predicates, compliant_constraints), ignore_nulls=False
)
return self._with_compliant(self._compliant_frame.filter(predicate))

def sort(
Expand Down Expand Up @@ -1653,7 +1651,7 @@ def filter(

Arguments:
*predicates: Expression(s) that evaluates to a boolean Series. Can
also be a (single!) boolean list.
also be a boolean list(s).
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
joined with the other filter conditions using &.
Expand Down Expand Up @@ -1695,7 +1693,14 @@ def filter(
foo bar ham
1 2 7 b
"""
return super().filter(*predicates, **constraints)
plx = self.__narwhals_namespace__()
parsed_predicates = tuple(
plx._series.from_iterable(pred, context=plx).to_narwhals()
if is_list_of(pred, bool)
else pred
for pred in predicates
)
return super().filter(*parsed_predicates, **constraints)

@overload
def group_by(
Expand Down Expand Up @@ -2857,8 +2862,7 @@ def filter(
The original order of the remaining rows is preserved.

Arguments:
*predicates: Expression that evaluates to a boolean Series. Can
also be a (single!) boolean list.
*predicates: Expression(s) that evaluates to a boolean Series.
**constraints: Column filters; use `name = value` to filter columns by the supplied value.
Each constraint will behave the same as `nw.col(name).eq(value)`, and will be implicitly
joined with the other filter conditions using &.
Expand Down Expand Up @@ -2924,13 +2928,12 @@ def filter(
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
<BLANKLINE>
"""
if (
len(predicates) == 1 and is_list_of(predicates[0], bool) and not constraints
): # pragma: no cover
Comment on lines -2927 to -2929
Copy link
Member Author

@FBruzzesi FBruzzesi Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not covering full spectrum of possibilities

predicates_ = tuple(tuple(p) if _is_generator(p) else p for p in predicates)
if not _predicates_has_no_list_of_bool(predicates_):
msg = "`LazyFrame.filter` is not supported with Python boolean masks - use expressions instead."
raise TypeError(msg)

return super().filter(*predicates, **constraints)
return super().filter(*predicates_, **constraints)

def sink_parquet(self, file: str | Path | BytesIO) -> None:
"""Write LazyFrame to Parquet file.
Expand Down
Loading
Loading