Skip to content

Commit ad1afeb

Browse files
devin-petersohnrobertnishihara
authored andcommitted
[DataFrame] Impement sort_values and sort_index (#1977)
1 parent 9f28529 commit ad1afeb

File tree

2 files changed

+198
-14
lines changed

2 files changed

+198
-14
lines changed

python/ray/dataframe/dataframe.py

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3839,15 +3839,175 @@ def slice_shift(self, periods=1, axis=0):
38393839
def sort_index(self, axis=0, level=None, ascending=True, inplace=False,
38403840
kind='quicksort', na_position='last', sort_remaining=True,
38413841
by=None):
3842-
raise NotImplementedError(
3843-
"To contribute to Pandas on Ray, please visit "
3844-
"github.com/ray-project/ray.")
3842+
"""Sort a DataFrame by one of the indices (columns or index).
3843+
3844+
Args:
3845+
axis: The axis to sort over.
3846+
level: The MultiIndex level to sort over.
3847+
ascending: Ascending or descending
3848+
inplace: Whether or not to update this DataFrame inplace.
3849+
kind: How to perform the sort.
3850+
na_position: Where to position NA on the sort.
3851+
sort_remaining: On Multilevel Index sort based on all levels.
3852+
by: (Deprecated) argument to pass to sort_values.
3853+
3854+
Returns:
3855+
A sorted DataFrame
3856+
"""
3857+
if level is not None:
3858+
raise NotImplementedError("Multilevel index not yet implemented.")
3859+
3860+
if by is not None:
3861+
warnings.warn("by argument to sort_index is deprecated, "
3862+
"please use .sort_values(by=...)",
3863+
FutureWarning, stacklevel=2)
3864+
if level is not None:
3865+
raise ValueError("unable to simultaneously sort by and level")
3866+
return self.sort_values(by, axis=axis, ascending=ascending,
3867+
inplace=inplace)
3868+
3869+
axis = pd.DataFrame()._get_axis_number(axis)
3870+
3871+
args = (axis, level, ascending, False, kind, na_position,
3872+
sort_remaining)
3873+
3874+
def _sort_helper(df, index, axis, *args):
3875+
if axis == 0:
3876+
df.index = index
3877+
else:
3878+
df.columns = index
3879+
3880+
result = df.sort_index(*args)
3881+
df.reset_index(drop=True, inplace=True)
3882+
df.columns = pd.RangeIndex(len(df.columns))
3883+
return result
3884+
3885+
if axis == 0:
3886+
index = self.index
3887+
new_column_parts = _map_partitions(
3888+
lambda df: _sort_helper(df, index, axis, *args),
3889+
self._col_partitions)
3890+
3891+
new_columns = self.columns
3892+
new_index = self.index.sort_values()
3893+
new_row_parts = None
3894+
else:
3895+
columns = self.columns
3896+
new_row_parts = _map_partitions(
3897+
lambda df: _sort_helper(df, columns, axis, *args),
3898+
self._row_partitions)
3899+
3900+
new_columns = self.columns.sort_values()
3901+
new_index = self.index
3902+
new_column_parts = None
3903+
3904+
if not inplace:
3905+
return DataFrame(col_partitions=new_column_parts,
3906+
row_partitions=new_row_parts,
3907+
index=new_index,
3908+
columns=new_columns)
3909+
else:
3910+
self._update_inplace(row_partitions=new_row_parts,
3911+
col_partitions=new_column_parts,
3912+
columns=new_columns,
3913+
index=new_index)
38453914

38463915
def sort_values(self, by, axis=0, ascending=True, inplace=False,
38473916
kind='quicksort', na_position='last'):
3848-
raise NotImplementedError(
3849-
"To contribute to Pandas on Ray, please visit "
3850-
"github.com/ray-project/ray.")
3917+
"""Sorts by a column/row or list of columns/rows.
3918+
3919+
Args:
3920+
by: A list of labels for the axis to sort over.
3921+
axis: The axis to sort.
3922+
ascending: Sort in ascending or descending order.
3923+
inplace: If true, do the operation inplace.
3924+
kind: How to sort.
3925+
na_position: Where to put np.nan values.
3926+
3927+
Returns:
3928+
A sorted DataFrame.
3929+
"""
3930+
3931+
axis = pd.DataFrame()._get_axis_number(axis)
3932+
3933+
if not is_list_like(by):
3934+
by = [by]
3935+
3936+
if axis == 0:
3937+
broadcast_value_dict = {str(col): self[col] for col in by}
3938+
broadcast_values = pd.DataFrame(broadcast_value_dict)
3939+
else:
3940+
broadcast_value_list = [to_pandas(self[row::len(self.index)])
3941+
for row in by]
3942+
3943+
index_builder = list(zip(broadcast_value_list, by))
3944+
3945+
for row, idx in index_builder:
3946+
row.index = [str(idx)]
3947+
3948+
broadcast_values = pd.concat([row for row, idx in index_builder])
3949+
3950+
# We are converting the by to string here so that we don't have a
3951+
# collision with the RangeIndex on the inner frame. It is cheap and
3952+
# gaurantees that we sort by the correct column.
3953+
by = [str(col) for col in by]
3954+
3955+
args = (by, axis, ascending, False, kind, na_position)
3956+
3957+
def _sort_helper(df, broadcast_values, axis, *args):
3958+
"""Sorts the data on a partition.
3959+
3960+
Args:
3961+
df: The DataFrame to sort.
3962+
broadcast_values: The by DataFrame to use for the sort.
3963+
axis: The axis to sort over.
3964+
args: The args for the sort.
3965+
3966+
Returns:
3967+
A new sorted DataFrame.
3968+
"""
3969+
if axis == 0:
3970+
broadcast_values.index = df.index
3971+
names = broadcast_values.columns
3972+
else:
3973+
broadcast_values.columns = df.columns
3974+
names = broadcast_values.index
3975+
3976+
return pd.concat([df, broadcast_values], axis=axis ^ 1,
3977+
copy=False).sort_values(*args)\
3978+
.drop(names, axis=axis ^ 1)
3979+
3980+
if axis == 0:
3981+
new_column_partitions = _map_partitions(
3982+
lambda df: _sort_helper(df, broadcast_values, axis, *args),
3983+
self._col_partitions)
3984+
3985+
new_row_partitions = None
3986+
new_columns = self.columns
3987+
3988+
# This is important because it allows us to get the axis that we
3989+
# aren't sorting over. We need the order of the columns/rows and
3990+
# this will provide that in the return value.
3991+
new_index = broadcast_values.sort_values(*args).index
3992+
else:
3993+
new_row_partitions = _map_partitions(
3994+
lambda df: _sort_helper(df, broadcast_values, axis, *args),
3995+
self._row_partitions)
3996+
3997+
new_column_partitions = None
3998+
new_columns = broadcast_values.sort_values(*args).columns
3999+
new_index = self.index
4000+
4001+
if inplace:
4002+
self._update_inplace(row_partitions=new_row_partitions,
4003+
col_partitions=new_column_partitions,
4004+
columns=new_columns,
4005+
index=new_index)
4006+
else:
4007+
return DataFrame(row_partitions=new_row_partitions,
4008+
col_partitions=new_column_partitions,
4009+
columns=new_columns,
4010+
index=new_index)
38514011

38524012
def sortlevel(self, level=0, axis=0, ascending=True, inplace=False,
38534013
sort_remaining=True):

python/ray/dataframe/test/test_dataframe.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -969,8 +969,6 @@ def test_append():
969969

970970
pandas_df2 = pd.DataFrame({"col5": [0], "col6": [1]})
971971

972-
print(ray_df.append(ray_df2))
973-
974972
assert ray_df_equals_pandas(ray_df.append(ray_df2),
975973
pandas_df.append(pandas_df2))
976974

@@ -2829,17 +2827,43 @@ def test_slice_shift():
28292827

28302828

28312829
def test_sort_index():
2832-
ray_df = create_test_dataframe()
2830+
pandas_df = pd.DataFrame(np.random.randint(0, 100, size=(1000, 100)))
2831+
ray_df = rdf.DataFrame(pandas_df)
28332832

2834-
with pytest.raises(NotImplementedError):
2835-
ray_df.sort_index()
2833+
pandas_result = pandas_df.sort_index()
2834+
ray_result = ray_df.sort_index()
2835+
2836+
ray_df_equals_pandas(ray_result, pandas_result)
2837+
2838+
pandas_result = pandas_df.sort_index(ascending=False)
2839+
ray_result = ray_df.sort_index(ascending=False)
2840+
2841+
ray_df_equals_pandas(ray_result, pandas_result)
28362842

28372843

28382844
def test_sort_values():
2839-
ray_df = create_test_dataframe()
2845+
pandas_df = pd.DataFrame(np.random.randint(0, 100, size=(1000, 100)))
2846+
ray_df = rdf.DataFrame(pandas_df)
28402847

2841-
with pytest.raises(NotImplementedError):
2842-
ray_df.sort_values(None)
2848+
pandas_result = pandas_df.sort_values(by=1)
2849+
ray_result = ray_df.sort_values(by=1)
2850+
2851+
ray_df_equals_pandas(ray_result, pandas_result)
2852+
2853+
pandas_result = pandas_df.sort_values(by=1, axis=1)
2854+
ray_result = ray_df.sort_values(by=1, axis=1)
2855+
2856+
ray_df_equals_pandas(ray_result, pandas_result)
2857+
2858+
pandas_result = pandas_df.sort_values(by=[1, 3])
2859+
ray_result = ray_df.sort_values(by=[1, 3])
2860+
2861+
ray_df_equals_pandas(ray_result, pandas_result)
2862+
2863+
pandas_result = pandas_df.sort_values(by=[1, 67], axis=1)
2864+
ray_result = ray_df.sort_values(by=[1, 67], axis=1)
2865+
2866+
ray_df_equals_pandas(ray_result, pandas_result)
28432867

28442868

28452869
def test_sortlevel():

0 commit comments

Comments
 (0)