Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,6 +3003,7 @@ def _intersection(self, other, sort=False):

return result

@final
def difference(self, other, sort=None):
"""
Return a new Index with elements of index not in `other`.
Expand Down Expand Up @@ -3127,7 +3128,7 @@ def symmetric_difference(self, other, result_name=None, sort=None):
result_name = result_name_update

if not self._should_compare(other):
return self.union(other).rename(result_name)
return self.union(other, sort=sort).rename(result_name)
elif not is_dtype_equal(self.dtype, other.dtype):
dtype = find_common_type([self.dtype, other.dtype])
this = self.astype(dtype, copy=False)
Expand Down Expand Up @@ -6236,7 +6237,7 @@ def _maybe_cast_data_without_dtype(subarr):
try:
data = IntervalArray._from_sequence(subarr, copy=False)
return data
except ValueError:
except (ValueError, TypeError):
# GH27172: mixed closed Intervals --> object dtype
pass
elif inferred == "boolean":
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,8 @@ def is_type_compatible(self, kind: str) -> bool:
# --------------------------------------------------------------------
# Set Operation Methods

@Appender(Index.difference.__doc__)
def difference(self, other, sort=None):
new_idx = super().difference(other, sort=sort)._with_freq(None)
def _difference(self, other, sort=None):
new_idx = super()._difference(other, sort=sort)._with_freq(None)
return new_idx

def _intersection(self, other: Index, sort=False) -> Index:
Expand Down
29 changes: 11 additions & 18 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,17 @@ def setop_check(method):
def wrapped(self, other, sort=False):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, _ = self._convert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if not isinstance(other, IntervalIndex):
result = getattr(self.astype(object), op_name)(other)
if op_name in ("difference",):
result = result.astype(self.dtype)
return result
if op_name == "difference":
if not isinstance(other, IntervalIndex):
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

elif not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
result = getattr(self.astype(object), op_name)(other, sort=sort)
return result.astype(self.dtype)

return method(self, other, sort)

Expand Down Expand Up @@ -912,17 +916,6 @@ def _format_space(self) -> str:
# --------------------------------------------------------------------
# Set Operations

def _assert_can_do_setop(self, other):
super()._assert_can_do_setop(other)

if isinstance(other, IntervalIndex) and not self._should_compare(other):
# GH#19016: ensure set op will not return a prohibited dtype
raise TypeError(
"can only do set operations between two IntervalIndex "
"objects that are closed on the same side "
"and have compatible dtypes"
)

def _intersection(self, other, sort):
"""
intersection specialized to the case with matching dtypes.
Expand Down Expand Up @@ -1014,7 +1007,7 @@ def func(self, other, sort=sort):
return setop_check(func)

_union = _setop("union")
difference = _setop("difference")
_difference = _setop("difference")

# --------------------------------------------------------------------

Expand Down
10 changes: 5 additions & 5 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,14 @@ def _union(self, other, sort):
return type(self)(start_r, end_r + step_o, step_o)
return self._int64index._union(other, sort=sort)

def difference(self, other, sort=None):
def _difference(self, other, sort=None):
# optimized set operation if we have another RangeIndex
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other, result_name = self._convert_can_do_setop(other)

if not isinstance(other, RangeIndex):
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

res_name = ops.get_op_result_name(self, other)

Expand All @@ -654,11 +654,11 @@ def difference(self, other, sort=None):
return self[:0].rename(res_name)
if not isinstance(overlap, RangeIndex):
# We won't end up with RangeIndex, so fall back
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)
if overlap.step != first.step:
# In some cases we might be able to get a RangeIndex back,
# but not worth the effort.
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

if overlap[0] == first.start:
# The difference is everything after the intersection
Expand All @@ -668,7 +668,7 @@ def difference(self, other, sort=None):
new_rng = range(first.start, overlap[0], first.step)
else:
# The difference is not range-like
return super().difference(other, sort=sort)
return super()._difference(other, sort=sort)

new_index = type(self)._simple_new(new_rng, name=res_name)
if first is not self._range:
Expand Down
26 changes: 12 additions & 14 deletions pandas/tests/indexes/interval/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,19 @@ def test_set_incompatible_types(self, closed, op_name, sort):
result = set_op(Index([1, 2, 3]), sort=sort)
tm.assert_index_equal(result, expected)

# mixed closed
msg = (
"can only do set operations between two IntervalIndex objects "
"that are closed on the same side and have compatible dtypes"
)
# mixed closed -> cast to object
for other_closed in {"right", "left", "both", "neither"} - {closed}:
other = monotonic_index(0, 11, closed=other_closed)
with pytest.raises(TypeError, match=msg):
set_op(other, sort=sort)
expected = getattr(index.astype(object), op_name)(other, sort=sort)
if op_name == "difference":
expected = index
result = set_op(other, sort=sort)
tm.assert_index_equal(result, expected)

# GH 19016: incompatible dtypes
# GH 19016: incompatible dtypes -> cast to object
other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
msg = (
"can only do set operations between two IntervalIndex objects "
"that are closed on the same side and have compatible dtypes"
)
with pytest.raises(TypeError, match=msg):
set_op(other, sort=sort)
expected = getattr(index.astype(object), op_name)(other, sort=sort)
if op_name == "difference":
expected = index
result = set_op(other, sort=sort)
tm.assert_index_equal(result, expected)