Skip to content

Commit 51c0299

Browse files
authored
Fix sort_index with multi-index columns. (#637)
Resolves #634.
1 parent 90ec799 commit 51c0299

File tree

3 files changed

+43
-47
lines changed

3 files changed

+43
-47
lines changed

databricks/koalas/frame.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -4177,6 +4177,32 @@ def get(self, key, default=None):
41774177
except (KeyError, ValueError, IndexError):
41784178
return default
41794179

4180+
def _sort(self, by: List[Column], ascending: Union[bool, List[bool]],
4181+
inplace: bool, na_position: str):
4182+
if isinstance(ascending, bool):
4183+
ascending = [ascending] * len(by)
4184+
if len(ascending) != len(by):
4185+
raise ValueError('Length of ascending ({}) != length of by ({})'
4186+
.format(len(ascending), len(by)))
4187+
if na_position not in ('first', 'last'):
4188+
raise ValueError("invalid na_position: '{}'".format(na_position))
4189+
4190+
# Mapper: Get a spark column function for (ascending, na_position) combination
4191+
# Note that 'asc_nulls_first' and friends were added as of Spark 2.4, see SPARK-23847.
4192+
mapper = {
4193+
(True, 'first'): lambda x: Column(getattr(x._jc, "asc_nulls_first")()),
4194+
(True, 'last'): lambda x: Column(getattr(x._jc, "asc_nulls_last")()),
4195+
(False, 'first'): lambda x: Column(getattr(x._jc, "desc_nulls_first")()),
4196+
(False, 'last'): lambda x: Column(getattr(x._jc, "desc_nulls_last")()),
4197+
}
4198+
by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)]
4199+
kdf = DataFrame(self._internal.copy(sdf=self._sdf.sort(*by))) # type: ks.DataFrame
4200+
if inplace:
4201+
self._internal = kdf._internal
4202+
return None
4203+
else:
4204+
return kdf
4205+
41804206
def sort_values(self, by: Union[str, List[str]], ascending: Union[bool, List[bool]] = True,
41814207
inplace: bool = False, na_position: str = 'last') -> Optional['DataFrame']:
41824208
"""
@@ -4253,30 +4279,9 @@ def sort_values(self, by: Union[str, List[str]], ascending: Union[bool, List[boo
42534279
"""
42544280
if isinstance(by, str):
42554281
by = [by]
4256-
if isinstance(ascending, bool):
4257-
ascending = [ascending] * len(by)
4258-
if len(ascending) != len(by):
4259-
raise ValueError('Length of ascending ({}) != length of by ({})'
4260-
.format(len(ascending), len(by)))
4261-
if na_position not in ('first', 'last'):
4262-
raise ValueError("invalid na_position: '{}'".format(na_position))
4263-
4264-
# Mapper: Get a spark column function for (ascending, na_position) combination
4265-
# Note that 'asc_nulls_first' and friends were added as of Spark 2.4, see SPARK-23847.
4266-
mapper = {
4267-
(True, 'first'): lambda x: Column(getattr(x._jc, "asc_nulls_first")()),
4268-
(True, 'last'): lambda x: Column(getattr(x._jc, "asc_nulls_last")()),
4269-
(False, 'first'): lambda x: Column(getattr(x._jc, "desc_nulls_first")()),
4270-
(False, 'last'): lambda x: Column(getattr(x._jc, "desc_nulls_last")()),
4271-
}
4272-
by = [mapper[(asc, na_position)](self[colname]._scol)
4273-
for colname, asc in zip(by, ascending)]
4274-
kdf = DataFrame(self._internal.copy(sdf=self._sdf.sort(*by))) # type: ks.DataFrame
4275-
if inplace:
4276-
self._internal = kdf._internal
4277-
return None
4278-
else:
4279-
return kdf
4282+
by = [self[colname]._scol for colname in by]
4283+
return self._sort(by=by, ascending=ascending,
4284+
inplace=inplace, na_position=na_position)
42804285

42814286
def sort_index(self, axis: int = 0,
42824287
level: Optional[Union[int, List[int]]] = None, ascending: bool = True,
@@ -4367,14 +4372,14 @@ def sort_index(self, axis: int = 0,
43674372
raise ValueError("Specifying the sorting algorithm is supported at the moment.")
43684373

43694374
if level is None or (is_list_like(level) and len(level) == 0): # type: ignore
4370-
by = self._internal.index_columns
4375+
by = self._internal.index_scols
43714376
elif is_list_like(level):
4372-
by = [self._internal.index_columns[l] for l in level] # type: ignore
4377+
by = [self._internal.index_scols[l] for l in level] # type: ignore
43734378
else:
4374-
by = self._internal.index_columns[level]
4379+
by = [self._internal.index_scols[level]]
43754380

4376-
return self.sort_values(by=by, ascending=ascending,
4377-
inplace=inplace, na_position=na_position)
4381+
return self._sort(by=by, ascending=ascending,
4382+
inplace=inplace, na_position=na_position)
43784383

43794384
# TODO: add keep = First
43804385
def nlargest(self, n: int, columns: 'Any') -> 'DataFrame':

databricks/koalas/series.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -1739,24 +1739,8 @@ def sort_index(self, axis: int = 0,
17391739
b 1 0
17401740
Name: 0, dtype: int64
17411741
"""
1742-
if len(self._internal.index_map) == 0:
1743-
raise ValueError("Index should be set.")
1744-
1745-
if axis != 0:
1746-
raise ValueError("No other axes than 0 are supported at the moment")
1747-
if kind is not None:
1748-
raise ValueError("Specifying the sorting algorithm is supported at the moment.")
1749-
1750-
if level is None or (is_list_like(level) and len(level) == 0): # type: ignore
1751-
by = self._internal.index_columns
1752-
elif is_list_like(level):
1753-
by = [self._internal.index_columns[l] for l in level] # type: ignore
1754-
else:
1755-
by = self._internal.index_columns[level]
1756-
1757-
kseries = _col(self.to_dataframe().sort_values(by=by,
1758-
ascending=ascending,
1759-
na_position=na_position))
1742+
kseries = _col(self.to_dataframe().sort_index(axis=axis, level=level, ascending=ascending,
1743+
kind=kind, na_position=na_position))
17601744
if inplace:
17611745
self._internal = kseries._internal
17621746
self._kdf = kseries._kdf

databricks/koalas/tests/test_dataframe.py

+7
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ def test_sort_index(self):
543543

544544
self.assertRaises(ValueError, lambda: kdf.reset_index().sort_index())
545545

546+
# Assert with multi-index columns
547+
columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B')])
548+
pdf.columns = columns
549+
kdf.columns = columns
550+
551+
self.assert_eq(kdf.sort_index(), pdf.sort_index())
552+
546553
def test_nlargest(self):
547554
pdf = pd.DataFrame({'a': [1, 2, 3, 4, 5, None, 7],
548555
'b': [7, 6, 5, 4, 3, 2, 1]})

0 commit comments

Comments
 (0)