Skip to content

[Bug]: Missing handling for Iterator[IntoExpr] #1897

@dangotbanned

Description

@dangotbanned

Describe the bug

Discovered this in (vega/altair#3631 (comment)) and managed to narrow it down to a more minimal repro.

I think this may be an issue for any function that accepts Iterable[IntoExpr] but assumes Sequence[IntoExpr].

I discovered a similar bug in polars a while back, but that was more limited in scope

So far I've found this impacts these two:

  • nw.all_horizontal
  • nw.any_horizontal

I suspect maybe adding a step in flatten to collect an Iterator might help?

narwhals/narwhals/utils.py

Lines 344 to 349 in f644931

def flatten(args: Any) -> list[Any]:
if not args:
return []
if len(args) == 1 and _is_iterable(args[0]):
return args[0] # type: ignore[no-any-return]
return args # type: ignore[no-any-return]

Steps or code to reproduce the bug

from __future__ import annotations  # noqa: F404

from collections.abc import Iterable, Iterator
from typing import Any

import polars as pl
from narwhals.stable import v1 as nw


def iter_eq(namespace: Any, items: Iterable[tuple[str, Any]], /) -> Iterator[Any]:
    for column, value in items:
        yield namespace.col(column) == value


def any_horizontal(namespace: Any, items: Iterable[tuple[str, Any]], /) -> Any:
    expr = iter_eq(namespace, items)
    return namespace.any_horizontal(expr)


pl_frame = pl.DataFrame({"A": range(5), "B": list("bbaab")})
nw_frame = nw.from_native(pl_frame, eager_only=True)

expr_items = [("A", 3), ("B", "b")]
nw_expr = any_horizontal(nw, expr_items)

nw_one = nw_frame.filter(nw_expr)
>>> nw_frame.filter(nw_expr)

Expected results

No error is thrown.

Using the same example in polars, the same expression can be used multiple times regardless of the kind of Iterable used.

pl_expr = any_horizontal(pl, expr_items)
pl_results = (
    pl_frame.filter(pl_expr),
    pl_frame.filter(pl_expr),
    pl_frame.filter(~pl_expr),
    pl_frame.filter(~~pl_expr),
)
>>> pl_results

Output

(shape: (4, 2)
 ┌─────┬─────┐
 │ AB   │
 │ ------ │
 │ i64str │
 ╞═════╪═════╡
 │ 0b   │
 │ 1b   │
 │ 3a   │
 │ 4b   │
 └─────┴─────┘,
 shape: (4, 2)
 ┌─────┬─────┐
 │ AB   │
 │ ------ │
 │ i64str │
 ╞═════╪═════╡
 │ 0b   │
 │ 1b   │
 │ 3a   │
 │ 4b   │
 └─────┴─────┘,
 shape: (1, 2)
 ┌─────┬─────┐
 │ AB   │
 │ ------ │
 │ i64str │
 ╞═════╪═════╡
 │ 2a   │
 └─────┴─────┘,
 shape: (4, 2)
 ┌─────┬─────┐
 │ AB   │
 │ ------ │
 │ i64str │
 ╞═════╪═════╡
 │ 0b   │
 │ 1b   │
 │ 3a   │
 │ 4b   │
 └─────┴─────┘)

Actual results

No results, see Relevant log output

Please run narwhals.show_version() and enter the output below.

System:
    python: 3.12.8 (main, Dec 19 2024, 14:41:01) [MSC v.1942 64 bit (AMD64)]
executable: c:\Users\danie\Documents\GitHub\altair\.venv\Scripts\python.exe
   machine: Windows-10-10.0.19045-SP0

Python dependencies:
     narwhals: 1.22.0
       pandas: 2.2.3
       polars: 1.20.0
         cudf: 
        modin: 
      pyarrow: 19.0.0
        numpy: 2.2.1

Relevant log output

---------------------------------------------------------------------------
ComputeError                              Traceback (most recent call last)
Cell In[20], line 2
      1 nw_one = nw_frame.filter(nw_expr)
----> 2 nw_frame.filter(nw_expr)

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\dataframe.py:2317, in DataFrame.filter(self, *predicates, **constraints)
   2136 def filter(
   2137     self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
   2138 ) -> Self:
   2139     r"""Filter the rows in the DataFrame based on one or more predicate expressions.
   2140 
   2141     The original order of the remaining rows is preserved.
   (...)
   2315         ham: [["b"]]
   2316     """
-> 2317     return super().filter(*predicates, **constraints)

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\dataframe.py:158, in BaseFrame.filter(self, *predicates, **constraints)
    150 def filter(
    151     self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool], **constraints: Any
    152 ) -> Self:
    153     if not (
    154         len(predicates) == 1
    155         and isinstance(predicates[0], list)
    156         and all(isinstance(x, bool) for x in predicates[0])
    157     ):
--> 158         predicates, constraints = self._flatten_and_extract(
    159             *predicates, **constraints
    160         )
    161     return self._from_compliant_dataframe(
    162         self._compliant_frame.filter(*predicates, **constraints),
    163     )

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\dataframe.py:68, in BaseFrame._flatten_and_extract(self, *args, **kwargs)
     66 def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
     67     """Process `args` and `kwargs`, extracting underlying objects as we go."""
---> 68     args = [self._extract_compliant(v) for v in flatten(args)]  # type: ignore[assignment]
     69     kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
     70     return args, kwargs

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\dataframe.py:81, in BaseFrame._extract_compliant(self, arg)
     79     return arg._compliant_series
     80 if isinstance(arg, Expr):
---> 81     return arg._to_compliant_expr(self.__narwhals_namespace__())
     82 if get_polars() is not None and "polars" in str(type(arg)):
     83     msg = (
     84         f"Expected Narwhals object, got: {type(arg)}.\n\n"
     85         "Perhaps you:\n"
     86         "- Forgot a `nw.from_native` somewhere?\n"
     87         "- Used `pl.col` instead of `nw.col`?"
     88     )

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\functions.py:2371, in any_horizontal.<locals>.<lambda>(plx)
   2368     msg = "At least one expression must be passed to `any_horizontal`"
   2369     raise ValueError(msg)
   2370 return Expr(
-> 2371     lambda plx: plx.any_horizontal(
   2372         *[extract_compliant(plx, v) for v in flatten(exprs)]
   2373     )
   2374 )

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\narwhals\_polars\namespace.py:43, in PolarsNamespace.__getattr__.<locals>.func(*args, **kwargs)
     40 def func(*args: Any, **kwargs: Any) -> Any:
     41     args, kwargs = extract_args_kwargs(args, kwargs)  # type: ignore[assignment]
     42     return PolarsExpr(
---> 43         getattr(pl, attr)(*args, **kwargs),
     44         version=self._version,
     45         backend_version=self._backend_version,
     46     )

File c:\Users\danie\Documents\GitHub\altair\.venv\Lib\site-packages\polars\functions\aggregation\horizontal.py:108, in any_horizontal(*exprs)
     67 """
     68 Compute the bitwise OR horizontally across columns.
     69 
   (...)
    105 └───────┴───────┴─────┴───────┘
    106 """
    107 pyexprs = parse_into_list_of_expressions(*exprs)
--> 108 return wrap_expr(plr.any_horizontal(pyexprs))

ComputeError: cannot return empty fold because the number of output rows is unknown

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions