Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parse_columns_to_drop,
scale_bytes,
supports_arrow_c_stream,
zip_strict,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ShapeError
Expand Down Expand Up @@ -202,7 +203,7 @@ def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
return self.native.to_pylist()

def iter_columns(self) -> Iterator[ArrowSeries]:
for name, series in zip(self.columns, self.native.itercolumns()):
for name, series in zip_strict(self.columns, self.native.itercolumns()):
yield ArrowSeries.from_native(series, context=self, name=name)

_iter_columns = iter_columns
Expand All @@ -216,7 +217,7 @@ def iter_rows(
if not named:
for i in range(0, num_rows, buffer_size):
rows = df[i : i + buffer_size].to_pydict().values()
yield from zip(*rows)
yield from zip_strict(*rows)
else:
for i in range(0, num_rows, buffer_size):
yield from df[i : i + buffer_size].to_pylist()
Expand Down Expand Up @@ -287,10 +288,9 @@ def _select_multi_name(

@property
def schema(self) -> dict[str, DType]:
schema = self.native.schema
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in zip(schema.names, schema.types)
field.name: native_to_narwhals_dtype(field.type, self._version)
for field in self.native.schema
}

def collect_schema(self) -> dict[str, DType]:
Expand Down Expand Up @@ -431,7 +431,7 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
else:
sorting = [
(key, "descending" if is_descending else "ascending")
for key, is_descending in zip(by, descending)
for key, is_descending in zip_strict(by, descending)
]

null_placement = "at_end" if nulls_last else "at_start"
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
LazyExprT,
NativeExprT,
)
from narwhals._utils import _StoresCompliant, qualified_type_name
from narwhals._utils import _StoresCompliant, qualified_type_name, zip_strict
from narwhals.dependencies import is_numpy_array, is_numpy_scalar

if TYPE_CHECKING:
Expand Down Expand Up @@ -282,13 +282,13 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]:
if alias_output_names:
return [
series.alias(name)
for series, name in zip(
for series, name in zip_strict(
self(df), alias_output_names(self._evaluate_output_names(df))
)
]
return [
series.alias(name)
for series, name in zip(self(df), self._evaluate_output_names(df))
for series, name in zip_strict(self(df), self._evaluate_output_names(df))
]

return self.__class__(
Expand Down Expand Up @@ -772,7 +772,7 @@ def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
)
result = tuple(
from_numpy(array).alias(output_name)
for array, output_name in zip(udf_series_out, output_names)
for array, output_name in zip_strict(udf_series_out, output_names)
)
else:
result = udf_series_out
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_compliant/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
EagerExprT_contra,
NarwhalsAggregation,
)
from narwhals._utils import is_sequence_of
from narwhals._utils import is_sequence_of, zip_strict

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand Down Expand Up @@ -115,7 +115,7 @@ def _temporary_name(key: str) -> str:
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
# otherwise it's single named and we can use Expr.alias
else key.alias(_temporary_name(new_names[0]))
for key, new_names in zip(keys, keys_aliases)
for key, new_names in zip_strict(keys, keys_aliases)
]
return (
compliant_frame.with_columns(*safe_keys),
Expand Down
19 changes: 14 additions & 5 deletions narwhals/_compliant/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
dtype_matches_time_unit_and_time_zone,
get_column_names,
is_compliant_dataframe,
zip_strict,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,7 +78,7 @@ def _iter_columns_dtypes(
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...

def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip(self._iter_columns(df), df.columns)
yield from zip_strict(self._iter_columns(df), df.columns)

def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
Expand Down Expand Up @@ -192,7 +193,7 @@ def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()

def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip(self._iter_columns(df), df.schema.values())
yield from zip_strict(self._iter_columns(df), df.schema.values())


class CompliantSelector(
Expand Down Expand Up @@ -244,7 +245,9 @@ def __sub__(
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
]

def names(df: FrameT) -> Sequence[str]:
Expand All @@ -268,7 +271,11 @@ def __or__(
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
*(
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
),
*other(df),
]

Expand All @@ -292,7 +299,9 @@ def __and__(

def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
return [
x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names
]

def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.typing import CompliantLazyFrame

Expand Down Expand Up @@ -284,7 +285,7 @@ def _join_left(
)
extra = [
right_key if right_key not in self.columns else f"{right_key}{suffix}"
for left_key, right_key in zip(left_on, right_on)
for left_key, right_key in zip_strict(left_on, right_on)
if right_key != left_key
]
return result_native.drop(columns=extra)
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from narwhals._compliant import DepthTrackingGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import zip_strict

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -138,7 +139,7 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
simple_aggregations.update(
(alias, (output_name, agg_fn))
for alias, output_name in zip(aliases, output_names)
for alias, output_name in zip_strict(aliases, output_names)
)
return DaskLazyFrame(
self._grouped.agg(**simple_aggregations).reset_index(),
Expand Down
8 changes: 5 additions & 3 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation
from narwhals._utils import Implementation, zip_strict

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -263,15 +263,17 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip(series, null_mask)
s.where(~nm, "") for s, nm in zip_strict(series, null_mask)
]

separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
operator.add, (s + v for s, v in zip(separators, values)), init_value
operator.add,
(s + v for s, v in zip_strict(separators, values)),
init_value,
)

return [result]
Expand Down
13 changes: 8 additions & 5 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
not_implemented,
parse_columns_to_drop,
requires,
zip_strict,
)
from narwhals.dependencies import get_duckdb
from narwhals.exceptions import InvalidOperationError
Expand Down Expand Up @@ -231,7 +232,9 @@ def schema(self) -> dict[str, DType]:
column_name: native_to_narwhals_dtype(
duckdb_dtype, self._version, deferred_time_zone
)
for column_name, duckdb_dtype in zip(self.native.columns, self.native.types)
for column_name, duckdb_dtype in zip_strict(
self.native.columns, self.native.types
)
}

@property
Expand Down Expand Up @@ -295,7 +298,7 @@ def join(
assert right_on is not None # noqa: S101
it = (
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip(left_on, right_on)
for left, right in zip_strict(left_on, right_on)
)
condition: Expression = reduce(and_, it)
rel = self.native.set_alias("lhs").join(
Expand Down Expand Up @@ -340,7 +343,7 @@ def join_asof(
if by_left is not None and by_right is not None:
conditions.extend(
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip(by_left, by_right)
for left, right in zip_strict(by_left, by_right)
)
else:
by_left = by_right = []
Expand Down Expand Up @@ -400,12 +403,12 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
if nulls_last:
it = (
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
for name, desc in zip(by, descending)
for name, desc in zip_strict(by, descending)
)
else:
it = (
col(name).nulls_first() if not desc else col(name).desc().nulls_first()
for name, desc in zip(by, descending)
for name, desc in zip_strict(by, descending)
)
return self._with_native(self.native.sort(*it))

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import duckdb.typing as duckdb_dtypes
from duckdb.typing import DuckDBPyType

from narwhals._utils import Version, isinstance_or_issubclass
from narwhals._utils import Version, isinstance_or_issubclass, zip_strict
from narwhals.exceptions import ColumnNotFoundError

if TYPE_CHECKING:
Expand Down Expand Up @@ -302,7 +302,7 @@ def generate_order_by_sql(
return ""
by_sql = ",".join(
f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}"
for x, _descending, _nulls_last in zip(order_by, descending, nulls_last)
for x, _descending, _nulls_last in zip_strict(order_by, descending, nulls_last)
)
return f"order by {by_sql}"

Expand Down
8 changes: 4 additions & 4 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast

from narwhals._utils import is_compliant_expr
from narwhals._utils import is_compliant_expr, zip_strict
from narwhals.dependencies import is_narwhals_series, is_numpy_array
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError

Expand Down Expand Up @@ -104,10 +104,10 @@ def evaluate_output_names_and_aliases(
if exclude:
assert expr._metadata is not None # noqa: S101
if expr._metadata.expansion_kind.is_multi_unnamed():
output_names, aliases = zip(
output_names, aliases = zip_strict(
*[
(x, alias)
for x, alias in zip(output_names, aliases)
for x, alias in zip_strict(output_names, aliases)
if x not in exclude
]
)
Expand Down Expand Up @@ -626,6 +626,6 @@ def apply_n_ary_operation(
compliant_expr.broadcast(kind)
if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
else compliant_expr
for compliant_expr, kind in zip(compliant_exprs, kinds)
for compliant_expr, kind in zip_strict(compliant_exprs, kinds)
)
return function(*compliant_exprs)
3 changes: 2 additions & 1 deletion narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Version,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError

Expand Down Expand Up @@ -307,7 +308,7 @@ def _convert_predicates(
return left_on
return [
cast("ir.BooleanColumn", (self.native[left] == other.native[right]))
for left, right in zip(left_on, right_on)
for left, right in zip_strict(left_on, right_on)
]

def collect_schema(self) -> dict[str, DType]:
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from narwhals._ibis.expr_struct import IbisExprStructNamespace
from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, not_implemented
from narwhals._utils import Implementation, Version, not_implemented, zip_strict

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -128,7 +128,7 @@ def _sort(
}
yield from (
cast("ir.Column", mapping[(_desc, _nulls_last)](col))
for col, _desc, _nulls_last in zip(cols, descending, nulls_last)
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
)

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
generate_temporary_column_name,
parse_columns_to_drop,
scale_bytes,
zip_strict,
)
from narwhals.dependencies import is_pandas_like_dataframe
from narwhals.exceptions import InvalidOperationError, ShapeError
Expand Down Expand Up @@ -561,7 +562,7 @@ def _join_left(
)
extra = [
right_key if right_key not in self.columns else f"{right_key}{suffix}"
for left_key, right_key in zip(left_on, right_on)
for left_key, right_key in zip_strict(left_on, right_on)
if right_key != left_key
]
# NOTE: Keep `inplace=True` to avoid making a redundant copy.
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals._compliant import EagerGroupBy
from narwhals._exceptions import issue_warning
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import zip_strict
from narwhals.dependencies import is_pandas_like_dataframe

if TYPE_CHECKING:
Expand Down Expand Up @@ -283,7 +284,7 @@ def fn(df: pd.DataFrame) -> pd.Series[Any]:
for expr in exprs
for keys in expr(compliant)
)
out_group, out_names = zip(*results) if results else ([], [])
out_group, out_names = zip_strict(*results) if results else ([], [])
return into_series(out_group, index=out_names, context=ns).native

return fn
Expand Down
Loading
Loading