Skip to content

Commit

Permalink
REF: Share Index comparison and arithmetic methods (#43555)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Sep 13, 2021
1 parent e7efcca commit e9d0a58
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 151 deletions.
31 changes: 29 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
remove_na_arraylike,
)

from pandas.core import algorithms
from pandas.core import (
algorithms,
ops,
)
from pandas.core.accessor import DirNamesMixin
from pandas.core.algorithms import (
duplicated,
Expand All @@ -61,7 +64,11 @@
)
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
from pandas.core.construction import create_series_with_explicit_dtype
from pandas.core.construction import (
create_series_with_explicit_dtype,
ensure_wrapped_if_datetimelike,
extract_array,
)
import pandas.core.nanops as nanops

if TYPE_CHECKING:
Expand Down Expand Up @@ -1238,3 +1245,23 @@ def _duplicated(
self, keep: Literal["first", "last", False] = "first"
) -> npt.NDArray[np.bool_]:
return duplicated(self._values, keep=keep)

def _arith_method(self, other, op):
res_name = ops.get_op_result_name(self, other)

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
result = ops.arithmetic_op(lvalues, rvalues, op)

return self._construct_result(result, name=res_name)

def _construct_result(self, result, name):
"""
Construct an appropriately-wrapped result from the ArrayLike result
of an arithmetic-like operation.
"""
raise AbstractMethodError(self)
35 changes: 25 additions & 10 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6364,7 +6364,10 @@ def _cmp_method(self, other, op):
arr[self.isna()] = False
return arr
elif op in {operator.ne, operator.lt, operator.gt}:
return np.zeros(len(self), dtype=bool)
arr = np.zeros(len(self), dtype=bool)
if self._can_hold_na and not isinstance(self, ABCMultiIndex):
arr[self.isna()] = True
return arr

if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)) and len(
self
Expand All @@ -6381,6 +6384,9 @@ def _cmp_method(self, other, op):
with np.errstate(all="ignore"):
result = op(self._values, other)

elif isinstance(self._values, ExtensionArray):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
Expand All @@ -6392,17 +6398,26 @@ def _cmp_method(self, other, op):

return result

def _arith_method(self, other, op):
"""
Wrapper used to dispatch arithmetic operations.
"""
def _construct_result(self, result, name):
if isinstance(result, tuple):
return (
Index._with_infer(result[0], name=name),
Index._with_infer(result[1], name=name),
)
return Index._with_infer(result, name=name)

from pandas import Series
def _arith_method(self, other, op):
if (
isinstance(other, Index)
and is_object_dtype(other.dtype)
and type(other) is not Index
):
# We return NotImplemented for object-dtype index *subclasses* so they have
# a chance to implement ops before we unwrap them.
# See https://github.com/pandas-dev/pandas/issues/31109
return NotImplemented

result = op(Series(self), other)
if isinstance(result, tuple):
return (Index._with_infer(result[0]), Index(result[1]))
return Index._with_infer(result)
return super()._arith_method(other, op)

@final
def _unary_method(self, op):
Expand Down
119 changes: 1 addition & 118 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,9 @@

from pandas.core.dtypes.common import (
is_dtype_equal,
is_object_dtype,
pandas_dtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)
from pandas.core.dtypes.generic import ABCDataFrame

from pandas.core.arrays import (
Categorical,
Expand All @@ -45,7 +41,6 @@
from pandas.core.arrays.base import ExtensionArray
from pandas.core.indexers import deprecate_ndim_indexing
from pandas.core.indexes.base import Index
from pandas.core.ops import get_op_result_name

if TYPE_CHECKING:

Expand Down Expand Up @@ -154,94 +149,6 @@ def wrapper(cls):
return wrapper


def _make_wrapped_comparison_op(opname: str):
"""
Create a comparison method that dispatches to ``._data``.
"""

def wrapper(self, other):
if isinstance(other, ABCSeries):
# the arrays defer to Series for comparison ops but the indexes
# don't, so we have to unwrap here.
other = other._values

other = _maybe_unwrap_index(other)

op = getattr(self._data, opname)
return op(other)

wrapper.__name__ = opname
return wrapper


def _make_wrapped_arith_op(opname: str):
def method(self, other):
if (
isinstance(other, Index)
and is_object_dtype(other.dtype)
and type(other) is not Index
):
# We return NotImplemented for object-dtype index *subclasses* so they have
# a chance to implement ops before we unwrap them.
# See https://github.com/pandas-dev/pandas/issues/31109
return NotImplemented

try:
meth = getattr(self._data, opname)
except AttributeError as err:
# e.g. Categorical, IntervalArray
cls = type(self).__name__
raise TypeError(
f"cannot perform {opname} with this index type: {cls}"
) from err

result = meth(_maybe_unwrap_index(other))
return _wrap_arithmetic_op(self, other, result)

method.__name__ = opname
return method


def _wrap_arithmetic_op(self, other, result):
if result is NotImplemented:
return NotImplemented

if isinstance(result, tuple):
# divmod, rdivmod
assert len(result) == 2
return (
_wrap_arithmetic_op(self, other, result[0]),
_wrap_arithmetic_op(self, other, result[1]),
)

if not isinstance(result, Index):
# Index.__new__ will choose appropriate subclass for dtype
result = Index(result)

res_name = get_op_result_name(self, other)
result.name = res_name
return result


def _maybe_unwrap_index(obj):
"""
If operating against another Index object, we need to unwrap the underlying
data before deferring to the DatetimeArray/TimedeltaArray/PeriodArray
implementation, otherwise we will incorrectly return NotImplemented.
Parameters
----------
obj : object
Returns
-------
unwrapped object
"""
if isinstance(obj, Index):
return obj._data
return obj


class ExtensionIndex(Index):
"""
Index subclass for indexes backed by ExtensionArray.
Expand Down Expand Up @@ -284,30 +191,6 @@ def _simple_new(
result._reset_identity()
return result

__eq__ = _make_wrapped_comparison_op("__eq__")
__ne__ = _make_wrapped_comparison_op("__ne__")
__lt__ = _make_wrapped_comparison_op("__lt__")
__gt__ = _make_wrapped_comparison_op("__gt__")
__le__ = _make_wrapped_comparison_op("__le__")
__ge__ = _make_wrapped_comparison_op("__ge__")

__add__ = _make_wrapped_arith_op("__add__")
__sub__ = _make_wrapped_arith_op("__sub__")
__radd__ = _make_wrapped_arith_op("__radd__")
__rsub__ = _make_wrapped_arith_op("__rsub__")
__pow__ = _make_wrapped_arith_op("__pow__")
__rpow__ = _make_wrapped_arith_op("__rpow__")
__mul__ = _make_wrapped_arith_op("__mul__")
__rmul__ = _make_wrapped_arith_op("__rmul__")
__floordiv__ = _make_wrapped_arith_op("__floordiv__")
__rfloordiv__ = _make_wrapped_arith_op("__rfloordiv__")
__mod__ = _make_wrapped_arith_op("__mod__")
__rmod__ = _make_wrapped_arith_op("__rmod__")
__divmod__ = _make_wrapped_arith_op("__divmod__")
__rdivmod__ = _make_wrapped_arith_op("__rdivmod__")
__truediv__ = _make_wrapped_arith_op("__truediv__")
__rtruediv__ = _make_wrapped_arith_op("__rtruediv__")

# ---------------------------------------------------------------------
# NDarray-Like Methods

Expand Down
13 changes: 1 addition & 12 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
import pandas.core.common as com
from pandas.core.construction import (
create_series_with_explicit_dtype,
ensure_wrapped_if_datetimelike,
extract_array,
is_empty_data,
sanitize_array,
Expand Down Expand Up @@ -5515,18 +5514,8 @@ def _logical_method(self, other, op):
return self._construct_result(res_values, name=res_name)

def _arith_method(self, other, op):
res_name = ops.get_op_result_name(self, other)
self, other = ops.align_method_SERIES(self, other)

lvalues = self._values
rvalues = extract_array(other, extract_numpy=True, extract_range=True)
rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape)
rvalues = ensure_wrapped_if_datetimelike(rvalues)

with np.errstate(all="ignore"):
result = ops.arithmetic_op(lvalues, rvalues, op)

return self._construct_result(result, name=res_name)
return base.IndexOpsMixin._arith_method(self, other, op)


Series._add_numeric_operations()
Expand Down
10 changes: 2 additions & 8 deletions pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ def test_dti_sub_tdi(self, tz_naive_fixture):
result = dti - tdi.values
tm.assert_index_equal(result, expected)

msg = "cannot subtract DatetimeArray from"
msg = "cannot subtract a datelike from a TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdi.values - dti

Expand All @@ -2172,13 +2172,7 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
result -= tdi.values
tm.assert_index_equal(result, expected)

msg = "|".join(
[
"cannot perform __neg__ with this index type:",
"ufunc subtract cannot use operands with types",
"cannot subtract DatetimeArray from",
]
)
msg = "cannot subtract a datelike from a TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdi.values -= dti

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):

with pytest.raises(TypeError, match=msg):
rng - tdarr
msg = r"cannot subtract PeriodArray from timedelta64\[ns\]"
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdarr - rng

Expand Down

0 comments on commit e9d0a58

Please sign in to comment.