Skip to content

Commit 0baa4f0

Browse files
authored
Deduplicate Expanding and Rolling codes at _RollingAndExpanding (#1036)
This PR proposes to deduplicate the method implementation at `Expanding` and `Rolling` to `_RollingAndExpanding`. Now, windowing logics of `Expanding`, `ExpandingGroupBy`, `Rolling` and `RollingGroupBy` are all located in `_RollingAndExpanding`.
1 parent 5390538 commit 0baa4f0

File tree

1 file changed

+79
-82
lines changed

1 file changed

+79
-82
lines changed

databricks/koalas/window.py

+79-82
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import partial
1717
from typing import Any
1818

19-
from databricks.koalas.internal import _InternalFrame, SPARK_INDEX_NAME_FORMAT
19+
from databricks.koalas.internal import SPARK_INDEX_NAME_FORMAT
2020
from databricks.koalas.utils import name_like_string
2121
from pyspark.sql import Window
2222
from pyspark.sql import functions as F
@@ -29,36 +29,90 @@
2929

3030

3131
class _RollingAndExpanding(object):
32-
pass
32+
def __init__(self, window, index_scols, min_periods):
33+
self._window = window
34+
# This unbounded Window is later used to handle 'min_periods' for now.
35+
self._unbounded_window = Window.orderBy(index_scols).rowsBetween(
36+
Window.unboundedPreceding, Window.currentRow)
37+
self._min_periods = min_periods
38+
39+
def _apply_as_series_or_frame(self, func):
40+
"""
41+
Wraps a function that handles Spark column in order
42+
to support it in both Koalas Series and DataFrame.
43+
Note that the given `func` name should be same as the API's method name.
44+
"""
45+
raise NotImplementedError(
46+
"A class that inherits this class should implement this method "
47+
"to handle the index and columns of output.")
48+
49+
def count(self):
50+
def count(scol):
51+
return F.count(scol).over(self._window)
52+
return self._apply_as_series_or_frame(count).astype('float64')
53+
54+
def sum(self):
55+
def sum(scol):
56+
return F.when(
57+
F.row_number().over(self._unbounded_window) >= self._min_periods,
58+
F.sum(scol).over(self._window)
59+
).otherwise(F.lit(None))
60+
61+
return self._apply_as_series_or_frame(sum)
62+
63+
def min(self):
64+
def min(scol):
65+
return F.when(
66+
F.row_number().over(self._unbounded_window) >= self._min_periods,
67+
F.min(scol).over(self._window)
68+
).otherwise(F.lit(None))
69+
70+
return self._apply_as_series_or_frame(min)
71+
72+
def max(self):
73+
def max(scol):
74+
return F.when(
75+
F.row_number().over(self._unbounded_window) >= self._min_periods,
76+
F.max(scol).over(self._window)
77+
).otherwise(F.lit(None))
78+
79+
return self._apply_as_series_or_frame(max)
80+
81+
def mean(self):
82+
def mean(scol):
83+
return F.when(
84+
F.row_number().over(self._unbounded_window) >= self._min_periods,
85+
F.mean(scol).over(self._window)
86+
).otherwise(F.lit(None))
87+
88+
return self._apply_as_series_or_frame(mean)
3389

3490

3591
class Rolling(_RollingAndExpanding):
3692
def __init__(self, kdf_or_kser, window, min_periods=None):
3793
from databricks.koalas import DataFrame, Series
38-
from databricks.koalas.groupby import SeriesGroupBy, DataFrameGroupBy
3994

4095
if window < 0:
4196
raise ValueError("window must be >= 0")
4297
if (min_periods is not None) and (min_periods < 0):
4398
raise ValueError("min_periods must be >= 0")
4499
self._window_val = window
45100
if min_periods is not None:
46-
self._min_periods = min_periods
101+
min_periods = min_periods
47102
else:
48103
# TODO: 'min_periods' is not equivalent in pandas because it does not count NA as
49104
# a value.
50-
self._min_periods = window
105+
min_periods = window
51106

52107
self.kdf_or_kser = kdf_or_kser
53-
if not isinstance(kdf_or_kser, (DataFrame, Series, DataFrameGroupBy, SeriesGroupBy)):
108+
if not isinstance(kdf_or_kser, (DataFrame, Series)):
54109
raise TypeError(
55110
"kdf_or_kser must be a series or dataframe; however, got: %s" % type(kdf_or_kser))
56-
if isinstance(kdf_or_kser, (DataFrame, Series)):
57-
self._index_scols = kdf_or_kser._internal.index_scols
58-
self._window = Window.orderBy(self._index_scols).rowsBetween(
59-
Window.currentRow - (self._window_val-1), Window.currentRow)
111+
index_scols = kdf_or_kser._internal.index_scols
112+
window = Window.orderBy(index_scols).rowsBetween(
113+
Window.currentRow - (self._window_val - 1), Window.currentRow)
60114

61-
self._unbounded_window = Window.orderBy(self._index_scols)
115+
super(Rolling, self).__init__(window, index_scols, min_periods)
62116

63117
def __getattr__(self, item: str) -> Any:
64118
if hasattr(_MissingPandasLikeRolling, item):
@@ -143,10 +197,7 @@ def count(self):
143197
2 2.0
144198
3 2.0
145199
"""
146-
def count(scol):
147-
return F.count(scol).over(self._window)
148-
149-
return self._apply_as_series_or_frame(count).astype('float64')
200+
return super(Rolling, self).count()
150201

151202
def sum(self):
152203
"""
@@ -224,13 +275,7 @@ def sum(self):
224275
3 10.0 38.0
225276
4 13.0 65.0
226277
"""
227-
def sum(scol):
228-
return F.when(
229-
F.row_number().over(self._unbounded_window) >= self._min_periods,
230-
F.sum(scol).over(self._window)
231-
).otherwise(F.lit(None))
232-
233-
return self._apply_as_series_or_frame(sum)
278+
return super(Rolling, self).sum()
234279

235280
def min(self):
236281
"""
@@ -308,13 +353,7 @@ def min(self):
308353
3 2.0 4.0
309354
4 2.0 4.0
310355
"""
311-
def min(scol):
312-
return F.when(
313-
F.row_number().over(self._unbounded_window) >= self._min_periods,
314-
F.min(scol).over(self._window)
315-
).otherwise(F.lit(None))
316-
317-
return self._apply_as_series_or_frame(min)
356+
return super(Rolling, self).min()
318357

319358
def max(self):
320359
"""
@@ -391,13 +430,7 @@ def max(self):
391430
3 5.0 25.0
392431
4 6.0 36.0
393432
"""
394-
def max(scol):
395-
return F.when(
396-
F.row_number().over(self._unbounded_window) >= self._min_periods,
397-
F.max(scol).over(self._window)
398-
).otherwise(F.lit(None))
399-
400-
return self._apply_as_series_or_frame(max)
433+
return super(Rolling, self).max()
401434

402435
def mean(self):
403436
"""
@@ -475,13 +508,7 @@ def mean(self):
475508
3 3.333333 12.666667
476509
4 4.333333 21.666667
477510
"""
478-
def mean(scol):
479-
return F.when(
480-
F.row_number().over(self._unbounded_window) >= self._min_periods,
481-
F.mean(scol).over(self._window)
482-
).otherwise(F.lit(None))
483-
484-
return self._apply_as_series_or_frame(mean)
511+
return super(Rolling, self).mean()
485512

486513

487514
class RollingGroupby(Rolling):
@@ -866,14 +893,14 @@ def __init__(self, kdf_or_kser, min_periods=1):
866893

867894
if min_periods < 0:
868895
raise ValueError("min_periods must be >= 0")
869-
self._min_periods = min_periods
870896
self.kdf_or_kser = kdf_or_kser
871897
if not isinstance(kdf_or_kser, (DataFrame, Series)):
872898
raise TypeError(
873899
"kdf_or_kser must be a series or dataframe; however, got: %s" % type(kdf_or_kser))
874900
index_scols = kdf_or_kser._internal.index_scols
875-
self._window = Window.orderBy(index_scols).rowsBetween(
901+
window = Window.orderBy(index_scols).rowsBetween(
876902
Window.unboundedPreceding, Window.currentRow)
903+
super(Expanding, self).__init__(window, index_scols, min_periods)
877904

878905
def __getattr__(self, item: str) -> Any:
879906
if hasattr(_MissingPandasLikeExpanding, item):
@@ -954,15 +981,7 @@ def count(self):
954981
2 2.0
955982
3 3.0
956983
"""
957-
def count(scol):
958-
# TODO: is this a bug? min_periods is not respected in expanding().count() in pandas.
959-
# return F.when(
960-
# F.row_number().over(self._window) >= self._min_periods,
961-
# F.count(scol).over(self._window)
962-
# ).otherwise(F.lit(None))
963-
return F.count(scol).over(self._window)
964-
965-
return self._apply_as_series_or_frame(count).astype('float64')
984+
return super(Expanding, self).count()
966985

967986
def sum(self):
968987
"""
@@ -1024,13 +1043,7 @@ def sum(self):
10241043
3 10.0 30.0
10251044
4 15.0 55.0
10261045
"""
1027-
def sum(scol):
1028-
return F.when(
1029-
F.row_number().over(self._window) >= self._min_periods,
1030-
F.sum(scol).over(self._window)
1031-
).otherwise(F.lit(None))
1032-
1033-
return self._apply_as_series_or_frame(sum)
1046+
return super(Expanding, self).sum()
10341047

10351048
def min(self):
10361049
"""
@@ -1067,13 +1080,7 @@ def min(self):
10671080
4 2.0
10681081
Name: 0, dtype: float64
10691082
"""
1070-
def min(scol):
1071-
return F.when(
1072-
F.row_number().over(self._window) >= self._min_periods,
1073-
F.min(scol).over(self._window)
1074-
).otherwise(F.lit(None))
1075-
1076-
return self._apply_as_series_or_frame(min)
1083+
return super(Expanding, self).min()
10771084

10781085
def max(self):
10791086
"""
@@ -1096,13 +1103,7 @@ def max(self):
10961103
Series.max : Similar method for Series.
10971104
DataFrame.max : Similar method for DataFrame.
10981105
"""
1099-
def max(scol):
1100-
return F.when(
1101-
F.row_number().over(self._window) >= self._min_periods,
1102-
F.max(scol).over(self._window)
1103-
).otherwise(F.lit(None))
1104-
1105-
return self._apply_as_series_or_frame(max)
1106+
return super(Expanding, self).max()
11061107

11071108
def mean(self):
11081109
"""
@@ -1146,13 +1147,7 @@ def mean(self):
11461147
3 2.5
11471148
Name: 0, dtype: float64
11481149
"""
1149-
def mean(scol):
1150-
return F.when(
1151-
F.row_number().over(self._window) >= self._min_periods,
1152-
F.mean(scol).over(self._window)
1153-
).otherwise(F.lit(None))
1154-
1155-
return self._apply_as_series_or_frame(mean)
1150+
return super(Expanding, self).mean()
11561151

11571152

11581153
class ExpandingGroupby(Expanding):
@@ -1177,6 +1172,8 @@ def __init__(self, groupby, groupkeys, min_periods=1):
11771172
# being a different series.
11781173
self._window = self._window.partitionBy(
11791174
*[F.col(name_like_string(ser.name)) for ser in groupkeys])
1175+
self._unbounded_window = self._window.partitionBy(
1176+
*[F.col(name_like_string(ser.name)) for ser in groupkeys])
11801177
self._groupkeys = groupkeys
11811178
# Current implementation reuses DataFrameGroupBy implementations for Series as well.
11821179
self.kdf = self.kdf_or_kser

0 commit comments

Comments
 (0)