16
16
from functools import partial
17
17
from typing import Any
18
18
19
- from databricks .koalas .internal import _InternalFrame , SPARK_INDEX_NAME_FORMAT
19
+ from databricks .koalas .internal import SPARK_INDEX_NAME_FORMAT
20
20
from databricks .koalas .utils import name_like_string
21
21
from pyspark .sql import Window
22
22
from pyspark .sql import functions as F
29
29
30
30
31
31
class _RollingAndExpanding (object ):
32
- pass
32
+ def __init__ (self , window , index_scols , min_periods ):
33
+ self ._window = window
34
+ # This unbounded Window is later used to handle 'min_periods' for now.
35
+ self ._unbounded_window = Window .orderBy (index_scols ).rowsBetween (
36
+ Window .unboundedPreceding , Window .currentRow )
37
+ self ._min_periods = min_periods
38
+
39
+ def _apply_as_series_or_frame (self , func ):
40
+ """
41
+ Wraps a function that handles Spark column in order
42
+ to support it in both Koalas Series and DataFrame.
43
+ Note that the given `func` name should be same as the API's method name.
44
+ """
45
+ raise NotImplementedError (
46
+ "A class that inherits this class should implement this method "
47
+ "to handle the index and columns of output." )
48
+
49
+ def count (self ):
50
+ def count (scol ):
51
+ return F .count (scol ).over (self ._window )
52
+ return self ._apply_as_series_or_frame (count ).astype ('float64' )
53
+
54
+ def sum (self ):
55
+ def sum (scol ):
56
+ return F .when (
57
+ F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
58
+ F .sum (scol ).over (self ._window )
59
+ ).otherwise (F .lit (None ))
60
+
61
+ return self ._apply_as_series_or_frame (sum )
62
+
63
+ def min (self ):
64
+ def min (scol ):
65
+ return F .when (
66
+ F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
67
+ F .min (scol ).over (self ._window )
68
+ ).otherwise (F .lit (None ))
69
+
70
+ return self ._apply_as_series_or_frame (min )
71
+
72
+ def max (self ):
73
+ def max (scol ):
74
+ return F .when (
75
+ F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
76
+ F .max (scol ).over (self ._window )
77
+ ).otherwise (F .lit (None ))
78
+
79
+ return self ._apply_as_series_or_frame (max )
80
+
81
+ def mean (self ):
82
+ def mean (scol ):
83
+ return F .when (
84
+ F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
85
+ F .mean (scol ).over (self ._window )
86
+ ).otherwise (F .lit (None ))
87
+
88
+ return self ._apply_as_series_or_frame (mean )
33
89
34
90
35
91
class Rolling (_RollingAndExpanding ):
36
92
def __init__ (self , kdf_or_kser , window , min_periods = None ):
37
93
from databricks .koalas import DataFrame , Series
38
- from databricks .koalas .groupby import SeriesGroupBy , DataFrameGroupBy
39
94
40
95
if window < 0 :
41
96
raise ValueError ("window must be >= 0" )
42
97
if (min_periods is not None ) and (min_periods < 0 ):
43
98
raise ValueError ("min_periods must be >= 0" )
44
99
self ._window_val = window
45
100
if min_periods is not None :
46
- self . _min_periods = min_periods
101
+ min_periods = min_periods
47
102
else :
48
103
# TODO: 'min_periods' is not equivalent in pandas because it does not count NA as
49
104
# a value.
50
- self . _min_periods = window
105
+ min_periods = window
51
106
52
107
self .kdf_or_kser = kdf_or_kser
53
- if not isinstance (kdf_or_kser , (DataFrame , Series , DataFrameGroupBy , SeriesGroupBy )):
108
+ if not isinstance (kdf_or_kser , (DataFrame , Series )):
54
109
raise TypeError (
55
110
"kdf_or_kser must be a series or dataframe; however, got: %s" % type (kdf_or_kser ))
56
- if isinstance (kdf_or_kser , (DataFrame , Series )):
57
- self ._index_scols = kdf_or_kser ._internal .index_scols
58
- self ._window = Window .orderBy (self ._index_scols ).rowsBetween (
59
- Window .currentRow - (self ._window_val - 1 ), Window .currentRow )
111
+ index_scols = kdf_or_kser ._internal .index_scols
112
+ window = Window .orderBy (index_scols ).rowsBetween (
113
+ Window .currentRow - (self ._window_val - 1 ), Window .currentRow )
60
114
61
- self . _unbounded_window = Window . orderBy ( self . _index_scols )
115
+ super ( Rolling , self ). __init__ ( window , index_scols , min_periods )
62
116
63
117
def __getattr__ (self , item : str ) -> Any :
64
118
if hasattr (_MissingPandasLikeRolling , item ):
@@ -143,10 +197,7 @@ def count(self):
143
197
2 2.0
144
198
3 2.0
145
199
"""
146
- def count (scol ):
147
- return F .count (scol ).over (self ._window )
148
-
149
- return self ._apply_as_series_or_frame (count ).astype ('float64' )
200
+ return super (Rolling , self ).count ()
150
201
151
202
def sum (self ):
152
203
"""
@@ -224,13 +275,7 @@ def sum(self):
224
275
3 10.0 38.0
225
276
4 13.0 65.0
226
277
"""
227
- def sum (scol ):
228
- return F .when (
229
- F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
230
- F .sum (scol ).over (self ._window )
231
- ).otherwise (F .lit (None ))
232
-
233
- return self ._apply_as_series_or_frame (sum )
278
+ return super (Rolling , self ).sum ()
234
279
235
280
def min (self ):
236
281
"""
@@ -308,13 +353,7 @@ def min(self):
308
353
3 2.0 4.0
309
354
4 2.0 4.0
310
355
"""
311
- def min (scol ):
312
- return F .when (
313
- F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
314
- F .min (scol ).over (self ._window )
315
- ).otherwise (F .lit (None ))
316
-
317
- return self ._apply_as_series_or_frame (min )
356
+ return super (Rolling , self ).min ()
318
357
319
358
def max (self ):
320
359
"""
@@ -391,13 +430,7 @@ def max(self):
391
430
3 5.0 25.0
392
431
4 6.0 36.0
393
432
"""
394
- def max (scol ):
395
- return F .when (
396
- F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
397
- F .max (scol ).over (self ._window )
398
- ).otherwise (F .lit (None ))
399
-
400
- return self ._apply_as_series_or_frame (max )
433
+ return super (Rolling , self ).max ()
401
434
402
435
def mean (self ):
403
436
"""
@@ -475,13 +508,7 @@ def mean(self):
475
508
3 3.333333 12.666667
476
509
4 4.333333 21.666667
477
510
"""
478
- def mean (scol ):
479
- return F .when (
480
- F .row_number ().over (self ._unbounded_window ) >= self ._min_periods ,
481
- F .mean (scol ).over (self ._window )
482
- ).otherwise (F .lit (None ))
483
-
484
- return self ._apply_as_series_or_frame (mean )
511
+ return super (Rolling , self ).mean ()
485
512
486
513
487
514
class RollingGroupby (Rolling ):
@@ -866,14 +893,14 @@ def __init__(self, kdf_or_kser, min_periods=1):
866
893
867
894
if min_periods < 0 :
868
895
raise ValueError ("min_periods must be >= 0" )
869
- self ._min_periods = min_periods
870
896
self .kdf_or_kser = kdf_or_kser
871
897
if not isinstance (kdf_or_kser , (DataFrame , Series )):
872
898
raise TypeError (
873
899
"kdf_or_kser must be a series or dataframe; however, got: %s" % type (kdf_or_kser ))
874
900
index_scols = kdf_or_kser ._internal .index_scols
875
- self . _window = Window .orderBy (index_scols ).rowsBetween (
901
+ window = Window .orderBy (index_scols ).rowsBetween (
876
902
Window .unboundedPreceding , Window .currentRow )
903
+ super (Expanding , self ).__init__ (window , index_scols , min_periods )
877
904
878
905
def __getattr__ (self , item : str ) -> Any :
879
906
if hasattr (_MissingPandasLikeExpanding , item ):
@@ -954,15 +981,7 @@ def count(self):
954
981
2 2.0
955
982
3 3.0
956
983
"""
957
- def count (scol ):
958
- # TODO: is this a bug? min_periods is not respected in expanding().count() in pandas.
959
- # return F.when(
960
- # F.row_number().over(self._window) >= self._min_periods,
961
- # F.count(scol).over(self._window)
962
- # ).otherwise(F.lit(None))
963
- return F .count (scol ).over (self ._window )
964
-
965
- return self ._apply_as_series_or_frame (count ).astype ('float64' )
984
+ return super (Expanding , self ).count ()
966
985
967
986
def sum (self ):
968
987
"""
@@ -1024,13 +1043,7 @@ def sum(self):
1024
1043
3 10.0 30.0
1025
1044
4 15.0 55.0
1026
1045
"""
1027
- def sum (scol ):
1028
- return F .when (
1029
- F .row_number ().over (self ._window ) >= self ._min_periods ,
1030
- F .sum (scol ).over (self ._window )
1031
- ).otherwise (F .lit (None ))
1032
-
1033
- return self ._apply_as_series_or_frame (sum )
1046
+ return super (Expanding , self ).sum ()
1034
1047
1035
1048
def min (self ):
1036
1049
"""
@@ -1067,13 +1080,7 @@ def min(self):
1067
1080
4 2.0
1068
1081
Name: 0, dtype: float64
1069
1082
"""
1070
- def min (scol ):
1071
- return F .when (
1072
- F .row_number ().over (self ._window ) >= self ._min_periods ,
1073
- F .min (scol ).over (self ._window )
1074
- ).otherwise (F .lit (None ))
1075
-
1076
- return self ._apply_as_series_or_frame (min )
1083
+ return super (Expanding , self ).min ()
1077
1084
1078
1085
def max (self ):
1079
1086
"""
@@ -1096,13 +1103,7 @@ def max(self):
1096
1103
Series.max : Similar method for Series.
1097
1104
DataFrame.max : Similar method for DataFrame.
1098
1105
"""
1099
- def max (scol ):
1100
- return F .when (
1101
- F .row_number ().over (self ._window ) >= self ._min_periods ,
1102
- F .max (scol ).over (self ._window )
1103
- ).otherwise (F .lit (None ))
1104
-
1105
- return self ._apply_as_series_or_frame (max )
1106
+ return super (Expanding , self ).max ()
1106
1107
1107
1108
def mean (self ):
1108
1109
"""
@@ -1146,13 +1147,7 @@ def mean(self):
1146
1147
3 2.5
1147
1148
Name: 0, dtype: float64
1148
1149
"""
1149
- def mean (scol ):
1150
- return F .when (
1151
- F .row_number ().over (self ._window ) >= self ._min_periods ,
1152
- F .mean (scol ).over (self ._window )
1153
- ).otherwise (F .lit (None ))
1154
-
1155
- return self ._apply_as_series_or_frame (mean )
1150
+ return super (Expanding , self ).mean ()
1156
1151
1157
1152
1158
1153
class ExpandingGroupby (Expanding ):
@@ -1177,6 +1172,8 @@ def __init__(self, groupby, groupkeys, min_periods=1):
1177
1172
# being a different series.
1178
1173
self ._window = self ._window .partitionBy (
1179
1174
* [F .col (name_like_string (ser .name )) for ser in groupkeys ])
1175
+ self ._unbounded_window = self ._window .partitionBy (
1176
+ * [F .col (name_like_string (ser .name )) for ser in groupkeys ])
1180
1177
self ._groupkeys = groupkeys
1181
1178
# Current implementation reuses DataFrameGroupBy implementations for Series as well.
1182
1179
self .kdf = self .kdf_or_kser
0 commit comments