Skip to content

Commit

Permalink
Backport PR #56123 on branch 2.1.x (BUG: ne comparison returns False …
Browse files Browse the repository at this point in the history
…for NA and other value) (#56382)
  • Loading branch information
phofl authored Dec 7, 2023
1 parent 7006d99 commit aacdf61
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Bug fixes
- Fixed bug in :meth:`DataFrame.__setitem__` casting :class:`Index` with object-dtype to PyArrow backed strings when ``infer_string`` option is set (:issue:`55638`)
- Fixed bug in :meth:`DataFrame.to_hdf` raising when columns have ``StringDtype`` (:issue:`55088`)
- Fixed bug in :meth:`Index.insert` casting object-dtype to PyArrow backed strings when ``infer_string`` option is set (:issue:`55638`)
- Fixed bug in :meth:`Series.__ne__` resulting in False for comparison between ``NA`` and string value for ``dtype="string[pyarrow_numpy]"`` (:issue:`56122`)
- Fixed bug in :meth:`Series.str.split` and :meth:`Series.str.rsplit` when ``pat=None`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56271`)
- Fixed bug in :meth:`Series.str.translate` losing object dtype when string option is set (:issue:`56152`)
-
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
import operator
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -600,7 +601,10 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(np.bool_, na_value=False)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True):
from pandas import Series
Expand Down
29 changes: 18 additions & 11 deletions pandas/tests/arithmetic/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
Series,
Timestamp,
option_context,
)
import pandas._testing as tm
from pandas.core import ops
Expand All @@ -31,20 +34,24 @@ def test_comparison_object_numeric_nas(self, comparison_op):
expected = func(ser.astype(float), shifted.astype(float))
tm.assert_series_equal(result, expected)

def test_object_comparisons(self):
ser = Series(["a", "b", np.nan, "c", "a"])
@pytest.mark.parametrize(
"infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))]
)
def test_object_comparisons(self, infer_string):
with option_context("future.infer_string", infer_string):
ser = Series(["a", "b", np.nan, "c", "a"])

result = ser == "a"
expected = Series([True, False, False, False, True])
tm.assert_series_equal(result, expected)
result = ser == "a"
expected = Series([True, False, False, False, True])
tm.assert_series_equal(result, expected)

result = ser < "a"
expected = Series([False, False, False, False, False])
tm.assert_series_equal(result, expected)
result = ser < "a"
expected = Series([False, False, False, False, False])
tm.assert_series_equal(result, expected)

result = ser != "a"
expected = -(ser == "a")
tm.assert_series_equal(result, expected)
result = ser != "a"
expected = -(ser == "a")
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize("dtype", [None, object])
def test_more_na_comparisons(self, dtype):
Expand Down
26 changes: 20 additions & 6 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
This module tests the functionality of StringArray and ArrowStringArray.
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
"""
import operator

import numpy as np
import pytest

Expand Down Expand Up @@ -221,7 +223,10 @@ def test_comparison_methods_scalar(comparison_op, dtype):
result = getattr(a, op_name)(other)
if dtype.storage == "pyarrow_numpy":
expected = np.array([getattr(item, op_name)(other) for item in a])
expected[1] = False
if comparison_op == operator.ne:
expected[1] = True
else:
expected[1] = False
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
Expand All @@ -236,7 +241,10 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
result = getattr(a, op_name)(pd.NA)

if dtype.storage == "pyarrow_numpy":
expected = np.array([False, False, False])
if operator.ne == comparison_op:
expected = np.array([True, True, True])
else:
expected = np.array([False, False, False])
tm.assert_numpy_array_equal(result, expected)
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
Expand All @@ -262,7 +270,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
if dtype.storage == "pyarrow_numpy":
expected_data = {
"__eq__": [False, False, False],
"__ne__": [True, False, True],
"__ne__": [True, True, True],
}[op_name]
expected = np.array(expected_data)
tm.assert_numpy_array_equal(result, expected)
Expand All @@ -282,12 +290,18 @@ def test_comparison_methods_array(comparison_op, dtype):
other = [None, None, "c"]
result = getattr(a, op_name)(other)
if dtype.storage == "pyarrow_numpy":
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
if operator.ne == comparison_op:
expected = np.array([True, True, False])
else:
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
tm.assert_numpy_array_equal(result, expected)

result = getattr(a, op_name)(pd.NA)
expected = np.array([False, False, False])
if operator.ne == comparison_op:
expected = np.array([True, True, True])
else:
expected = np.array([False, False, False])
tm.assert_numpy_array_equal(result, expected)

else:
Expand Down

0 comments on commit aacdf61

Please sign in to comment.