Skip to content

Commit 6753bb6

Browse files
authored
fix: Incorrect logic in assert_series_equal for infinities (#20763)
1 parent d05b942 commit 6753bb6

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

py-polars/polars/testing/asserts/series.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,11 @@ def _assert_series_values_within_tolerance(
319319

320320
difference = (left_unequal - right_unequal).abs()
321321
tolerance = atol + rtol * right_unequal.abs()
322-
exceeds_tolerance = difference > tolerance
322+
within_tolerance = (difference <= tolerance) & right_unequal.is_finite() | (
323+
left_unequal == right_unequal
324+
)
323325

324-
if exceeds_tolerance.any():
326+
if not within_tolerance.all():
325327
raise_assertion_error(
326328
"Series",
327329
"value mismatch",

py-polars/tests/unit/operations/arithmetic/test_list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def test_list_arithmetic_div_ops_zero_denominator(
719719

720720
assert_series_equal(
721721
s / pl.Series([1]).new_from_index(0, n),
722-
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
722+
pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),
723723
)
724724

725725
# floordiv

py-polars/tests/unit/series/test_series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ def test_diff() -> None:
10161016

10171017
def test_pct_change() -> None:
10181018
s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64])
1019-
expected = pl.Series("a", [None, None, float("inf"), 3.0, 3.0, 3.0, 3.0])
1019+
expected = pl.Series("a", [None, None, 3.0, 3.0, 3.0, 3.0, 3.0])
10201020
assert_series_equal(s.pct_change(2), expected)
10211021
assert_series_equal(s.pct_change(pl.Series([2])), expected)
10221022
# negative

py-polars/tests/unit/testing/test_assert_series_equal.py

+26
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,29 @@ def test_series_data_type_fail():
837837
assert "AssertionError: Series are different (nan value mismatch)" in stdout
838838
assert "AssertionError: Series are different (dtype mismatch)" in stdout
839839
assert "AssertionError: inputs are different (unexpected input types)" in stdout
840+
841+
842+
def test_assert_series_equal_inf() -> None:
843+
s1 = pl.Series([1.0, float("inf")])
844+
s2 = pl.Series([1.0, float("inf")])
845+
assert_series_equal(s1, s2)
846+
847+
s1 = pl.Series([1.0, float("-inf")])
848+
s2 = pl.Series([1.0, float("-inf")])
849+
assert_series_equal(s1, s2)
850+
851+
s1 = pl.Series([1.0, float("inf")])
852+
s2 = pl.Series([float("inf"), 1.0])
853+
assert_series_not_equal(s1, s2)
854+
855+
s1 = pl.Series([1.0, float("inf")])
856+
s2 = pl.Series([1.0, float("-inf")])
857+
assert_series_not_equal(s1, s2)
858+
859+
s1 = pl.Series([1.0, float("inf")])
860+
s2 = pl.Series([1.0, 2.0])
861+
assert_series_not_equal(s1, s2)
862+
863+
s1 = pl.Series([1.0, float("inf")])
864+
s2 = pl.Series([1.0, float("nan")])
865+
assert_series_not_equal(s1, s2)

0 commit comments

Comments
 (0)