Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from narwhals._compliant.typing import (
CompliantExprT,
CompliantExprT_co,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
Expand All @@ -23,7 +24,7 @@
from narwhals.dependencies import is_numpy_array_2d

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence

from typing_extensions import TypeAlias, TypeIs

Expand Down Expand Up @@ -98,6 +99,58 @@ def is_native(self, obj: Any, /) -> TypeIs[Any]:
...


class AlignDiagonal(Protocol[CompliantFrameT, CompliantExprT_co]):
"""Mixin to help support `"diagonal*"` concatenation."""

def lit(
self, value: NonNestedLiteral, dtype: IntoDType | None
) -> CompliantExprT_co: ...
def align_diagonal(
self, frames: Collection[CompliantFrameT], /
) -> Sequence[CompliantFrameT]:
"""Prepare frames with differing schemas for vertical concatenation.

Adapted from [`convert_diagonal_concat`].

[`convert_diagonal_concat`]: https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-plan/src/plans/conversion/dsl_to_ir/concat.rs#L10-L68
"""
schemas = [frame.collect_schema() for frame in frames]
# 1 - Take the first schema, collecting any fields from the remaining that introduce new names
# Subtle difference from `dict |= dict |= ...` as we preserve the first `DType`, rather than the last.
it_schemas = iter(schemas)
union = dict(next(it_schemas))
seen = union.keys()
for schema in it_schemas:
union.update((nm, dtype) for nm, dtype in schema.items() if nm not in seen)
return self._align_diagonal(frames, schemas, union)

def _align_diagonal(
self,
frames: Iterable[CompliantFrameT],
schemas: Iterable[IntoSchema],
union_schema: IntoSchema,
) -> Sequence[CompliantFrameT]:
union_names = union_schema.keys()
# Lazily populate null expressions as needed, shared across frames
null_exprs: dict[str, CompliantExprT_co] = {}

def iter_missing_exprs(missing: Iterable[str]) -> Iterator[CompliantExprT_co]:
nonlocal null_exprs
for name in missing:
if (expr := null_exprs.get(name)) is None:
dtype = union_schema[name]
expr = null_exprs[name] = self.lit(None, dtype).alias(name)
yield expr

def align(df: CompliantFrameT, columns: KeysView[str]) -> CompliantFrameT:
if missing := union_names - columns:
df = df.with_columns(*iter_missing_exprs(missing))
# Even if all fields are present, we always reorder the columns to match between frames.
return df.simple_select(*union_names)

return [align(frame, schema.keys()) for frame, schema in zip(frames, schemas)]


class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
Expand Down
27 changes: 16 additions & 11 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ibis
import ibis.expr.types as ir

from narwhals._compliant.namespace import AlignDiagonal
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
Expand All @@ -26,7 +27,10 @@
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral


class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]):
class IbisNamespace(
SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"],
AlignDiagonal[IbisLazyFrame, IbisExpr],
):
_implementation: Implementation = Implementation.IBIS

def __init__(self, *, version: Version) -> None:
Expand Down Expand Up @@ -63,17 +67,18 @@ def _coalesce(self, *exprs: ir.Value) -> ir.Value:
def concat(
self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod
) -> IbisLazyFrame:
frames: Sequence[IbisLazyFrame] = tuple(items)
if how == "diagonal":
msg = "diagonal concat not supported for Ibis. Please join instead."
raise NotImplementedError(msg)

items = list(items)
native_items = [item.native for item in items]
schema = items[0].schema
if not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
return self._lazyframe.from_native(ibis.union(*native_items), context=self)
frames = self.align_diagonal(frames)
try:
result = ibis.union(*(lf.native for lf in frames))
except ibis.IbisError:
first = frames[0].schema
if not all(x.schema == first for x in frames[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg) from None
raise
return frames[0]._with_native(result)

def concat_str(
self, *exprs: IbisExpr, separator: str, ignore_nulls: bool
Expand Down
69 changes: 62 additions & 7 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import datetime as dt
import re
from typing import TYPE_CHECKING, Any

import pytest

import narwhals as nw
from narwhals.exceptions import InvalidOperationError
from tests.utils import Constructor, ConstructorEager, assert_equal_data
from narwhals._utils import Implementation
from narwhals.exceptions import InvalidOperationError, NarwhalsError
from tests.utils import POLARS_VERSION, Constructor, ConstructorEager, assert_equal_data

if TYPE_CHECKING:
from collections.abc import Iterator


def test_concat_horizontal(constructor_eager: ConstructorEager) -> None:
Expand Down Expand Up @@ -61,11 +67,7 @@ def test_concat_vertical(constructor: Constructor) -> None:
nw.concat([df_left, df_left.select("d")], how="vertical").collect()


def test_concat_diagonal(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "ibis" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_concat_diagonal(constructor: Constructor) -> None:
Copy link
Member Author

@dangotbanned dangotbanned Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(#3404 (comment))

I do think there's something here, I just wanna experiment a bit (and add more tests) ❤️

@FBruzzesi okay this led to discovering a bug (but not in any of the new code or your suggestion)

ibis is the only backend that doesn't guarantee the order of union.

I found that out by using the 3-tabled version of the test:

That fails for ibis, but it turns out "vertical" fails too:

Show test

def test_concat_vertical_bigger(constructor: Constructor) -> None:
    data_1 = {"a": [1, 2], "b": [3, 4], "c": [0, None]}
    data_2 = {"a": [5, 6], "b": [0, None], "c": [7, 8]}
    data_3 = {"a": [0, None], "b": [9, 10], "c": [11, 12]}
    expected = {
        "a": [1, 2, 5, 6, 0, None],
        "b": [3, 4, 0, None, 9, 10],
        "c": [0, None, 7, 8, 11, 12],
    }
    df_1 = nw.from_native(constructor(data_1)).lazy()
    df_2 = nw.from_native(constructor(data_2)).lazy()
    df_3 = nw.from_native(constructor(data_3)).lazy()
    result = nw.concat([df_1, df_2, df_3], how="vertical")
    assert_equal_data(result, expected)

Show error

E               AssertionError: Mismatch at index 0, key a: 0 != 1
E               Expected: {'a': [1, 2, 5, 6, 0, None], 'b': [3, 4, 0, None, 9, 10], 'c': [0, None, 7, 8, 11, 12]}
E               Got: {'a': [0, None, 5, 6, 1, 2], 'b': [9, 10, 0, None, 3, 4], 'c': [11, 12, 7, 8, 0, None]}

I suppose I'm stuck with testing two tables for now then 😂
I've added (a8e8388), but should probably follow this up with another issue.
(We (and polars) don't document that it is ordered, but we do test for it and polars.union was recently introduced for unordered)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ibis is the only backend that doesn't guarantee the order of union.

For a moment I thought order of columns, and I panicked sooo much 🤯

No guarantee in row order makes sense. You can add an index column and sort by such

data_1 = {"a": [1, 3], "b": [4, 6]}
data_2 = {"a": [100, 200], "z": ["x", "y"]}
expected = {
Expand All @@ -83,3 +85,56 @@ def test_concat_diagonal(

with pytest.raises(ValueError, match="No items"):
nw.concat([], how="diagonal")


def _from_natives(
constructor: Constructor, *sources: dict[str, list[Any]]
) -> Iterator[nw.LazyFrame[Any]]:
yield from (nw.from_native(constructor(data)).lazy() for data in sources)


def test_concat_diagonal_bigger(constructor: Constructor) -> None:
# NOTE: `ibis.union` doesn't guarantee the order of outputs
# https://github.com/narwhals-dev/narwhals/pull/3404#discussion_r2694556781
data_1 = {"idx": [1, 2], "a": [1, 2], "b": [3, 4]}
data_2 = {"a": [5, 6], "c": [7, 8], "idx": [3, 4]}
data_3 = {"b": [9, 10], "idx": [5, 6], "c": [11, 12]}
expected = {
"idx": [1, 2, 3, 4, 5, 6],
"a": [1, 2, 5, 6, None, None],
"b": [3, 4, None, None, 9, 10],
"c": [None, None, 7, 8, 11, 12],
}
dfs = _from_natives(constructor, data_1, data_2, data_3)
result = nw.concat(dfs, how="diagonal").sort("idx")
assert_equal_data(result, expected)


def test_concat_diagonal_invalid(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
data_1 = {"a": [1, 3], "b": [4, 6]}
data_2 = {
"a": [dt.datetime(2000, 1, 1), dt.datetime(2000, 1, 2)],
"b": [4, 6],
"z": ["x", "y"],
}
df_1 = nw.from_native(constructor(data_1)).lazy()
bad_schema = nw.from_native(constructor(data_2)).lazy()
impl = df_1.implementation
request.applymarker(
pytest.mark.xfail(
impl not in {Implementation.IBIS, Implementation.POLARS},
reason=f"{impl!r} does not validate schemas for `concat(how='diagonal')",
)
)
context: Any
if impl.is_polars() and POLARS_VERSION < (1,): # pragma: no cover
context = pytest.raises(
NarwhalsError,
match=re.compile(r"(int.+datetime)|(datetime.+int)", re.IGNORECASE),
)
else:
context = pytest.raises((InvalidOperationError, TypeError), match=r"same schema")
with context:
nw.concat([df_1, bad_schema], how="diagonal").collect().to_dict(as_series=False)
Loading