Skip to content

Commit

Permalink
Fix sort_index with multi-index columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Aug 12, 2019
1 parent 82e2e41 commit e482dd8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 47 deletions.
63 changes: 34 additions & 29 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4177,6 +4177,32 @@ def get(self, key, default=None):
except (KeyError, ValueError, IndexError):
return default

def _sort(self, by: List[Column], ascending: Union[bool, List[bool]],
inplace: bool, na_position: str):
if isinstance(ascending, bool):
ascending = [ascending] * len(by)
if len(ascending) != len(by):
raise ValueError('Length of ascending ({}) != length of by ({})'
.format(len(ascending), len(by)))
if na_position not in ('first', 'last'):
raise ValueError("invalid na_position: '{}'".format(na_position))

# Mapper: Get a spark column function for (ascending, na_position) combination
# Note that 'asc_nulls_first' and friends were added as of Spark 2.4, see SPARK-23847.
mapper = {
(True, 'first'): lambda x: Column(getattr(x._jc, "asc_nulls_first")()),
(True, 'last'): lambda x: Column(getattr(x._jc, "asc_nulls_last")()),
(False, 'first'): lambda x: Column(getattr(x._jc, "desc_nulls_first")()),
(False, 'last'): lambda x: Column(getattr(x._jc, "desc_nulls_last")()),
}
by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)]
kdf = DataFrame(self._internal.copy(sdf=self._sdf.sort(*by))) # type: ks.DataFrame
if inplace:
self._internal = kdf._internal
return None
else:
return kdf

def sort_values(self, by: Union[str, List[str]], ascending: Union[bool, List[bool]] = True,
inplace: bool = False, na_position: str = 'last') -> Optional['DataFrame']:
"""
Expand Down Expand Up @@ -4253,30 +4279,9 @@ def sort_values(self, by: Union[str, List[str]], ascending: Union[bool, List[boo
"""
if isinstance(by, str):
by = [by]
if isinstance(ascending, bool):
ascending = [ascending] * len(by)
if len(ascending) != len(by):
raise ValueError('Length of ascending ({}) != length of by ({})'
.format(len(ascending), len(by)))
if na_position not in ('first', 'last'):
raise ValueError("invalid na_position: '{}'".format(na_position))

# Mapper: Get a spark column function for (ascending, na_position) combination
# Note that 'asc_nulls_first' and friends were added as of Spark 2.4, see SPARK-23847.
mapper = {
(True, 'first'): lambda x: Column(getattr(x._jc, "asc_nulls_first")()),
(True, 'last'): lambda x: Column(getattr(x._jc, "asc_nulls_last")()),
(False, 'first'): lambda x: Column(getattr(x._jc, "desc_nulls_first")()),
(False, 'last'): lambda x: Column(getattr(x._jc, "desc_nulls_last")()),
}
by = [mapper[(asc, na_position)](self[colname]._scol)
for colname, asc in zip(by, ascending)]
kdf = DataFrame(self._internal.copy(sdf=self._sdf.sort(*by))) # type: ks.DataFrame
if inplace:
self._internal = kdf._internal
return None
else:
return kdf
by = [self[colname]._scol for colname in by]
return self._sort(by=by, ascending=ascending,
inplace=inplace, na_position=na_position)

def sort_index(self, axis: int = 0,
level: Optional[Union[int, List[int]]] = None, ascending: bool = True,
Expand Down Expand Up @@ -4367,14 +4372,14 @@ def sort_index(self, axis: int = 0,
raise ValueError("Specifying the sorting algorithm is supported at the moment.")

if level is None or (is_list_like(level) and len(level) == 0): # type: ignore
by = self._internal.index_columns
by = self._internal.index_scols
elif is_list_like(level):
by = [self._internal.index_columns[l] for l in level] # type: ignore
by = [self._internal.index_scols[l] for l in level] # type: ignore
else:
by = self._internal.index_columns[level]
by = [self._internal.index_scols[level]]

return self.sort_values(by=by, ascending=ascending,
inplace=inplace, na_position=na_position)
return self._sort(by=by, ascending=ascending,
inplace=inplace, na_position=na_position)

# TODO: add keep = First
def nlargest(self, n: int, columns: 'Any') -> 'DataFrame':
Expand Down
20 changes: 2 additions & 18 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,24 +1739,8 @@ def sort_index(self, axis: int = 0,
b 1 0
Name: 0, dtype: int64
"""
if len(self._internal.index_map) == 0:
raise ValueError("Index should be set.")

if axis != 0:
raise ValueError("No other axes than 0 are supported at the moment")
if kind is not None:
raise ValueError("Specifying the sorting algorithm is supported at the moment.")

if level is None or (is_list_like(level) and len(level) == 0): # type: ignore
by = self._internal.index_columns
elif is_list_like(level):
by = [self._internal.index_columns[l] for l in level] # type: ignore
else:
by = self._internal.index_columns[level]

kseries = _col(self.to_dataframe().sort_values(by=by,
ascending=ascending,
na_position=na_position))
kseries = _col(self.to_dataframe().sort_index(axis=axis, level=level, ascending=ascending,
kind=kind, na_position=na_position))
if inplace:
self._internal = kseries._internal
self._kdf = kseries._kdf
Expand Down
7 changes: 7 additions & 0 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ def test_sort_index(self):

self.assertRaises(ValueError, lambda: kdf.reset_index().sort_index())

# Assert with multi-index columns
columns = pd.MultiIndex.from_tuples([('X', 'A'), ('X', 'B')])
pdf.columns = columns
kdf.columns = columns

self.assert_eq(kdf.sort_index(), pdf.sort_index())

def test_nlargest(self):
pdf = pd.DataFrame({'a': [1, 2, 3, 4, 5, None, 7],
'b': [7, 6, 5, 4, 3, 2, 1]})
Expand Down

0 comments on commit e482dd8

Please sign in to comment.