Skip to content

Commit 78b1004

Browse files
authored
Implemented GroupBy.median() (#1957)
This PR proposes `GroupBy.median()`. Note: the result can be slightly different from pandas since we use an approximated median based upon approximate percentile computation because computing median across a large dataset is extremely expensive. ```python >>> kdf = ks.DataFrame({'a': [1., 1., 1., 1., 2., 2., 2., 3., 3., 3.], ... 'b': [2., 3., 1., 4., 6., 9., 8., 10., 7., 5.], ... 'c': [3., 5., 2., 5., 1., 2., 6., 4., 3., 6.]}, ... columns=['a', 'b', 'c'], ... index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6]) >>> kdf a b c 7 1.0 2.0 3.0 2 1.0 3.0 5.0 4 1.0 1.0 2.0 1 1.0 4.0 5.0 3 2.0 6.0 1.0 4 2.0 9.0 2.0 9 2.0 8.0 6.0 10 3.0 10.0 4.0 5 3.0 7.0 3.0 6 3.0 5.0 6.0 >>> kdf.groupby('a').median().sort_index() # doctest: +NORMALIZE_WHITESPACE b c a 1.0 2.0 3.0 2.0 8.0 2.0 3.0 7.0 4.0 >>> kdf.groupby('a')['b'].median().sort_index() a 1.0 2.0 2.0 8.0 3.0 7.0 Name: b, dtype: float64 ``` ref #1929
1 parent bb31489 commit 78b1004

File tree

4 files changed

+93
-2
lines changed

4 files changed

+93
-2
lines changed

databricks/koalas/groupby.py

+68
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from databricks.koalas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
7272
from databricks.koalas.window import RollingGroupby, ExpandingGroupby
7373
from databricks.koalas.exceptions import DataError
74+
from databricks.koalas.spark import functions as SF
7475

7576
# to keep it the same as pandas
7677
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
@@ -2343,6 +2344,73 @@ def get_group(self, name) -> Union[DataFrame, Series]:
23432344

23442345
return DataFrame(internal)
23452346

2347+
def median(self, numeric_only=True, accuracy=10000) -> Union[DataFrame, Series]:
2348+
"""
2349+
Compute median of groups, excluding missing values.
2350+
2351+
For multiple groupings, the result index will be a MultiIndex
2352+
2353+
.. note:: Unlike pandas', the median in Koalas is an approximated median based upon
2354+
approximate percentile computation because computing median across a large dataset
2355+
is extremely expensive.
2356+
2357+
Parameters
2358+
----------
2359+
numeric_only : bool, default True
2360+
Include only float, int, boolean columns. False is not supported. This parameter
2361+
is mainly for pandas compatibility.
2362+
2363+
Returns
2364+
-------
2365+
Series or DataFrame
2366+
Median of values within each group.
2367+
2368+
Examples
2369+
--------
2370+
>>> kdf = ks.DataFrame({'a': [1., 1., 1., 1., 2., 2., 2., 3., 3., 3.],
2371+
... 'b': [2., 3., 1., 4., 6., 9., 8., 10., 7., 5.],
2372+
... 'c': [3., 5., 2., 5., 1., 2., 6., 4., 3., 6.]},
2373+
... columns=['a', 'b', 'c'],
2374+
... index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6])
2375+
>>> kdf
2376+
a b c
2377+
7 1.0 2.0 3.0
2378+
2 1.0 3.0 5.0
2379+
4 1.0 1.0 2.0
2380+
1 1.0 4.0 5.0
2381+
3 2.0 6.0 1.0
2382+
4 2.0 9.0 2.0
2383+
9 2.0 8.0 6.0
2384+
10 3.0 10.0 4.0
2385+
5 3.0 7.0 3.0
2386+
6 3.0 5.0 6.0
2387+
2388+
DataFrameGroupBy
2389+
2390+
>>> kdf.groupby('a').median().sort_index() # doctest: +NORMALIZE_WHITESPACE
2391+
b c
2392+
a
2393+
1.0 2.0 3.0
2394+
2.0 8.0 2.0
2395+
3.0 7.0 4.0
2396+
2397+
SeriesGroupBy
2398+
2399+
>>> kdf.groupby('a')['b'].median().sort_index()
2400+
a
2401+
1.0 2.0
2402+
2.0 8.0
2403+
3.0 7.0
2404+
Name: b, dtype: float64
2405+
"""
2406+
if not isinstance(accuracy, int):
2407+
raise ValueError(
2408+
"accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
2409+
)
2410+
2411+
stat_function = lambda col: SF.percentile_approx(col, 0.5, accuracy)
2412+
return self._reduce_for_stat_function(stat_function, only_numeric=numeric_only)
2413+
23462414
def _reduce_for_stat_function(self, sfun, only_numeric):
23472415
agg_columns = self._agg_columns
23482416
agg_columns_scols = self._agg_columns_scols

databricks/koalas/missing/groupby.py

-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class MissingPandasLikeDataFrameGroupBy(object):
5757

5858
# Functions
5959
boxplot = _unsupported_function("boxplot")
60-
median = _unsupported_function("median")
6160
ngroup = _unsupported_function("ngroup")
6261
nth = _unsupported_function("nth")
6362
ohlc = _unsupported_function("ohlc")
@@ -93,7 +92,6 @@ class MissingPandasLikeSeriesGroupBy(object):
9392
agg = _unsupported_function("agg")
9493
aggregate = _unsupported_function("aggregate")
9594
describe = _unsupported_function("describe")
96-
median = _unsupported_function("median")
9795
ngroup = _unsupported_function("ngroup")
9896
nth = _unsupported_function("nth")
9997
ohlc = _unsupported_function("ohlc")

databricks/koalas/tests/test_groupby.py

+24
Original file line numberDiff line numberDiff line change
@@ -2609,6 +2609,30 @@ def test_get_group(self):
26092609
ValueError, lambda: kdf.groupby([("B", "class"), ("A", "name")]).get_group("mammal")
26102610
)
26112611

2612+
def test_median(self):
2613+
kdf = ks.DataFrame(
2614+
{
2615+
"a": [1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0],
2616+
"b": [2.0, 3.0, 1.0, 4.0, 6.0, 9.0, 8.0, 10.0, 7.0, 5.0],
2617+
"c": [3.0, 5.0, 2.0, 5.0, 1.0, 2.0, 6.0, 4.0, 3.0, 6.0],
2618+
},
2619+
columns=["a", "b", "c"],
2620+
index=[7, 2, 4, 1, 3, 4, 9, 10, 5, 6],
2621+
)
2622+
# DataFrame
2623+
expected_result = ks.DataFrame(
2624+
{"b": [2.0, 8.0, 7.0], "c": [3.0, 2.0, 4.0]}, index=pd.Index([1.0, 2.0, 3.0], name="a")
2625+
)
2626+
self.assert_eq(expected_result, kdf.groupby("a").median().sort_index())
2627+
# Series
2628+
expected_result = ks.Series(
2629+
[2.0, 8.0, 7.0], name="b", index=pd.Index([1.0, 2.0, 3.0], name="a")
2630+
)
2631+
self.assert_eq(expected_result, kdf.groupby("a")["b"].median().sort_index())
2632+
2633+
with self.assertRaisesRegex(ValueError, "accuracy must be an integer; however"):
2634+
kdf.groupby("a").median(accuracy="a")
2635+
26122636
def test_tail(self):
26132637
pdf = pd.DataFrame(
26142638
{

docs/source/reference/groupby.rst

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Computations / Descriptive Stats
5151
GroupBy.last
5252
GroupBy.max
5353
GroupBy.mean
54+
GroupBy.median
5455
GroupBy.min
5556
GroupBy.rank
5657
GroupBy.std

0 commit comments

Comments
 (0)