-
Notifications
You must be signed in to change notification settings - Fork 172
feat: Add testing.assert_frame_equal
#3220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
FBruzzesi
wants to merge
25
commits into
main
Choose a base branch
from
feat/testing-assert-frame-equal
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
fa82752
feat: Add testing.assert_frame_equal
FBruzzesi e12252c
it should be it
FBruzzesi d109448
WIP: Unit test
FBruzzesi db761a8
arguably improve typing(?), raise if only nested dtypes are available
FBruzzesi 748b971
almost there with the testing
FBruzzesi bbf7a2f
Docstrings, more comments, docs
FBruzzesi 9417b60
try within 'if TYPE_CHECKING'
FBruzzesi e2a7c7c
skip old pandas nested dtypes
FBruzzesi d20fdc4
add missing backtick, improve comment
FBruzzesi dadcf34
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi c0c6210
merge main, simplify, rm unused
FBruzzesi 128e05a
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi d3d4ed8
ci: Test fairlearn using pytest marker
FBruzzesi 149c9c6
use uv run
FBruzzesi a27931f
one more try
FBruzzesi c3046b0
one more try
FBruzzesi 6d9528b
ok use system py
FBruzzesi aa03a19
skip if pyarrow is not installed
FBruzzesi 1672b53
something went wrong in merging
FBruzzesi 2bdfb68
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi 9e1c539
fix fixture name
FBruzzesi 91bb851
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi c169685
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi 253f3ad
fix docstrings
FBruzzesi a97f64a
Merge branch 'main' into feat/testing-assert-frame-equal
FBruzzesi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,4 +4,5 @@ | |
| handler: python | ||
| options: | ||
| members: | ||
| - assert_frame_equal | ||
| - assert_series_equal | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from narwhals.testing.asserts.frame import assert_frame_equal | ||
| from narwhals.testing.asserts.series import assert_series_equal | ||
|
|
||
| __all__ = ("assert_series_equal",) | ||
| __all__ = ("assert_frame_equal", "assert_series_equal") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,267 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from narwhals._utils import Implementation, qualified_type_name | ||
| from narwhals.dataframe import DataFrame, LazyFrame | ||
| from narwhals.dependencies import is_narwhals_dataframe, is_narwhals_lazyframe | ||
| from narwhals.testing.asserts.series import assert_series_equal | ||
| from narwhals.testing.asserts.utils import ( | ||
| raise_assertion_error, | ||
| raise_frame_assertion_error, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from narwhals.typing import DataFrameT, LazyFrameT | ||
|
|
||
| GUARANTEES_ROW_ORDER = { | ||
| Implementation.PANDAS, | ||
| Implementation.MODIN, | ||
| Implementation.CUDF, | ||
| Implementation.PYARROW, | ||
| Implementation.POLARS, | ||
| Implementation.DASK, | ||
| } | ||
|
|
||
|
|
||
| def assert_frame_equal( | ||
| left: DataFrameT | LazyFrameT, | ||
| right: DataFrameT | LazyFrameT, | ||
| *, | ||
| check_row_order: bool = True, | ||
| check_column_order: bool = True, | ||
| check_dtypes: bool = True, | ||
| check_exact: bool = False, | ||
| rel_tol: float = 1e-5, | ||
| abs_tol: float = 1e-8, | ||
| categorical_as_str: bool = False, | ||
| ) -> None: | ||
| """Assert that the left and right frames are equal. | ||
|
|
||
| Raises a detailed `AssertionError` if the frames differ. | ||
| This function is intended for use in unit tests. | ||
|
|
||
| Notes: | ||
| In the case of backends that do not guarantee the row order, such as DuckDB, Ibis, | ||
| PySpark, and SQLFrame, `check_row_order` argument is ignored and the comparands | ||
| are sorted by all the columns regardless. | ||
|
|
||
| Arguments: | ||
| left: The first DataFrame or LazyFrame to compare. | ||
| right: The second DataFrame or LazyFrame to compare. | ||
| check_row_order: Requires row order to match. This flag is ignored for backends | ||
| that do not guarantee row order such as DuckDB, Ibis, PySpark, SQLFrame. | ||
| check_column_order: Requires column order to match. | ||
| check_dtypes: Requires data types to match. | ||
| check_exact: Requires float values to match exactly. If set to `False`, values are | ||
| considered equal when within tolerance of each other (see `rel_tol` and `abs_tol`). | ||
| Only affects columns with a Float data type. | ||
| rel_tol: Relative tolerance for inexact checking. Fraction of values in `right`. | ||
| abs_tol: Absolute tolerance for inexact checking. | ||
| categorical_as_str: Cast categorical columns to string before comparing. | ||
| Enabling this helps compare columns that do not share the same string cache. | ||
|
|
||
| Examples: | ||
| >>> import duckdb | ||
| >>> import narwhals as nw | ||
| >>> from narwhals.testing import assert_frame_equal | ||
| >>> | ||
| >>> left_native = duckdb.sql("SELECT * FROM VALUES (1, ), (2, ), (3, ) df(a)") | ||
| >>> right_native = duckdb.sql("SELECT * FROM VALUES (1, ), (5, ), (3, ) df(a)") | ||
| >>> left = nw.from_native(left_native) | ||
| >>> right = nw.from_native(right_native) | ||
| >>> assert_frame_equal(left, right) # doctest: +ELLIPSIS | ||
| Traceback (most recent call last): | ||
| ... | ||
| AssertionError: DataFrames are different (value mismatch for column "a") | ||
| [left]: | ||
| ββββββββββββββββββββββββββββββββββββββββββββββββββ | ||
| | Narwhals Series | | ||
| |------------------------------------------------| | ||
| |<pyarrow.lib.ChunkedArray object at ... | ||
| |[ | | ||
| | [ | | ||
| | 1, | | ||
| | 2, | | ||
| | 3 | | ||
| | ] | | ||
| |] | | ||
| ββββββββββββββββββββββββββββββββββββββββββββββββββ | ||
| [right]: | ||
| ββββββββββββββββββββββββββββββββββββββββββββββββββ | ||
| | Narwhals Series | | ||
| |------------------------------------------------| | ||
| |<pyarrow.lib.ChunkedArray object at ... | ||
| |[ | | ||
| | [ | | ||
| | 1, | | ||
| | 3, | | ||
| | 5 | | ||
| | ] | | ||
| |] | | ||
| ββββββββββββββββββββββββββββββββββββββββββββββββββ | ||
| """ | ||
| __tracebackhide__ = True | ||
|
|
||
| if any( | ||
| not (is_narwhals_dataframe(obj) or is_narwhals_lazyframe(obj)) | ||
| for obj in (left, right) | ||
| ): | ||
| msg = ( | ||
| "Expected `narwhals.DataFrame` or `narwhals.LazyFrame` instance, found:\n" | ||
| f"[left]: {qualified_type_name(type(left))}\n" | ||
| f"[right]: {qualified_type_name(type(right))}\n\n" | ||
| "Hint: Use `nw.from_native(obj, allow_series=False)` to convert each native " | ||
| "object into a `narwhals.DataFrame` or `narwhals.LazyFrame` first." | ||
| ) | ||
| raise TypeError(msg) | ||
|
|
||
| left_impl, right_impl = left.implementation, right.implementation | ||
| if left_impl != right_impl: | ||
| raise_frame_assertion_error("implementation mismatch", left_impl, right_impl) | ||
|
|
||
| left_eager, right_eager = _check_correct_input_type(left, right) | ||
|
|
||
| _assert_dataframe_equal( | ||
| left=left_eager, | ||
| right=right_eager, | ||
| impl=left_impl, | ||
| check_row_order=check_row_order, | ||
| check_column_order=check_column_order, | ||
| check_dtypes=check_dtypes, | ||
| check_exact=check_exact, | ||
| rel_tol=rel_tol, | ||
| abs_tol=abs_tol, | ||
| categorical_as_str=categorical_as_str, | ||
| ) | ||
|
|
||
|
|
||
| def _check_correct_input_type( # noqa: RET503 | ||
| left: DataFrameT | LazyFrameT, right: DataFrameT | LazyFrameT | ||
| ) -> tuple[DataFrame[Any], DataFrame[Any]]: | ||
| # Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/py-polars/src/polars/testing/asserts/frame.py#L15-L17 | ||
| if isinstance(left, DataFrame) and isinstance(right, DataFrame): | ||
| return left, right | ||
|
|
||
| if isinstance(left, LazyFrame) and isinstance(right, LazyFrame): | ||
| return left.collect(), right.collect() | ||
|
|
||
| raise_assertion_error( | ||
| "inputs", | ||
| "unexpected input types", | ||
| left=type(left).__name__, | ||
| right=type(right).__name__, | ||
| ) | ||
|
|
||
|
|
||
| def _assert_dataframe_equal( | ||
| left: DataFrameT, | ||
| right: DataFrameT, | ||
| impl: Implementation, | ||
| *, | ||
| check_row_order: bool, | ||
| check_column_order: bool, | ||
| check_dtypes: bool, | ||
| check_exact: bool, | ||
| rel_tol: float, | ||
| abs_tol: float, | ||
| categorical_as_str: bool, | ||
| ) -> None: | ||
| # Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/crates/polars-testing/src/asserts/utils.rs#L829 | ||
| # NOTE: Here `impl` comes from the original dataframe, not the `.collect`-ed one, and | ||
| # it's used to distinguish between backends that do and do not guarantee row order. | ||
| _check_schema_equal( | ||
| left, right, check_dtypes=check_dtypes, check_column_order=check_column_order | ||
| ) | ||
|
|
||
| left_len, right_len = len(left), len(right) | ||
| if left_len != right_len: | ||
| raise_frame_assertion_error("height (row count) mismatch", left_len, right_len) | ||
| # TODO(FBruzzesi): Should we return early if row count is zero? | ||
|
|
||
| left_schema = left.schema | ||
| if (not check_row_order) or (impl not in GUARANTEES_ROW_ORDER): | ||
| # NOTE: Sort by all the non-nested dtypes columns. | ||
| # See: https://github.com/narwhals-dev/narwhals/issues/2939 | ||
| # ! This might lead to wrong results if there are duplicate values in the sorting | ||
| # columns as the final order might still be non fully deterministic. | ||
| sort_by = [name for name, dtype in left_schema.items() if not dtype.is_nested()] | ||
|
|
||
| if not sort_by: | ||
| # If only nested dtypes are available, then we raise an exception. | ||
| msg = "`check_row_order=False` is not supported (yet) with only nested data type." | ||
| raise NotImplementedError(msg) | ||
|
|
||
| left = left.sort(sort_by) | ||
| right = right.sort(sort_by) | ||
|
|
||
| for col_name in left_schema.names(): | ||
| _series_left = left.get_column(col_name) | ||
| _series_right = right.get_column(col_name) | ||
| try: | ||
| assert_series_equal( | ||
| _series_left, | ||
| _series_right, | ||
| check_dtypes=False, | ||
| check_names=False, | ||
| check_order=True, | ||
| check_exact=check_exact, | ||
| rel_tol=rel_tol, | ||
| abs_tol=abs_tol, | ||
| categorical_as_str=categorical_as_str, | ||
| ) | ||
| except AssertionError: | ||
| raise_frame_assertion_error( | ||
| f'value mismatch for column "{col_name}"', _series_left, _series_right | ||
| ) | ||
|
|
||
|
|
||
| def _check_schema_equal( | ||
| left: DataFrameT, right: DataFrameT, *, check_dtypes: bool, check_column_order: bool | ||
| ) -> None: | ||
| """Compares DataFrame schema based on specified criteria. | ||
|
|
||
| Adapted from https://github.com/pola-rs/polars/blob/afdbf3056d1228cf493901e45f536b0905cec8ea/crates/polars-testing/src/asserts/utils.rs#L667-L698 | ||
| """ | ||
| lschema, rschema = left.schema, right.schema | ||
|
|
||
| # Fast path for equal DataFrames | ||
| if lschema == rschema: | ||
| return | ||
|
|
||
| lnames, rnames = lschema.names(), rschema.names() | ||
| lset, rset = set(lnames), set(rnames) | ||
|
|
||
| if lset != rset: | ||
| if left_not_in_right := sorted(lset.difference(rset)): | ||
| raise_frame_assertion_error( | ||
| detail=f"{left_not_in_right} in left, but not in right", | ||
| left=lset, | ||
| right=rset, | ||
| ) | ||
| if right_not_in_left := sorted(rset.difference(lset)): # pragma: no cover | ||
| # NOTE: the `pragma: no cover` flag is due to a false negative. | ||
| # The last test in `test_check_schema_mismatch` does cover this case. | ||
| raise_frame_assertion_error( | ||
| detail=f"{right_not_in_left} in right, but not in left", | ||
| left=lset, | ||
| right=rset, | ||
| ) | ||
|
|
||
| if check_column_order and lnames != rnames: | ||
| raise_frame_assertion_error( | ||
| detail="columns are not in the same order", left=lnames, right=rnames | ||
| ) | ||
|
|
||
| if check_dtypes: | ||
| ldtypes = lschema.dtypes() | ||
| rdtypes = ( | ||
| rschema.dtypes() | ||
| if check_column_order | ||
| else [rschema[col_name] for col_name in lnames] | ||
| ) | ||
|
|
||
| if ldtypes != rdtypes: | ||
| raise_frame_assertion_error( | ||
| detail="dtypes do not match", left=ldtypes, right=rdtypes | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dangotbanned do you see a better way of doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first thought would be to keep the
*Detailaliases for static portion of the message.I haven't looked through to see how the dynamic stuff works yet - particularly where it is sourced and where it gets inserted - but if you had these things separated into 2 parameters then it might work?
E.g. if you added another parameter with
not_thing: str = "", that could just get skipped when empty?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example uses
Noneand adds line breaks when you have contenthttps://github.com/vega/altair/blob/31c7f8a4b74f824311af459f17b8c5c65c32ba3c/altair/utils/deprecation.py