Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals.exceptions import ShapeError
from narwhals.utils import _SeriesNamespace
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -278,6 +279,9 @@ def extract_dataframe_comparand(
) -> ArrowChunkedArray:
"""Extract native Series, broadcasting to `length` if necessary."""
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native

import numpy as np # ignore-banned-import
Expand Down
7 changes: 6 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Sized
from typing import TypeVar
from typing import cast

import pandas as pd

from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import DuplicateError
from narwhals.exceptions import ShapeError
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

T = TypeVar("T")
T = TypeVar("T", bound=Sized)

if TYPE_CHECKING:
from pandas._typing import Dtype as PandasDtype
Expand Down Expand Up @@ -137,6 +139,9 @@ def extract_dataframe_comparand(
if other._broadcast:
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)
if (len_other := len(other)) != (len_idx := len(index)):
msg = f"Expected object of length {len_idx}, got: {len_other}."
raise ShapeError(msg)
if other._native_series.index is not index:
return set_index(
other._native_series,
Expand Down
10 changes: 10 additions & 0 deletions tests/frame/with_columns_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.exceptions import ShapeError
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand Down Expand Up @@ -68,3 +69,12 @@ def test_with_columns_dtypes_single_row(
df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical))
result = df.with_columns(nw.col("a"))
assert result.collect_schema() == {"a": nw.Categorical}


def test_with_columns_series_shape_mismatch(constructor_eager: ConstructorEager) -> None:
df1 = nw.from_native(constructor_eager({"first": [1, 2, 3]}), eager_only=True)
second = nw.from_native(constructor_eager({"second": [1, 2, 3, 4]}), eager_only=True)[
"second"
]
with pytest.raises(ShapeError):
df1.with_columns(second=second)
Loading