Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fa82752
feat: Add testing.assert_frame_equal
FBruzzesi Oct 15, 2025
e12252c
it should be it
FBruzzesi Oct 16, 2025
d109448
WIP: Unit test
FBruzzesi Oct 16, 2025
db761a8
arguably improve typing(?), raise if only nested dtypes are available
FBruzzesi Oct 16, 2025
748b971
almost there with the testing
FBruzzesi Oct 17, 2025
bbf7a2f
Docstrings, more comments, docs
FBruzzesi Oct 18, 2025
9417b60
try within 'if TYPE_CHECKING'
FBruzzesi Oct 18, 2025
e2a7c7c
skip old pandas nested dtypes
FBruzzesi Oct 18, 2025
d20fdc4
add missing backtick, improve comment
FBruzzesi Oct 18, 2025
dadcf34
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Oct 18, 2025
c0c6210
merge main, simplify, rm unused
FBruzzesi Oct 20, 2025
128e05a
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Oct 22, 2025
d3d4ed8
ci: Test fairlearn using pytest marker
FBruzzesi Oct 23, 2025
149c9c6
use uv run
FBruzzesi Oct 23, 2025
a27931f
one more try
FBruzzesi Oct 24, 2025
c3046b0
one more try
FBruzzesi Oct 24, 2025
6d9528b
ok use system py
FBruzzesi Oct 24, 2025
aa03a19
skip if pyarrow is not installed
FBruzzesi Oct 24, 2025
1672b53
something went wrong in merging
FBruzzesi Oct 24, 2025
2bdfb68
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Oct 26, 2025
9e1c539
fix fixture name
FBruzzesi Oct 26, 2025
91bb851
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Oct 29, 2025
c169685
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Oct 30, 2025
253f3ad
fix docstrings
FBruzzesi Oct 30, 2025
a97f64a
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
handler: python
options:
members:
- assert_frame_equal
- assert_series_equal
3 changes: 2 additions & 1 deletion narwhals/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from narwhals.testing.asserts.frame import assert_frame_equal
from narwhals.testing.asserts.series import assert_series_equal

__all__ = ("assert_series_equal",)
__all__ = ("assert_frame_equal", "assert_series_equal")
266 changes: 266 additions & 0 deletions narwhals/testing/asserts/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from narwhals._utils import Implementation, qualified_type_name
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dependencies import is_narwhals_dataframe, is_narwhals_lazyframe
from narwhals.testing.asserts.series import assert_series_equal
from narwhals.testing.asserts.utils import (
raise_assertion_error,
raise_frame_assertion_error,
)

if TYPE_CHECKING:
from narwhals.typing import DataFrameT, LazyFrameT

GUARANTEES_ROW_ORDER = {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
Implementation.PYARROW,
Implementation.POLARS,
Implementation.DASK,
}


def assert_frame_equal(
left: DataFrameT | LazyFrameT,
right: DataFrameT | LazyFrameT,
*,
check_row_order: bool = True,
check_column_order: bool = True,
check_dtypes: bool = True,
check_exact: bool = False,
rel_tol: float = 1e-5,
abs_tol: float = 1e-8,
categorical_as_str: bool = False,
) -> None:
"""Assert that the left and right frames are equal.

Raises a detailed `AssertionError` if the frames differ.
This function is intended for use in unit tests.

Notes:
In the case of backends that do not guarantee the row order, such as DuckDB, Ibis,
PySpark, and SQLFrame, `check_row_order` argument is ignored and the comparands
are sorted by all the columns regardless.

Arguments:
left: The first DataFrame or LazyFrame to compare.
right: The second DataFrame or LazyFrame to compare.
check_row_order: Requires row order to match. This flag is ignored for backends
that do not guarantee row order such as DuckDB, Ibis, PySpark, SQLFrame.
check_column_order: Requires column order to match.
check_dtypes: Requires data types to match.
check_exact: Requires float values to match exactly. If set to `False`, values are
considered equal when within tolerance of each other (see `rel_tol` and `abs_tol`).
Only affects columns with a Float data type.
rel_tol: Relative tolerance for inexact checking. Fraction of values in `right`.
abs_tol: Absolute tolerance for inexact checking.
categorical_as_str: Cast categorical columns to string before comparing.
Enabling this helps compare columns that do not share the same string cache.

Examples:
>>> import duckdb
>>> import narwhals as nw
>>> from narwhals.testing import assert_frame_equal
>>>
>>> left_native = duckdb.sql("SELECT * FROM VALUES (1, ), (2, ), (3, ) df(a)")
>>> right_native = duckdb.sql("SELECT * FROM VALUES (1, ), (5, ), (3, ) df(a)")
>>> left = nw.from_native(left_native)
>>> right = nw.from_native(right_native)
>>> assert_frame_equal(left, right) # doctest: +ELLIPSIS
Traceback (most recent call last):
...
AssertionError: DataFrames are different (value mismatch for column "a")
[left]:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals Series |
|------------------------------------------------|
|<pyarrow.lib.ChunkedArray object at ...
|[ |
| [ |
| 1, |
| 2, |
| 3 |
| ] |
|] |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
[right]:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals Series |
|------------------------------------------------|
|<pyarrow.lib.ChunkedArray object at ...
|[ |
| [ |
| 1, |
| 3, |
| 5 |
| ] |
|] |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
__tracebackhide__ = True

if any(
not (is_narwhals_dataframe(obj) or is_narwhals_lazyframe(obj))
for obj in (left, right)
):
msg = (
"Expected `narwhals.DataFrame` or `narwhals.LazyFrame` instance, found:\n"
f"[left]: {qualified_type_name(type(left))}\n"
f"[right]: {qualified_type_name(type(right))}\n\n"
"Hint: Use `nw.from_native(obj, allow_series=False) to convert each native "
"object into a `narwhals.DataFrame` or `narwhals.LazyFrame` first."
)
raise TypeError(msg)

left_impl, right_impl = left.implementation, right.implementation
if left_impl != right_impl:
raise_frame_assertion_error("implementation mismatch", left_impl, right_impl)

left_eager, right_eager = _check_correct_input_type(left, right)

_assert_dataframe_equal(
left=left_eager,
right=right_eager,
impl=left_impl,
check_row_order=check_row_order,
check_column_order=check_column_order,
check_dtypes=check_dtypes,
check_exact=check_exact,
rel_tol=rel_tol,
abs_tol=abs_tol,
categorical_as_str=categorical_as_str,
)


def _check_correct_input_type( # noqa: RET503
left: DataFrameT | LazyFrameT, right: DataFrameT | LazyFrameT
) -> tuple[DataFrame[Any], DataFrame[Any]]:
# Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/py-polars/src/polars/testing/asserts/frame.py#L15-L17
if isinstance(left, DataFrame) and isinstance(right, DataFrame):
return left, right

if isinstance(left, LazyFrame) and isinstance(right, LazyFrame):
return left.collect(), right.collect()

raise_assertion_error(
"inputs",
"unexpected input types",
left=type(left).__name__,
right=type(right).__name__,
)


def _assert_dataframe_equal(
left: DataFrameT,
right: DataFrameT,
impl: Implementation,
*,
check_row_order: bool,
check_column_order: bool,
check_dtypes: bool,
check_exact: bool,
rel_tol: float,
abs_tol: float,
categorical_as_str: bool,
) -> None:
# Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/crates/polars-testing/src/asserts/utils.rs#L829
# NOTE: Here `impl` comes from the original dataframe, not the `.collect`-ed one, and
# it's used to distinguish between backends that do and do not guarantee row order.
_check_schema_equal(
left, right, check_dtypes=check_dtypes, check_column_order=check_column_order
)

left_len, right_len = len(left), len(right)
if left_len != right_len:
raise_frame_assertion_error("height (row count) mismatch", left_len, right_len)

left_schema = left.schema
if (not check_row_order) or (impl not in GUARANTEES_ROW_ORDER):
# NOTE: Sort by all the non-nested dtypes columns.
# ! This might lead to wrong results.
# If only nested dtypes are available, then we raise an exception.
sort_by = [name for name, dtype in left_schema.items() if not dtype.is_nested()]

if not sort_by:
msg = "`check_row_order=False` is not supported (yet) with only nested data type."
raise NotImplementedError(msg)

left = left.sort(sort_by)
right = right.sort(sort_by)

for col_name in left_schema.names():
_series_left = left.get_column(col_name)
_series_right = right.get_column(col_name)
try:
assert_series_equal(
_series_left,
_series_right,
check_dtypes=False,
check_names=False,
check_order=True,
check_exact=check_exact,
rel_tol=rel_tol,
abs_tol=abs_tol,
categorical_as_str=categorical_as_str,
)
except AssertionError:
raise_frame_assertion_error(
f'value mismatch for column "{col_name}"', _series_left, _series_right
)


def _check_schema_equal(
left: DataFrameT, right: DataFrameT, *, check_dtypes: bool, check_column_order: bool
) -> None:
"""Compares DataFrame schema based on specified criteria.

Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/crates/polars-testing/src/asserts/utils.rs#L667-L698
"""
lschema, rschema = left.schema, right.schema

# Fast path for equal DataFrames
if lschema == rschema:
return

lnames, rnames = lschema.names(), rschema.names()
lset, rset = set(lnames), set(rnames)

if lset != rset:
if left_not_in_right := sorted(lset.difference(rset)):
raise_frame_assertion_error(
detail=f"{left_not_in_right} in left, but not in right",
left=lset,
right=rset,
)
if right_not_in_left := sorted(rset.difference(lset)): # pragma: no cover
# NOTE: the `pragma: no cover` flag is due to a false negative.
# The last test in `test_check_schema_mismatch` does cover this case.
raise_frame_assertion_error(
detail=f"{right_not_in_left} in right, but not in left",
left=lset,
right=rset,
)

if check_column_order and lnames != rnames:
raise_frame_assertion_error(
detail="columns are not in the same order", left=lnames, right=rnames
)

if check_dtypes:
ldtypes = lschema.dtypes()
rdtypes = (
rschema.dtypes()
if check_column_order
else [rschema[col_name] for col_name in lnames]
)

if ldtypes != rdtypes:
raise_frame_assertion_error(
detail="dtypes do not match", left=ldtypes, right=rdtypes
)

return
38 changes: 32 additions & 6 deletions narwhals/testing/asserts/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: PYI051
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal
Expand All @@ -7,23 +8,42 @@
if TYPE_CHECKING:
from typing_extensions import Never, TypeAlias

# NOTE: This alias is created to facilitate autocomplete. Feel free to extend it as
# you please when adding a new feature.
# NOTE: These aliases are created to facilitate autocompletion.
# Feel free to extend them as you please when adding new features.
# See: https://github.com/narwhals-dev/narwhals/pull/2983#discussion_r2337548736
ObjectName: TypeAlias = Literal["inputs", "Series", "DataFrames"]
SeriesDetail: TypeAlias = Literal[
"dtype mismatch",
"exact value mismatch",
"implementation mismatch",
"length mismatch",
"dtype mismatch",
"name mismatch",
"nested value mismatch",
"null value mismatch",
"exact value mismatch",
"values not within tolerance",
"nested value mismatch",
]
DataFramesDetail: TypeAlias = (
Literal[
"columns are not in the same order",
"dtypes do not match",
"height (row count) mismatch",
"implementation mismatch",
]
| str
# NOTE: `| str` makes # This makes the literals above redundant, but they still show
# up when typing as autocompletion.
# The reason to have `str` is due to the fact that other details are dynamic
# and depend upon which columns lead to the assertion error.
)


def raise_assertion_error(
objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None
objects: ObjectName,
detail: str,
left: Any,
right: Any,
*,
cause: Exception | None = None,
) -> Never:
"""Raise a detailed assertion error."""
__tracebackhide__ = True
Expand All @@ -43,3 +63,9 @@ def raise_series_assertion_error(
detail: SeriesDetail, left: Any, right: Any, *, cause: Exception | None = None
) -> Never:
raise_assertion_error("Series", detail, left, right, cause=cause)


def raise_frame_assertion_error(
detail: DataFramesDetail, left: Any, right: Any, *, cause: Exception | None = None
) -> Never:
raise_assertion_error("DataFrames", detail, left, right, cause=cause)
Loading
Loading