From d3e379b79c656c384b141015ed7bbb17de0892bd Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Sun, 17 Aug 2025 11:14:52 +0100 Subject: [PATCH 01/10] replace `zip` with `zip_equal` --- narwhals/_arrow/dataframe.py | 9 ++++--- narwhals/_compliant/expr.py | 8 +++--- narwhals/_compliant/selectors.py | 19 ++++++++++---- narwhals/_dask/dataframe.py | 3 ++- narwhals/_dask/group_by.py | 3 ++- narwhals/_dask/namespace.py | 8 +++--- narwhals/_duckdb/dataframe.py | 13 ++++++---- narwhals/_duckdb/utils.py | 4 +-- narwhals/_expression_parsing.py | 8 +++--- narwhals/_ibis/dataframe.py | 3 ++- narwhals/_ibis/expr.py | 4 +-- narwhals/_pandas_like/dataframe.py | 3 ++- narwhals/_pandas_like/group_by.py | 3 ++- narwhals/_pandas_like/namespace.py | 7 +++-- narwhals/_polars/namespace.py | 6 ++--- narwhals/_spark_like/dataframe.py | 5 ++-- narwhals/_spark_like/expr.py | 4 +-- narwhals/_spark_like/namespace.py | 5 ++-- narwhals/_sql/group_by.py | 7 +++-- narwhals/_utils.py | 41 ++++++++++++++++++++++++++++++ narwhals/dataframe.py | 11 +++++--- narwhals/schema.py | 4 +-- 22 files changed, 125 insertions(+), 53 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 6ed9cdcb63..f7843f24cb 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -21,6 +21,7 @@ parse_columns_to_drop, scale_bytes, supports_arrow_c_stream, + zip_equal, ) from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ShapeError @@ -202,7 +203,7 @@ def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: return self.native.to_pylist() def iter_columns(self) -> Iterator[ArrowSeries]: - for name, series in zip(self.columns, self.native.itercolumns()): + for name, series in zip_equal(self.columns, self.native.itercolumns()): yield ArrowSeries.from_native(series, context=self, name=name) _iter_columns = iter_columns @@ -216,7 +217,7 @@ def iter_rows( if not named: for i in range(0, num_rows, buffer_size): rows = df[i : i + buffer_size].to_pydict().values() - yield from zip(*rows) + yield from zip_equal(*rows) else: for i in range(0, num_rows, buffer_size): yield from df[i : i + buffer_size].to_pylist() @@ -290,7 +291,7 @@ def schema(self) -> dict[str, DType]: schema = self.native.schema return { name: native_to_narwhals_dtype(dtype, self._version) - for name, dtype in zip(schema.names, schema.types) + for name, dtype in zip_equal(schema.names, schema.types) } def collect_schema(self) -> dict[str, DType]: @@ -431,7 +432,7 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> else: sorting = [ (key, "descending" if is_descending else "ascending") - for key, is_descending in zip(by, descending) + for key, is_descending in zip_equal(by, descending) ] null_placement = "at_end" if nulls_last else "at_start" diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index cfcfeda520..f105d17126 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -28,7 +28,7 @@ LazyExprT, NativeExprT, ) -from narwhals._utils import _StoresCompliant +from narwhals._utils import _StoresCompliant, zip_equal from narwhals.dependencies import get_numpy, is_numpy_array if TYPE_CHECKING: @@ -280,13 +280,13 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: if alias_output_names: return [ series.alias(name) - for series, name in zip( + for series, name in zip_equal( self(df), alias_output_names(self._evaluate_output_names(df)) ) ] return [ series.alias(name) - for series, name in zip(self(df), self._evaluate_output_names(df)) + for series, name in zip_equal(self(df), self._evaluate_output_names(df)) ] return self.__class__( @@ -767,7 +767,7 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]: ) result = [ from_numpy(array).alias(output_name) - for array, output_name in zip(result, output_names) + for array, output_name in zip_equal(result, output_names) ] if return_dtype is not None: result = [series.cast(return_dtype) for series in result] diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 8e318dc675..8c80220fc4 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -12,6 +12,7 @@ dtype_matches_time_unit_and_time_zone, get_column_names, is_compliant_dataframe, + zip_equal, ) if TYPE_CHECKING: @@ -77,7 +78,7 @@ def _iter_columns_dtypes( ) -> Iterator[tuple[SeriesOrExprT, DType]]: ... def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]: - yield from zip(self._iter_columns(df), df.columns) + yield from zip_equal(self._iter_columns(df), df.columns) def _is_dtype( self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], / @@ -192,7 +193,7 @@ def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]: yield from df._iter_columns() def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]: - yield from zip(self._iter_columns(df), df.schema.values()) + yield from zip_equal(self._iter_columns(df), df.schema.values()) class CompliantSelector( @@ -244,7 +245,9 @@ def __sub__( def series(df: FrameT) -> Sequence[SeriesOrExprT]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) return [ - x for x, name in zip(self(df), lhs_names) if name not in rhs_names + x + for x, name in zip_equal(self(df), lhs_names) + if name not in rhs_names ] def names(df: FrameT) -> Sequence[str]: @@ -268,7 +271,11 @@ def __or__( def series(df: FrameT) -> Sequence[SeriesOrExprT]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) return [ - *(x for x, name in zip(self(df), lhs_names) if name not in rhs_names), + *( + x + for x, name in zip_equal(self(df), lhs_names) + if name not in rhs_names + ), *other(df), ] @@ -292,7 +299,9 @@ def __and__( def series(df: FrameT) -> Sequence[SeriesOrExprT]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) - return [x for x, name in zip(self(df), lhs_names) if name in rhs_names] + return [ + x for x, name in zip_equal(self(df), lhs_names) if name in rhs_names + ] def names(df: FrameT) -> Sequence[str]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index b3b6dc2676..32df5604d7 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -17,6 +17,7 @@ generate_temporary_column_name, not_implemented, parse_columns_to_drop, + zip_equal, ) from narwhals.typing import CompliantLazyFrame @@ -284,7 +285,7 @@ def _join_left( ) extra = [ right_key if right_key not in self.columns else f"{right_key}{suffix}" - for left_key, right_key in zip(left_on, right_on) + for left_key, right_key in zip_equal(left_on, right_on) if right_key != left_key ] return result_native.drop(columns=extra) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index cb2df38815..20d443738a 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -7,6 +7,7 @@ from narwhals._compliant import DepthTrackingGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._utils import zip_equal if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -138,7 +139,7 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( (alias, (output_name, agg_fn)) - for alias, output_name in zip(aliases, output_names) + for alias, output_name in zip_equal(aliases, output_names) ) return DaskLazyFrame( self._grouped.agg(**simple_aggregations).reset_index(), diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 518ab36b7e..e9285d48ad 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -26,7 +26,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation +from narwhals._utils import Implementation, zip_equal if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -263,7 +263,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) else: init_value, *values = [ - s.where(~nm, "") for s, nm in zip(series, null_mask) + s.where(~nm, "") for s, nm in zip_equal(series, null_mask) ] separators = ( @@ -271,7 +271,9 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: for nm in null_mask[:-1] ) result = reduce( - operator.add, (s + v for s, v in zip(separators, values)), init_value + operator.add, + (s + v for s, v in zip_equal(separators, values)), + init_value, ) return [result] diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 82b913405b..b18a1c7641 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -26,6 +26,7 @@ not_implemented, parse_columns_to_drop, requires, + zip_equal, ) from narwhals.dependencies import get_duckdb from narwhals.exceptions import InvalidOperationError @@ -231,7 +232,9 @@ def schema(self) -> dict[str, DType]: column_name: native_to_narwhals_dtype( duckdb_dtype, self._version, deferred_time_zone ) - for column_name, duckdb_dtype in zip(self.native.columns, self.native.types) + for column_name, duckdb_dtype in zip_equal( + self.native.columns, self.native.types + ) } @property @@ -295,7 +298,7 @@ def join( assert right_on is not None # noqa: S101 it = ( col(f'lhs."{left}"') == col(f'rhs."{right}"') - for left, right in zip(left_on, right_on) + for left, right in zip_equal(left_on, right_on) ) condition: Expression = reduce(and_, it) rel = self.native.set_alias("lhs").join( @@ -340,7 +343,7 @@ def join_asof( if by_left is not None and by_right is not None: conditions.extend( col(f'lhs."{left}"') == col(f'rhs."{right}"') - for left, right in zip(by_left, by_right) + for left, right in zip_equal(by_left, by_right) ) else: by_left = by_right = [] @@ -400,12 +403,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> if nulls_last: it = ( col(name).nulls_last() if not desc else col(name).desc().nulls_last() - for name, desc in zip(by, descending) + for name, desc in zip_equal(by, descending) ) else: it = ( col(name).nulls_first() if not desc else col(name).desc().nulls_first() - for name, desc in zip(by, descending) + for name, desc in zip_equal(by, descending) ) return self._with_native(self.native.sort(*it)) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 96a4698d60..740c2eb985 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -7,7 +7,7 @@ import duckdb.typing as duckdb_dtypes from duckdb.typing import DuckDBPyType -from narwhals._utils import Version, isinstance_or_issubclass +from narwhals._utils import Version, isinstance_or_issubclass, zip_equal from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: @@ -302,7 +302,7 @@ def generate_order_by_sql( return "" by_sql = ",".join( f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}" - for x, _descending, _nulls_last in zip(order_by, descending, nulls_last) + for x, _descending, _nulls_last in zip_equal(order_by, descending, nulls_last) ) return f"order by {by_sql}" diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 94a43b53ef..75ae11624b 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -8,7 +8,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast -from narwhals._utils import is_compliant_expr +from narwhals._utils import is_compliant_expr, zip_equal from narwhals.dependencies import is_narwhals_series, is_numpy_array from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError @@ -104,10 +104,10 @@ def evaluate_output_names_and_aliases( if exclude: assert expr._metadata is not None # noqa: S101 if expr._metadata.expansion_kind.is_multi_unnamed(): - output_names, aliases = zip( + output_names, aliases = zip_equal( *[ (x, alias) - for x, alias in zip(output_names, aliases) + for x, alias in zip_equal(output_names, aliases) if x not in exclude ] ) @@ -626,6 +626,6 @@ def apply_n_ary_operation( compliant_expr.broadcast(kind) if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip(compliant_exprs, kinds) + for compliant_expr, kind in zip_equal(compliant_exprs, kinds) ) return function(*compliant_exprs) diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 5fa857894f..38e4ad2803 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -15,6 +15,7 @@ Version, not_implemented, parse_columns_to_drop, + zip_equal, ) from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError @@ -307,7 +308,7 @@ def _convert_predicates( return left_on return [ cast("ir.BooleanColumn", (self.native[left] == other.native[right])) - for left, right in zip(left_on, right_on) + for left, right in zip_equal(left_on, right_on) ] def collect_schema(self) -> dict[str, DType]: diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 13c3989d8c..8ac257fe00 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -12,7 +12,7 @@ from narwhals._ibis.expr_struct import IbisExprStructNamespace from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype from narwhals._sql.expr import SQLExpr -from narwhals._utils import Implementation, Version, not_implemented +from narwhals._utils import Implementation, Version, not_implemented, zip_equal if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -128,7 +128,7 @@ def _sort( } yield from ( cast("ir.Column", mapping[(_desc, _nulls_last)](col)) - for col, _desc, _nulls_last in zip(cols, descending, nulls_last) + for col, _desc, _nulls_last in zip_equal(cols, descending, nulls_last) ) @classmethod diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a71b5256d1..fc86a8465a 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -29,6 +29,7 @@ generate_temporary_column_name, parse_columns_to_drop, scale_bytes, + zip_equal, ) from narwhals.dependencies import is_pandas_like_dataframe from narwhals.exceptions import InvalidOperationError, ShapeError @@ -561,7 +562,7 @@ def _join_left( ) extra = [ right_key if right_key not in self.columns else f"{right_key}{suffix}" - for left_key, right_key in zip(left_on, right_on) + for left_key, right_key in zip_equal(left_on, right_on) if right_key != left_key ] # NOTE: Keep `inplace=True` to avoid making a redundant copy. diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 7867115034..27278428fe 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -9,6 +9,7 @@ from narwhals._compliant import EagerGroupBy from narwhals._exceptions import issue_warning from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._utils import zip_equal from narwhals.dependencies import is_pandas_like_dataframe if TYPE_CHECKING: @@ -283,7 +284,7 @@ def fn(df: pd.DataFrame) -> pd.Series[Any]: for expr in exprs for keys in expr(compliant) ) - out_group, out_names = zip(*results) if results else ([], []) + out_group, out_names = zip_equal(*results) if results else ([], []) return into_series(out_group, index=out_names, context=ns).native return fn diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 637c2ffae9..59e62ea720 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -17,6 +17,7 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT from narwhals._pandas_like.utils import is_non_nullable_boolean +from narwhals._utils import zip_equal if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -345,7 +346,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: # error: Cannot determine type of "values" [has-type] values: list[PandasLikeSeries] init_value, *values = [ - s.zip_with(~nm, "") for s, nm in zip(series, null_mask) + s.zip_with(~nm, "") for s, nm in zip_equal(series, null_mask) ] sep_array = init_value.from_iterable( @@ -356,7 +357,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1]) result = reduce( - operator.add, (s + v for s, v in zip(separators, values)), init_value + operator.add, + (s + v for s, v in zip_equal(separators, values)), + init_value, ) return [result] diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 047952916b..1dfdede2a3 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -8,7 +8,7 @@ from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype -from narwhals._utils import Implementation, requires +from narwhals._utils import Implementation, requires, zip_equal from narwhals.dependencies import is_numpy_array_2d from narwhals.dtypes import DType @@ -175,7 +175,7 @@ def concat_str( else: init_value, *values = [ pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String())) - for expr, nm in zip(pl_exprs, null_mask) + for expr, nm in zip_equal(pl_exprs, null_mask) ] separators = [ pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1] @@ -184,7 +184,7 @@ def concat_str( result = pl.fold( # type: ignore[assignment] acc=init_value, function=operator.add, - exprs=[s + v for s, v in zip(separators, values)], + exprs=[s + v for s, v in zip_equal(separators, values)], ) return self._expr(result, version=self._version) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 3cbaec22ee..5b49968494 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -22,6 +22,7 @@ generate_temporary_column_name, not_implemented, parse_columns_to_drop, + zip_equal, ) from narwhals.exceptions import InvalidOperationError @@ -335,7 +336,7 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> for d in descending ) - sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] + sort_cols = [sort_f(col) for col, sort_f in zip_equal(by, sort_funcs)] return self._with_native(self.native.sort(*sort_cols)) def drop_nulls(self, subset: Sequence[str] | None) -> Self: @@ -423,7 +424,7 @@ def join( and_, ( getattr(self.native, left_key) == getattr(other_native, right_key) - for left_key, right_key in zip(left_on_, right_on_remapped) + for left_key, right_key in zip_equal(left_on_, right_on_remapped) ), ) if how == "full" diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 7eedb49495..ea1c0d39d0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -16,7 +16,7 @@ true_divide, ) from narwhals._sql.expr import SQLExpr -from narwhals._utils import Implementation, Version, not_implemented +from narwhals._utils import Implementation, Version, not_implemented, zip_equal if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence @@ -142,7 +142,7 @@ def _sort( } yield from ( mapping[(_desc, _nulls_last)](col) - for col, _desc, _nulls_last in zip(cols, descending, nulls_last) + for col, _desc, _nulls_last in zip_equal(cols, descending, nulls_last) ) def partition_by(self, *cols: Column | str) -> WindowSpec: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index ef00ffdd77..0f9aeb2e2b 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,6 +19,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._sql.when_then import SQLThen, SQLWhen +from narwhals._utils import zip_equal if TYPE_CHECKING: from collections.abc import Iterable @@ -188,7 +189,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: else: init_value, *values = [ df._F.when(~nm, col).otherwise(df._F.lit("")) - for col, nm in zip(cols_casted, null_mask) + for col, nm in zip_equal(cols_casted, null_mask) ] separators = ( @@ -199,7 +200,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: lambda x, y: df._F.format_string("%s%s", x, y), ( df._F.format_string("%s%s", s, v) - for s, v in zip(separators, values) + for s, v in zip_equal(separators, values) ), init_value, ) diff --git a/narwhals/_sql/group_by.py b/narwhals/_sql/group_by.py index 58e236f33e..e8991e4ae9 100644 --- a/narwhals/_sql/group_by.py +++ b/narwhals/_sql/group_by.py @@ -5,6 +5,7 @@ from narwhals._compliant.group_by import CompliantGroupBy, ParseKeysGroupBy from narwhals._compliant.typing import CompliantLazyFrameT_co, NativeExprT_co from narwhals._sql.typing import SQLExprT_contra +from narwhals._utils import zip_equal if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -28,11 +29,13 @@ def _evaluate_expr(self, expr: SQLExprT_contra, /) -> Iterator[NativeExprT_co]: native_exprs = expr(self.compliant) if expr._is_multi_output_unnamed(): exclude = {*self._keys, *self._output_key_names} - for native_expr, name, alias in zip(native_exprs, output_names, aliases): + for native_expr, name, alias in zip_equal( + native_exprs, output_names, aliases + ): if name not in exclude: yield expr._alias_native(native_expr, alias) else: - for native_expr, alias in zip(native_exprs, aliases): + for native_expr, alias in zip_equal(native_exprs, aliases): yield expr._alias_native(native_expr, alias) def _evaluate_exprs( diff --git a/narwhals/_utils.py b/narwhals/_utils.py index a30e940a61..8321a2c3c7 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -8,6 +8,7 @@ from functools import cache, lru_cache, wraps from importlib.util import find_spec from inspect import getattr_static, getdoc +from itertools import chain from operator import attrgetter from secrets import token_hex from typing import ( @@ -1079,6 +1080,46 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: return obj_any +@overload +def zip_equal(it1: Iterable[_T1], it2: Iterable[_T2], /) -> Iterable[tuple[_T1, _T2]]: ... + + +@overload +def zip_equal(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: ... + + +# https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 +def zip_equal(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: + # For trivial cases, use pure zip. + if len(iterables) < 2: + return zip(*iterables) + + # Tail for the first iterable + first_stopped = False + + def first_tail() -> Any: + nonlocal first_stopped + first_stopped = True + return + yield + + # Tail for the zip + def zip_tail() -> Any: + if not first_stopped: + msg = "zip_equal: first iterable is longer" + raise ValueError(msg) + for _ in chain.from_iterable(rest): + msg = "zip_equal: first iterable is shorter" + raise ValueError(msg) + yield + + # Put the pieces together + iterables_it = iter(iterables) + first = chain(next(iterables_it), first_tail()) + rest = list(map(iter, iterables_it)) + return chain(zip(first, *rest), zip_tail()) + + def _is_range_index(obj: Any, native_namespace: Any) -> TypeIs[pd.RangeIndex]: return isinstance(obj, native_namespace.RangeIndex) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 897aead1f6..9f82820ca6 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -35,6 +35,7 @@ is_sequence_like, is_slice_none, supports_arrow_c_stream, + zip_equal, ) from narwhals.dependencies import ( get_polars, @@ -166,7 +167,7 @@ def with_columns( compliant_exprs, kinds = self._flatten_and_extract(*exprs, **named_exprs) compliant_exprs = [ compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip(compliant_exprs, kinds) + for compliant_expr, kind in zip_equal(compliant_exprs, kinds) ] return self._with_compliant(self._compliant_frame.with_columns(*compliant_exprs)) @@ -190,7 +191,7 @@ def select( return self._with_compliant(self._compliant_frame.aggregate(*compliant_exprs)) compliant_exprs = [ compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip(compliant_exprs, kinds) + for compliant_expr, kind in zip_equal(compliant_exprs, kinds) ] return self._with_compliant(self._compliant_frame.select(*compliant_exprs)) @@ -1698,7 +1699,7 @@ def group_by( _keys = [ k if is_expr else col(k) - for k, is_expr in zip(flat_keys, key_is_expr_or_series) + for k, is_expr in zip_equal(flat_keys, key_is_expr_or_series) ] expr_flat_keys, kinds = self._flatten_and_extract(*_keys) @@ -2930,7 +2931,9 @@ def group_by( msg = "drop_null_keys cannot be True when keys contains Expr" raise NotImplementedError(msg) - _keys = [k if is_expr else col(k) for k, is_expr in zip(flat_keys, key_is_expr)] + _keys = [ + k if is_expr else col(k) for k, is_expr in zip_equal(flat_keys, key_is_expr) + ] expr_flat_keys, kinds = self._flatten_and_extract(*_keys) if not all(kind is ExprKind.ELEMENTWISE for kind in kinds): diff --git a/narwhals/schema.py b/narwhals/schema.py index e79a9ca86c..a33b765cd6 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -10,7 +10,7 @@ from functools import partial from typing import TYPE_CHECKING, cast -from narwhals._utils import Implementation, Version +from narwhals._utils import Implementation, Version, zip_equal if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -145,7 +145,7 @@ def to_pandas( raise ValueError(msg) return { name: to_native_dtype(dtype=dtype, dtype_backend=backend) - for name, dtype, backend in zip(self.keys(), self.values(), backends) + for name, dtype, backend in zip_equal(self.keys(), self.values(), backends) } def to_polars(self) -> pl.Schema: From 4ee1060feda155f311544a4bfe220a3a2ef99931 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Sun, 17 Aug 2025 11:25:57 +0100 Subject: [PATCH 02/10] add no cover --- narwhals/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 7d7bd32ded..9b7c257315 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1101,10 +1101,10 @@ def first_tail() -> Any: # Tail for the zip def zip_tail() -> Any: - if not first_stopped: + if not first_stopped: # pragma: no cover msg = "zip_equal: first iterable is longer" raise ValueError(msg) - for _ in chain.from_iterable(rest): + for _ in chain.from_iterable(rest): # pragma: no cover msg = "zip_equal: first iterable is shorter" raise ValueError(msg) yield From de3ffa46e266ce611578f6db9f990a7f29a0209d Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Sun, 17 Aug 2025 15:39:13 +0100 Subject: [PATCH 03/10] rename `zip_equal` to `zip_strict`, remove the existing `zip_strict` --- narwhals/_arrow/dataframe.py | 10 +++++----- narwhals/_compliant/expr.py | 8 ++++---- narwhals/_compliant/selectors.py | 12 ++++++------ narwhals/_dask/dataframe.py | 4 ++-- narwhals/_dask/group_by.py | 4 ++-- narwhals/_dask/namespace.py | 6 +++--- narwhals/_duckdb/dataframe.py | 12 ++++++------ narwhals/_duckdb/utils.py | 4 ++-- narwhals/_expression_parsing.py | 8 ++++---- narwhals/_ibis/dataframe.py | 4 ++-- narwhals/_ibis/expr.py | 4 ++-- narwhals/_pandas_like/dataframe.py | 4 ++-- narwhals/_pandas_like/group_by.py | 4 ++-- narwhals/_pandas_like/namespace.py | 6 +++--- narwhals/_polars/namespace.py | 6 +++--- narwhals/_spark_like/dataframe.py | 6 +++--- narwhals/_spark_like/expr.py | 4 ++-- narwhals/_spark_like/namespace.py | 6 +++--- narwhals/_sql/group_by.py | 6 +++--- narwhals/_utils.py | 12 +++++++----- narwhals/dataframe.py | 10 +++++----- narwhals/schema.py | 4 ++-- tests/utils.py | 11 ++--------- 23 files changed, 75 insertions(+), 80 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index f7843f24cb..5feb1fada4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -21,7 +21,7 @@ parse_columns_to_drop, scale_bytes, supports_arrow_c_stream, - zip_equal, + zip_strict, ) from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ShapeError @@ -203,7 +203,7 @@ def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: return self.native.to_pylist() def iter_columns(self) -> Iterator[ArrowSeries]: - for name, series in zip_equal(self.columns, self.native.itercolumns()): + for name, series in zip_strict(self.columns, self.native.itercolumns()): yield ArrowSeries.from_native(series, context=self, name=name) _iter_columns = iter_columns @@ -217,7 +217,7 @@ def iter_rows( if not named: for i in range(0, num_rows, buffer_size): rows = df[i : i + buffer_size].to_pydict().values() - yield from zip_equal(*rows) + yield from zip_strict(*rows) else: for i in range(0, num_rows, buffer_size): yield from df[i : i + buffer_size].to_pylist() @@ -291,7 +291,7 @@ def schema(self) -> dict[str, DType]: schema = self.native.schema return { name: native_to_narwhals_dtype(dtype, self._version) - for name, dtype in zip_equal(schema.names, schema.types) + for name, dtype in zip_strict(schema.names, schema.types) } def collect_schema(self) -> dict[str, DType]: @@ -432,7 +432,7 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> else: sorting = [ (key, "descending" if is_descending else "ascending") - for key, is_descending in zip_equal(by, descending) + for key, is_descending in zip_strict(by, descending) ] null_placement = "at_end" if nulls_last else "at_start" diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index f105d17126..8241f000e2 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -28,7 +28,7 @@ LazyExprT, NativeExprT, ) -from narwhals._utils import _StoresCompliant, zip_equal +from narwhals._utils import _StoresCompliant, zip_strict from narwhals.dependencies import get_numpy, is_numpy_array if TYPE_CHECKING: @@ -280,13 +280,13 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]: if alias_output_names: return [ series.alias(name) - for series, name in zip_equal( + for series, name in zip_strict( self(df), alias_output_names(self._evaluate_output_names(df)) ) ] return [ series.alias(name) - for series, name in zip_equal(self(df), self._evaluate_output_names(df)) + for series, name in zip_strict(self(df), self._evaluate_output_names(df)) ] return self.__class__( @@ -767,7 +767,7 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]: ) result = [ from_numpy(array).alias(output_name) - for array, output_name in zip_equal(result, output_names) + for array, output_name in zip_strict(result, output_names) ] if return_dtype is not None: result = [series.cast(return_dtype) for series in result] diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index 8c80220fc4..4f892134c4 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -12,7 +12,7 @@ dtype_matches_time_unit_and_time_zone, get_column_names, is_compliant_dataframe, - zip_equal, + zip_strict, ) if TYPE_CHECKING: @@ -78,7 +78,7 @@ def _iter_columns_dtypes( ) -> Iterator[tuple[SeriesOrExprT, DType]]: ... def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]: - yield from zip_equal(self._iter_columns(df), df.columns) + yield from zip_strict(self._iter_columns(df), df.columns) def _is_dtype( self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], / @@ -193,7 +193,7 @@ def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]: yield from df._iter_columns() def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]: - yield from zip_equal(self._iter_columns(df), df.schema.values()) + yield from zip_strict(self._iter_columns(df), df.schema.values()) class CompliantSelector( @@ -246,7 +246,7 @@ def series(df: FrameT) -> Sequence[SeriesOrExprT]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) return [ x - for x, name in zip_equal(self(df), lhs_names) + for x, name in zip_strict(self(df), lhs_names) if name not in rhs_names ] @@ -273,7 +273,7 @@ def series(df: FrameT) -> Sequence[SeriesOrExprT]: return [ *( x - for x, name in zip_equal(self(df), lhs_names) + for x, name in zip_strict(self(df), lhs_names) if name not in rhs_names ), *other(df), @@ -300,7 +300,7 @@ def __and__( def series(df: FrameT) -> Sequence[SeriesOrExprT]: lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) return [ - x for x, name in zip_equal(self(df), lhs_names) if name in rhs_names + x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names ] def names(df: FrameT) -> Sequence[str]: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 32df5604d7..cf103580a0 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -17,7 +17,7 @@ generate_temporary_column_name, not_implemented, parse_columns_to_drop, - zip_equal, + zip_strict, ) from narwhals.typing import CompliantLazyFrame @@ -285,7 +285,7 @@ def _join_left( ) extra = [ right_key if right_key not in self.columns else f"{right_key}{suffix}" - for left_key, right_key in zip_equal(left_on, right_on) + for left_key, right_key in zip_strict(left_on, right_on) if right_key != left_key ] return result_native.drop(columns=extra) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 20d443738a..3e5668d423 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -7,7 +7,7 @@ from narwhals._compliant import DepthTrackingGroupBy from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._utils import zip_equal +from narwhals._utils import zip_strict if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -139,7 +139,7 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn simple_aggregations.update( (alias, (output_name, agg_fn)) - for alias, output_name in zip_equal(aliases, output_names) + for alias, output_name in zip_strict(aliases, output_names) ) return DaskLazyFrame( self._grouped.agg(**simple_aggregations).reset_index(), diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index e9285d48ad..b60aba5340 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -26,7 +26,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation, zip_equal +from narwhals._utils import Implementation, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -263,7 +263,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) else: init_value, *values = [ - s.where(~nm, "") for s, nm in zip_equal(series, null_mask) + s.where(~nm, "") for s, nm in zip_strict(series, null_mask) ] separators = ( @@ -272,7 +272,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) result = reduce( operator.add, - (s + v for s, v in zip_equal(separators, values)), + (s + v for s, v in zip_strict(separators, values)), init_value, ) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index b18a1c7641..18e5492b6d 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -26,7 +26,7 @@ not_implemented, parse_columns_to_drop, requires, - zip_equal, + zip_strict, ) from narwhals.dependencies import get_duckdb from narwhals.exceptions import InvalidOperationError @@ -232,7 +232,7 @@ def schema(self) -> dict[str, DType]: column_name: native_to_narwhals_dtype( duckdb_dtype, self._version, deferred_time_zone ) - for column_name, duckdb_dtype in zip_equal( + for column_name, duckdb_dtype in zip_strict( self.native.columns, self.native.types ) } @@ -298,7 +298,7 @@ def join( assert right_on is not None # noqa: S101 it = ( col(f'lhs."{left}"') == col(f'rhs."{right}"') - for left, right in zip_equal(left_on, right_on) + for left, right in zip_strict(left_on, right_on) ) condition: Expression = reduce(and_, it) rel = self.native.set_alias("lhs").join( @@ -343,7 +343,7 @@ def join_asof( if by_left is not None and by_right is not None: conditions.extend( col(f'lhs."{left}"') == col(f'rhs."{right}"') - for left, right in zip_equal(by_left, by_right) + for left, right in zip_strict(by_left, by_right) ) else: by_left = by_right = [] @@ -403,12 +403,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> if nulls_last: it = ( col(name).nulls_last() if not desc else col(name).desc().nulls_last() - for name, desc in zip_equal(by, descending) + for name, desc in zip_strict(by, descending) ) else: it = ( col(name).nulls_first() if not desc else col(name).desc().nulls_first() - for name, desc in zip_equal(by, descending) + for name, desc in zip_strict(by, descending) ) return self._with_native(self.native.sort(*it)) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 740c2eb985..c11405976d 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -7,7 +7,7 @@ import duckdb.typing as duckdb_dtypes from duckdb.typing import DuckDBPyType -from narwhals._utils import Version, isinstance_or_issubclass, zip_equal +from narwhals._utils import Version, isinstance_or_issubclass, zip_strict from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: @@ -302,7 +302,7 @@ def generate_order_by_sql( return "" by_sql = ",".join( f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}" - for x, _descending, _nulls_last in zip_equal(order_by, descending, nulls_last) + for x, _descending, _nulls_last in zip_strict(order_by, descending, nulls_last) ) return f"order by {by_sql}" diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 75ae11624b..ec00300da6 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -8,7 +8,7 @@ from itertools import chain from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast -from narwhals._utils import is_compliant_expr, zip_equal +from narwhals._utils import is_compliant_expr, zip_strict from narwhals.dependencies import is_narwhals_series, is_numpy_array from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError @@ -104,10 +104,10 @@ def evaluate_output_names_and_aliases( if exclude: assert expr._metadata is not None # noqa: S101 if expr._metadata.expansion_kind.is_multi_unnamed(): - output_names, aliases = zip_equal( + output_names, aliases = zip_strict( *[ (x, alias) - for x, alias in zip_equal(output_names, aliases) + for x, alias in zip_strict(output_names, aliases) if x not in exclude ] ) @@ -626,6 +626,6 @@ def apply_n_ary_operation( compliant_expr.broadcast(kind) if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_equal(compliant_exprs, kinds) + for compliant_expr, kind in zip_strict(compliant_exprs, kinds) ) return function(*compliant_exprs) diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index b9256706f3..b16095250d 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -15,7 +15,7 @@ Version, not_implemented, parse_columns_to_drop, - zip_equal, + zip_strict, ) from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError @@ -308,7 +308,7 @@ def _convert_predicates( return left_on return [ cast("ir.BooleanColumn", (self.native[left] == other.native[right])) - for left, right in zip_equal(left_on, right_on) + for left, right in zip_strict(left_on, right_on) ] def collect_schema(self) -> dict[str, DType]: diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index f35c282b89..4f52df79e4 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -12,7 +12,7 @@ from narwhals._ibis.expr_struct import IbisExprStructNamespace from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype from narwhals._sql.expr import SQLExpr -from narwhals._utils import Implementation, Version, not_implemented, zip_equal +from narwhals._utils import Implementation, Version, not_implemented, zip_strict if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -128,7 +128,7 @@ def _sort( } yield from ( cast("ir.Column", mapping[(_desc, _nulls_last)](col)) - for col, _desc, _nulls_last in zip_equal(cols, descending, nulls_last) + for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last) ) @classmethod diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index fc86a8465a..0b33322b69 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -29,7 +29,7 @@ generate_temporary_column_name, parse_columns_to_drop, scale_bytes, - zip_equal, + zip_strict, ) from narwhals.dependencies import is_pandas_like_dataframe from narwhals.exceptions import InvalidOperationError, ShapeError @@ -562,7 +562,7 @@ def _join_left( ) extra = [ right_key if right_key not in self.columns else f"{right_key}{suffix}" - for left_key, right_key in zip_equal(left_on, right_on) + for left_key, right_key in zip_strict(left_on, right_on) if right_key != left_key ] # NOTE: Keep `inplace=True` to avoid making a redundant copy. diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 27278428fe..78fb725bce 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -9,7 +9,7 @@ from narwhals._compliant import EagerGroupBy from narwhals._exceptions import issue_warning from narwhals._expression_parsing import evaluate_output_names_and_aliases -from narwhals._utils import zip_equal +from narwhals._utils import zip_strict from narwhals.dependencies import is_pandas_like_dataframe if TYPE_CHECKING: @@ -284,7 +284,7 @@ def fn(df: pd.DataFrame) -> pd.Series[Any]: for expr in exprs for keys in expr(compliant) ) - out_group, out_names = zip_equal(*results) if results else ([], []) + out_group, out_names = zip_strict(*results) if results else ([], []) return into_series(out_group, index=out_names, context=ns).native return fn diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 59e62ea720..e3b975b811 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -17,7 +17,7 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT from narwhals._pandas_like.utils import is_non_nullable_boolean -from narwhals._utils import zip_equal +from narwhals._utils import zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -346,7 +346,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: # error: Cannot determine type of "values" [has-type] values: list[PandasLikeSeries] init_value, *values = [ - s.zip_with(~nm, "") for s, nm in zip_equal(series, null_mask) + s.zip_with(~nm, "") for s, nm in zip_strict(series, null_mask) ] sep_array = init_value.from_iterable( @@ -358,7 +358,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1]) result = reduce( operator.add, - (s + v for s, v in zip_equal(separators, values)), + (s + v for s, v in zip_strict(separators, values)), init_value, ) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 1dfdede2a3..9a7953c41a 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -8,7 +8,7 @@ from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype -from narwhals._utils import Implementation, requires, zip_equal +from narwhals._utils import Implementation, requires, zip_strict from narwhals.dependencies import is_numpy_array_2d from narwhals.dtypes import DType @@ -175,7 +175,7 @@ def concat_str( else: init_value, *values = [ pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String())) - for expr, nm in zip_equal(pl_exprs, null_mask) + for expr, nm in zip_strict(pl_exprs, null_mask) ] separators = [ pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1] @@ -184,7 +184,7 @@ def concat_str( result = pl.fold( # type: ignore[assignment] acc=init_value, function=operator.add, - exprs=[s + v for s, v in zip_equal(separators, values)], + exprs=[s + v for s, v in zip_strict(separators, values)], ) return self._expr(result, version=self._version) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 5b49968494..122659fa9f 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -22,7 +22,7 @@ generate_temporary_column_name, not_implemented, parse_columns_to_drop, - zip_equal, + zip_strict, ) from narwhals.exceptions import InvalidOperationError @@ -336,7 +336,7 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> for d in descending ) - sort_cols = [sort_f(col) for col, sort_f in zip_equal(by, sort_funcs)] + sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)] return self._with_native(self.native.sort(*sort_cols)) def drop_nulls(self, subset: Sequence[str] | None) -> Self: @@ -424,7 +424,7 @@ def join( and_, ( getattr(self.native, left_key) == getattr(other_native, right_key) - for left_key, right_key in zip_equal(left_on_, right_on_remapped) + for left_key, right_key in zip_strict(left_on_, right_on_remapped) ), ) if how == "full" diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index ea1c0d39d0..39f1f0f95f 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -16,7 +16,7 @@ true_divide, ) from narwhals._sql.expr import SQLExpr -from narwhals._utils import Implementation, Version, not_implemented, zip_equal +from narwhals._utils import Implementation, Version, not_implemented, zip_strict if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence @@ -142,7 +142,7 @@ def _sort( } yield from ( mapping[(_desc, _nulls_last)](col) - for col, _desc, _nulls_last in zip_equal(cols, descending, nulls_last) + for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last) ) def partition_by(self, *cols: Column | str) -> WindowSpec: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 0f9aeb2e2b..cc40931ff7 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,7 +19,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import zip_equal +from narwhals._utils import zip_strict if TYPE_CHECKING: from collections.abc import Iterable @@ -189,7 +189,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: else: init_value, *values = [ df._F.when(~nm, col).otherwise(df._F.lit("")) - for col, nm in zip_equal(cols_casted, null_mask) + for col, nm in zip_strict(cols_casted, null_mask) ] separators = ( @@ -200,7 +200,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: lambda x, y: df._F.format_string("%s%s", x, y), ( df._F.format_string("%s%s", s, v) - for s, v in zip_equal(separators, values) + for s, v in zip_strict(separators, values) ), init_value, ) diff --git a/narwhals/_sql/group_by.py b/narwhals/_sql/group_by.py index e8991e4ae9..ab5d669e02 100644 --- a/narwhals/_sql/group_by.py +++ b/narwhals/_sql/group_by.py @@ -5,7 +5,7 @@ from narwhals._compliant.group_by import CompliantGroupBy, ParseKeysGroupBy from narwhals._compliant.typing import CompliantLazyFrameT_co, NativeExprT_co from narwhals._sql.typing import SQLExprT_contra -from narwhals._utils import zip_equal +from narwhals._utils import zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -29,13 +29,13 @@ def _evaluate_expr(self, expr: SQLExprT_contra, /) -> Iterator[NativeExprT_co]: native_exprs = expr(self.compliant) if expr._is_multi_output_unnamed(): exclude = {*self._keys, *self._output_key_names} - for native_expr, name, alias in zip_equal( + for native_expr, name, alias in zip_strict( native_exprs, output_names, aliases ): if name not in exclude: yield expr._alias_native(native_expr, alias) else: - for native_expr, alias in zip_equal(native_exprs, aliases): + for native_expr, alias in zip_strict(native_exprs, aliases): yield expr._alias_native(native_expr, alias) def _evaluate_exprs( diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 9b7c257315..c81202f0ba 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1077,15 +1077,17 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: @overload -def zip_equal(it1: Iterable[_T1], it2: Iterable[_T2], /) -> Iterable[tuple[_T1, _T2]]: ... +def zip_strict( + it1: Iterable[_T1], it2: Iterable[_T2], / +) -> Iterable[tuple[_T1, _T2]]: ... @overload -def zip_equal(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: ... +def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: ... # https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 -def zip_equal(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: +def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: # For trivial cases, use pure zip. if len(iterables) < 2: return zip(*iterables) @@ -1102,10 +1104,10 @@ def first_tail() -> Any: # Tail for the zip def zip_tail() -> Any: if not first_stopped: # pragma: no cover - msg = "zip_equal: first iterable is longer" + msg = "zip_strict: first iterable is longer" raise ValueError(msg) for _ in chain.from_iterable(rest): # pragma: no cover - msg = "zip_equal: first iterable is shorter" + msg = "zip_strict: first iterable is shorter" raise ValueError(msg) yield diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 08a9e0f91c..eeeb48a100 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -35,7 +35,7 @@ is_sequence_like, is_slice_none, supports_arrow_c_stream, - zip_equal, + zip_strict, ) from narwhals.dependencies import ( get_polars, @@ -167,7 +167,7 @@ def with_columns( compliant_exprs, kinds = self._flatten_and_extract(*exprs, **named_exprs) compliant_exprs = [ compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_equal(compliant_exprs, kinds) + for compliant_expr, kind in zip_strict(compliant_exprs, kinds) ] return self._with_compliant(self._compliant_frame.with_columns(*compliant_exprs)) @@ -191,7 +191,7 @@ def select( return self._with_compliant(self._compliant_frame.aggregate(*compliant_exprs)) compliant_exprs = [ compliant_expr.broadcast(kind) if is_scalar_like(kind) else compliant_expr - for compliant_expr, kind in zip_equal(compliant_exprs, kinds) + for compliant_expr, kind in zip_strict(compliant_exprs, kinds) ] return self._with_compliant(self._compliant_frame.select(*compliant_exprs)) @@ -1699,7 +1699,7 @@ def group_by( _keys = [ k if is_expr else col(k) - for k, is_expr in zip_equal(flat_keys, key_is_expr_or_series) + for k, is_expr in zip_strict(flat_keys, key_is_expr_or_series) ] expr_flat_keys, kinds = self._flatten_and_extract(*_keys) @@ -2932,7 +2932,7 @@ def group_by( raise NotImplementedError(msg) _keys = [ - k if is_expr else col(k) for k, is_expr in zip_equal(flat_keys, key_is_expr) + k if is_expr else col(k) for k, is_expr in zip_strict(flat_keys, key_is_expr) ] expr_flat_keys, kinds = self._flatten_and_extract(*_keys) diff --git a/narwhals/schema.py b/narwhals/schema.py index a33b765cd6..a34379983d 100644 --- a/narwhals/schema.py +++ b/narwhals/schema.py @@ -10,7 +10,7 @@ from functools import partial from typing import TYPE_CHECKING, cast -from narwhals._utils import Implementation, Version, zip_equal +from narwhals._utils import Implementation, Version, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -145,7 +145,7 @@ def to_pandas( raise ValueError(msg) return { name: to_native_dtype(dtype=dtype, dtype_backend=backend) - for name, dtype, backend in zip_equal(self.keys(), self.values(), backends) + for name, dtype, backend in zip_strict(self.keys(), self.values(), backends) } def to_polars(self) -> pl.Schema: diff --git a/tests/utils.py b/tests/utils.py index 482e4362e1..c1a2991916 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,11 +12,11 @@ import pyarrow as pa import narwhals as nw -from narwhals._utils import Implementation, parse_version +from narwhals._utils import Implementation, parse_version, zip_strict from narwhals.translate import from_native if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Mapping from typing_extensions import TypeAlias @@ -45,13 +45,6 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: ConstructorLazy: TypeAlias = Callable[[Any], "NativeLazyFrame"] -def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: - if len(left) != len(right): - msg = f"{len(left)=} != {len(right)=}\nLeft: {left}\nRight: {right}" # pragma: no cover - raise ValueError(msg) # pragma: no cover - return zip(left, right) - - def _to_comparable_list(column_values: Any) -> Any: if isinstance(column_values, nw.Series) and isinstance( column_values.to_native(), pa.Array From 2f7a68d7620a4fdba42d3aeec30aec7230bee49f Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Sun, 17 Aug 2025 17:01:45 +0100 Subject: [PATCH 04/10] Update narwhals/_utils.py Co-authored-by: Dan Redding <125183946+dangotbanned@users.noreply.github.com> --- narwhals/_utils.py | 77 ++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index c81202f0ba..aa099e107a 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1076,46 +1076,43 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: return obj_any -@overload -def zip_strict( - it1: Iterable[_T1], it2: Iterable[_T2], / -) -> Iterable[tuple[_T1, _T2]]: ... - - -@overload -def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: ... - - -# https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 -def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: - # For trivial cases, use pure zip. - if len(iterables) < 2: - return zip(*iterables) - - # Tail for the first iterable - first_stopped = False - - def first_tail() -> Any: - nonlocal first_stopped - first_stopped = True - return - yield - - # Tail for the zip - def zip_tail() -> Any: - if not first_stopped: # pragma: no cover - msg = "zip_strict: first iterable is longer" - raise ValueError(msg) - for _ in chain.from_iterable(rest): # pragma: no cover - msg = "zip_strict: first iterable is shorter" - raise ValueError(msg) - yield - - # Put the pieces together - iterables_it = iter(iterables) - first = chain(next(iterables_it), first_tail()) - rest = list(map(iter, iterables_it)) - return chain(zip(first, *rest), zip_tail()) +if TYPE_CHECKING: + zip_strict = partial(zip, strict=True) +else: + import sys + + if sys.version_info >= (3, 10): + zip_strict = partial(zip, strict=True) + else: # https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 + + def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: + # For trivial cases, use pure zip. + if len(iterables) < 2: + return zip(*iterables) + # Tail for the first iterable + first_stopped = False + + def first_tail() -> Any: + nonlocal first_stopped + first_stopped = True + return + yield + + # Tail for the zip + def zip_tail() -> Any: + if not first_stopped: # pragma: no cover + msg = "zip_strict: first iterable is longer" + raise ValueError(msg) + for _ in chain.from_iterable(rest): # pragma: no cover + msg = "zip_strict: first iterable is shorter" + raise ValueError(msg) + yield + + # Put the pieces together + iterables_it = iter(iterables) + first = chain(next(iterables_it), first_tail()) + rest = list(map(iter, iterables_it)) + return chain(zip(first, *rest), zip_tail()) def _is_range_index(obj: Any, native_namespace: Any) -> TypeIs[pd.RangeIndex]: From e4be9c47fbe83fdc1f56a385f2273d2ace71735f Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:06:39 +0000 Subject: [PATCH 05/10] fix my goof --- narwhals/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index aa099e107a..1fcad1c6b6 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1076,6 +1076,8 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: return obj_any +from functools import partial + if TYPE_CHECKING: zip_strict = partial(zip, strict=True) else: From b1264171ffbf20642cb06765f23743ebd24221ef Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:07:48 +0000 Subject: [PATCH 06/10] ok ruff --- narwhals/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 1fcad1c6b6..b498c91402 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -5,7 +5,7 @@ from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence from datetime import timezone from enum import Enum, auto -from functools import cache, lru_cache, wraps +from functools import cache, lru_cache, partial, wraps from importlib.util import find_spec from inspect import getattr_static, getdoc from itertools import chain @@ -1076,8 +1076,6 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: return obj_any -from functools import partial - if TYPE_CHECKING: zip_strict = partial(zip, strict=True) else: From edcd83c59f15d830bc11e6b628accd4890906774 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:17:39 +0000 Subject: [PATCH 07/10] refactor: avoid `zip` entirely From https://github.com/narwhals-dev/narwhals/blob/4966868bfd94655f395c808ae12562596fe6e73b/narwhals/schema.py#L95-L131 --- narwhals/_arrow/dataframe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 5feb1fada4..ae785f184d 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -288,10 +288,9 @@ def _select_multi_name( @property def schema(self) -> dict[str, DType]: - schema = self.native.schema return { - name: native_to_narwhals_dtype(dtype, self._version) - for name, dtype in zip_strict(schema.names, schema.types) + field.name: native_to_narwhals_dtype(field.type, self._version) + for field in self.native.schema } def collect_schema(self) -> dict[str, DType]: From 98f2b5c99bc0b804f5206617383ed150763c1acf Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:35:04 +0100 Subject: [PATCH 08/10] add no cover --- narwhals/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_utils.py b/narwhals/_utils.py index b498c91402..cc18eca00c 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -1083,7 +1083,8 @@ def maybe_reset_index(obj: FrameOrSeriesT) -> FrameOrSeriesT: if sys.version_info >= (3, 10): zip_strict = partial(zip, strict=True) - else: # https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 + else: # pragma: no cover + # https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python/69485272#69485272 def zip_strict(*iterables: Iterable[Any]) -> Iterable[tuple[Any, ...]]: # For trivial cases, use pure zip. From 711f5038fb71c1e2c6f0ce7e9d27aded0e48e5c5 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:40:11 +0100 Subject: [PATCH 09/10] fix merge --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index e831830f49..57ff53cd0b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from narwhals.translate import from_native if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence from typing_extensions import TypeAlias From d735ffc06f6c9e61d3286708fbc4d6830bd1b59e Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:45:37 +0100 Subject: [PATCH 10/10] replace `zip` with `zip_strict` for `group_by` that was failing before --- narwhals/_compliant/group_by.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index ead3ee9674..8dfb93aa35 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -16,7 +16,7 @@ EagerExprT_contra, NarwhalsAggregation, ) -from narwhals._utils import is_sequence_of +from narwhals._utils import is_sequence_of, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence @@ -115,7 +115,7 @@ def _temporary_name(key: str) -> str: if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output() # otherwise it's single named and we can use Expr.alias else key.alias(_temporary_name(new_names[0])) - for key, new_names in zip(keys, keys_aliases) + for key, new_names in zip_strict(keys, keys_aliases) ] return ( compliant_frame.with_columns(*safe_keys),