@@ -3395,6 +3395,7 @@ def join(self, other, how="left", level=None, return_indexers=False, sort=False)
33953395 -------
33963396 join_index, (left_indexer, right_indexer)
33973397 """
3398+ other = ensure_index (other )
33983399 self_is_mi = isinstance (self , ABCMultiIndex )
33993400 other_is_mi = isinstance (other , ABCMultiIndex )
34003401
@@ -3414,8 +3415,6 @@ def join(self, other, how="left", level=None, return_indexers=False, sort=False)
34143415 other , level , how = how , return_indexers = return_indexers
34153416 )
34163417
3417- other = ensure_index (other )
3418-
34193418 if len (other ) == 0 and how in ("left" , "outer" ):
34203419 join_index = self ._shallow_copy ()
34213420 if return_indexers :
@@ -3577,16 +3576,26 @@ def _join_multi(self, other, how, return_indexers=True):
35773576 def _join_non_unique (self , other , how = "left" , return_indexers = False ):
35783577 from pandas .core .reshape .merge import _get_join_indexers
35793578
3579+ # We only get here if dtypes match
3580+ assert self .dtype == other .dtype
3581+
3582+ if is_extension_array_dtype (self .dtype ):
3583+ lvalues = self ._data ._values_for_argsort ()
3584+ rvalues = other ._data ._values_for_argsort ()
3585+ else :
3586+ lvalues = self ._values
3587+ rvalues = other ._values
3588+
35803589 left_idx , right_idx = _get_join_indexers (
3581- [self . _ndarray_values ], [other . _ndarray_values ], how = how , sort = True
3590+ [lvalues ], [rvalues ], how = how , sort = True
35823591 )
35833592
35843593 left_idx = ensure_platform_int (left_idx )
35853594 right_idx = ensure_platform_int (right_idx )
35863595
3587- join_index = np .asarray (self . _ndarray_values .take (left_idx ))
3596+ join_index = np .asarray (lvalues .take (left_idx ))
35883597 mask = left_idx == - 1
3589- np .putmask (join_index , mask , other . _ndarray_values .take (right_idx ))
3598+ np .putmask (join_index , mask , rvalues .take (right_idx ))
35903599
35913600 join_index = self ._wrap_joined_index (join_index , other )
35923601
@@ -3737,15 +3746,22 @@ def _get_leaf_sorter(labels):
37373746 return join_index
37383747
37393748 def _join_monotonic (self , other , how = "left" , return_indexers = False ):
3749+ # We only get here with matching dtypes
3750+ assert other .dtype == self .dtype
3751+
37403752 if self .equals (other ):
37413753 ret_index = other if how == "right" else self
37423754 if return_indexers :
37433755 return ret_index , None , None
37443756 else :
37453757 return ret_index
37463758
3747- sv = self ._ndarray_values
3748- ov = other ._ndarray_values
3759+ if is_extension_array_dtype (self .dtype ):
3760+ sv = self ._data ._values_for_argsort ()
3761+ ov = other ._data ._values_for_argsort ()
3762+ else :
3763+ sv = self ._values
3764+ ov = other ._values
37493765
37503766 if self .is_unique and other .is_unique :
37513767 # We can perform much better than the general case
0 commit comments