Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
from polars.dependencies import numpy as np

if TYPE_CHECKING:
from collections.abc import Reversible
from collections.abc import Iterator, Reversible

from polars import DataFrame
from polars.type_aliases import PolarsDataType, SizeUnit

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs

if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeGuard
else:
Expand All @@ -56,7 +61,7 @@ def _process_null_values(
return null_values


def _is_generator(val: object) -> bool:
def _is_generator(val: object | Iterator[T]) -> TypeIs[Iterator[T]]:
return (
(isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
or isinstance(val, MappingView)
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from polars._utils.unstable import issue_unstable_warning, unstable
from polars._utils.various import (
_in_notebook,
_is_generator,
is_bool_sequence,
is_sequence,
normalize_filepath,
Expand Down Expand Up @@ -2798,7 +2799,9 @@ def filter(
return self.clear() # type: ignore[return-value]
elif p is True:
continue # no-op; matches all rows
elif is_bool_sequence(p, include_series=True):
if _is_generator(p):
p = tuple(p)
if is_bool_sequence(p, include_series=True):
boolean_masks.append(pl.Series(p, dtype=Boolean))
elif (
(is_seq := is_sequence(p))
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,28 @@ def test_filter_multiple_predicates() -> None:
assert ldf.filter(predicate="==").select("description").collect().item() == "eq"


@pytest.mark.parametrize(
"predicate",
[
[pl.lit(True)],
iter([pl.lit(True)]),
[True, True, True],
iter([True, True, True]),
(p for p in (pl.col("c") < 9,)),
(p for p in (pl.col("a") > 0, pl.col("b") > 0)),
],
)
def test_filter_seq_iterable_all_true(predicate: Any) -> None:
ldf = pl.LazyFrame(
{
"a": [1, 1, 1],
"b": [1, 1, 2],
"c": [3, 1, 2],
}
)
assert_frame_equal(ldf, ldf.filter(predicate))


def test_apply_custom_function() -> None:
ldf = pl.LazyFrame(
{
Expand Down