Skip to content

Commit c4cb39d

Browse files
feat: Add agg/aggregate methods to windows (#2288)
1 parent 2dcf6ae commit c4cb39d

File tree

3 files changed

+248
-40
lines changed

3 files changed

+248
-40
lines changed

bigframes/core/window/rolling.py

Lines changed: 130 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@
1515
from __future__ import annotations
1616

1717
import datetime
18-
import typing
18+
from typing import Literal, Mapping, Sequence, TYPE_CHECKING, Union
1919

2020
import bigframes_vendored.pandas.core.window.rolling as vendored_pandas_rolling
2121
import numpy
2222
import pandas
2323

2424
from bigframes import dtypes
25+
from bigframes.core import agg_expressions
2526
from bigframes.core import expression as ex
26-
from bigframes.core import log_adapter, ordering, window_spec
27+
from bigframes.core import log_adapter, ordering, utils, window_spec
2728
import bigframes.core.blocks as blocks
2829
from bigframes.core.window import ordering as window_ordering
2930
import bigframes.operations.aggregations as agg_ops
3031

32+
if TYPE_CHECKING:
33+
import bigframes.dataframe as df
34+
import bigframes.series as series
35+
3136

3237
@log_adapter.class_logger
3338
class Window(vendored_pandas_rolling.Window):
@@ -37,7 +42,7 @@ def __init__(
3742
self,
3843
block: blocks.Block,
3944
window_spec: window_spec.WindowSpec,
40-
value_column_ids: typing.Sequence[str],
45+
value_column_ids: Sequence[str],
4146
drop_null_groups: bool = True,
4247
is_series: bool = False,
4348
skip_agg_column_id: str | None = None,
@@ -52,55 +57,106 @@ def __init__(
5257
self._skip_agg_column_id = skip_agg_column_id
5358

5459
def count(self):
55-
return self._apply_aggregate(agg_ops.count_op)
60+
return self._apply_aggregate_op(agg_ops.count_op)
5661

5762
def sum(self):
58-
return self._apply_aggregate(agg_ops.sum_op)
63+
return self._apply_aggregate_op(agg_ops.sum_op)
5964

6065
def mean(self):
61-
return self._apply_aggregate(agg_ops.mean_op)
66+
return self._apply_aggregate_op(agg_ops.mean_op)
6267

6368
def var(self):
64-
return self._apply_aggregate(agg_ops.var_op)
69+
return self._apply_aggregate_op(agg_ops.var_op)
6570

6671
def std(self):
67-
return self._apply_aggregate(agg_ops.std_op)
72+
return self._apply_aggregate_op(agg_ops.std_op)
6873

6974
def max(self):
70-
return self._apply_aggregate(agg_ops.max_op)
75+
return self._apply_aggregate_op(agg_ops.max_op)
7176

7277
def min(self):
73-
return self._apply_aggregate(agg_ops.min_op)
78+
return self._apply_aggregate_op(agg_ops.min_op)
7479

75-
def _apply_aggregate(
76-
self,
77-
op: agg_ops.UnaryAggregateOp,
78-
):
79-
agg_block = self._aggregate_block(op)
80+
def agg(self, func) -> Union[df.DataFrame, series.Series]:
81+
if utils.is_dict_like(func):
82+
return self._agg_dict(func)
83+
elif utils.is_list_like(func):
84+
return self._agg_list(func)
85+
else:
86+
return self._agg_func(func)
8087

81-
if self._is_series:
82-
from bigframes.series import Series
88+
aggregate = agg
89+
90+
def _agg_func(self, func) -> df.DataFrame:
91+
ids, labels = self._aggregated_columns()
92+
aggregations = [agg(col_id, agg_ops.lookup_agg_func(func)[0]) for col_id in ids]
93+
return self._apply_aggs(aggregations, labels)
94+
95+
def _agg_dict(self, func: Mapping) -> df.DataFrame:
96+
aggregations: list[agg_expressions.Aggregation] = []
97+
column_labels = []
98+
function_labels = []
8399

84-
return Series(agg_block)
100+
want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values())
101+
102+
for label, funcs_for_id in func.items():
103+
col_id = self._block.label_to_col_id[label][-1] # get last matching column
104+
func_list = (
105+
funcs_for_id if utils.is_list_like(funcs_for_id) else [funcs_for_id]
106+
)
107+
for f in func_list:
108+
f_op, f_label = agg_ops.lookup_agg_func(f)
109+
aggregations.append(agg(col_id, f_op))
110+
column_labels.append(label)
111+
function_labels.append(f_label)
112+
if want_aggfunc_level:
113+
result_labels: pandas.Index = utils.combine_indices(
114+
pandas.Index(column_labels),
115+
pandas.Index(function_labels),
116+
)
85117
else:
86-
from bigframes.dataframe import DataFrame
118+
result_labels = pandas.Index(column_labels)
87119

88-
# Preserve column order.
89-
column_labels = [
90-
self._block.col_id_to_label[col_id] for col_id in self._value_column_ids
91-
]
92-
return DataFrame(agg_block)._reindex_columns(column_labels)
120+
return self._apply_aggs(aggregations, result_labels)
93121

94-
def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
95-
agg_col_ids = [
96-
col_id
97-
for col_id in self._value_column_ids
98-
if col_id != self._skip_agg_column_id
122+
def _agg_list(self, func: Sequence) -> df.DataFrame:
123+
ids, labels = self._aggregated_columns()
124+
aggregations = [
125+
agg(col_id, agg_ops.lookup_agg_func(f)[0]) for col_id in ids for f in func
99126
]
100-
block, result_ids = self._block.multi_apply_window_op(
101-
agg_col_ids,
102-
op,
103-
self._window_spec,
127+
128+
if self._is_series:
129+
# if series, no need to rebuild
130+
result_cols_idx = pandas.Index(
131+
[agg_ops.lookup_agg_func(f)[1] for f in func]
132+
)
133+
else:
134+
if self._block.column_labels.nlevels > 1:
135+
# Restructure MultiIndex for proper format: (idx1, idx2, func)
136+
# rather than ((idx1, idx2), func).
137+
column_labels = [
138+
tuple(label) + (agg_ops.lookup_agg_func(f)[1],)
139+
for label in labels.to_frame(index=False).to_numpy()
140+
for f in func
141+
]
142+
else: # Single-level index
143+
column_labels = [
144+
(label, agg_ops.lookup_agg_func(f)[1])
145+
for label in labels
146+
for f in func
147+
]
148+
result_cols_idx = pandas.MultiIndex.from_tuples(
149+
column_labels, names=[*self._block.column_labels.names, None]
150+
)
151+
return self._apply_aggs(aggregations, result_cols_idx)
152+
153+
def _apply_aggs(
154+
self, exprs: Sequence[agg_expressions.Aggregation], labels: pandas.Index
155+
):
156+
block, ids = self._block.apply_analytic(
157+
agg_exprs=exprs,
158+
window=self._window_spec,
159+
result_labels=labels,
104160
skip_null_groups=self._drop_null_groups,
105161
)
106162

@@ -115,24 +171,50 @@ def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
115171
)
116172
block = block.set_index(col_ids=index_ids)
117173

118-
labels = [self._block.col_id_to_label[col] for col in agg_col_ids]
119174
if self._skip_agg_column_id is not None:
120-
result_ids = [self._skip_agg_column_id, *result_ids]
121-
labels.insert(0, self._block.col_id_to_label[self._skip_agg_column_id])
175+
block = block.select_columns([self._skip_agg_column_id, *ids])
176+
else:
177+
block = block.select_columns(ids).with_column_labels(labels)
178+
179+
if self._is_series and (len(block.value_columns) == 1):
180+
import bigframes.series as series
181+
182+
return series.Series(block)
183+
else:
184+
import bigframes.dataframe as df
185+
186+
return df.DataFrame(block)
187+
188+
def _apply_aggregate_op(
189+
self,
190+
op: agg_ops.UnaryAggregateOp,
191+
):
192+
ids, labels = self._aggregated_columns()
193+
aggregations = [agg(col_id, op) for col_id in ids]
194+
return self._apply_aggs(aggregations, labels)
122195

123-
return block.select_columns(result_ids).with_column_labels(labels)
196+
def _aggregated_columns(self) -> tuple[Sequence[str], pandas.Index]:
197+
agg_col_ids = [
198+
col_id
199+
for col_id in self._value_column_ids
200+
if col_id != self._skip_agg_column_id
201+
]
202+
labels: pandas.Index = pandas.Index(
203+
[self._block.col_id_to_label[col] for col in agg_col_ids]
204+
)
205+
return agg_col_ids, labels
124206

125207

126208
def create_range_window(
127209
block: blocks.Block,
128210
window: pandas.Timedelta | numpy.timedelta64 | datetime.timedelta | str,
129211
*,
130-
value_column_ids: typing.Sequence[str] = tuple(),
212+
value_column_ids: Sequence[str] = tuple(),
131213
min_periods: int | None,
132214
on: str | None = None,
133-
closed: typing.Literal["right", "left", "both", "neither"],
215+
closed: Literal["right", "left", "both", "neither"],
134216
is_series: bool,
135-
grouping_keys: typing.Sequence[str] = tuple(),
217+
grouping_keys: Sequence[str] = tuple(),
136218
drop_null_groups: bool = True,
137219
) -> Window:
138220

@@ -184,3 +266,11 @@ def create_range_window(
184266
skip_agg_column_id=None if on is None else rolling_key_col_id,
185267
drop_null_groups=drop_null_groups,
186268
)
269+
270+
271+
def agg(input: str, op: agg_ops.AggregateOp) -> agg_expressions.Aggregation:
272+
if isinstance(op, agg_ops.UnaryAggregateOp):
273+
return agg_expressions.UnaryAggregation(op, ex.deref(input))
274+
else:
275+
assert isinstance(op, agg_ops.NullaryAggregateOp)
276+
return agg_expressions.NullaryAggregation(op)

tests/system/small/test_window.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,75 @@ def test_dataframe_window_agg_ops(scalars_dfs, windowing, agg_op):
228228
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
229229

230230

231+
@pytest.mark.parametrize(
232+
("windowing"),
233+
[
234+
pytest.param(lambda x: x.expanding(), id="expanding"),
235+
pytest.param(lambda x: x.rolling(3, min_periods=3), id="rolling"),
236+
pytest.param(
237+
lambda x: x.groupby(level=0).rolling(3, min_periods=3), id="rollinggroupby"
238+
),
239+
pytest.param(
240+
lambda x: x.groupby("int64_too").expanding(min_periods=2),
241+
id="expandinggroupby",
242+
),
243+
],
244+
)
245+
@pytest.mark.parametrize(
246+
("func"),
247+
[
248+
pytest.param("sum", id="sum_by_name"),
249+
pytest.param(np.sum, id="sum_by_by_np"),
250+
pytest.param([np.sum, np.mean], id="list_of_funcs"),
251+
pytest.param(
252+
{"int64_col": np.sum, "float64_col": "mean"}, id="dict_of_single_funcs"
253+
),
254+
pytest.param(
255+
{"int64_col": np.sum, "float64_col": ["mean", np.max]},
256+
id="dict_of_lists_and_single_funcs",
257+
),
258+
],
259+
)
260+
def test_dataframe_window_agg_func(scalars_dfs, windowing, func):
261+
bf_df, pd_df = scalars_dfs
262+
target_columns = ["int64_too", "float64_col", "bool_col", "int64_col"]
263+
index_column = "bool_col"
264+
bf_df = bf_df[target_columns].set_index(index_column)
265+
pd_df = pd_df[target_columns].set_index(index_column)
266+
267+
bf_result = windowing(bf_df).agg(func).to_pandas()
268+
269+
pd_result = windowing(pd_df).agg(func)
270+
271+
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
272+
273+
274+
def test_series_window_agg_single_func(scalars_dfs):
275+
bf_df, pd_df = scalars_dfs
276+
index_column = "bool_col"
277+
bf_series = bf_df.set_index(index_column).int64_too
278+
pd_series = pd_df.set_index(index_column).int64_too
279+
280+
bf_result = bf_series.expanding().agg("sum").to_pandas()
281+
282+
pd_result = pd_series.expanding().agg("sum")
283+
284+
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
285+
286+
287+
def test_series_window_agg_multi_func(scalars_dfs):
288+
bf_df, pd_df = scalars_dfs
289+
index_column = "bool_col"
290+
bf_series = bf_df.set_index(index_column).int64_too
291+
pd_series = pd_df.set_index(index_column).int64_too
292+
293+
bf_result = bf_series.expanding().agg(["sum", np.mean]).to_pandas()
294+
295+
pd_result = pd_series.expanding().agg(["sum", np.mean])
296+
297+
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
298+
299+
231300
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
232301
@pytest.mark.parametrize(
233302
"window", # skipped numpy timedelta because Pandas does not support it.

third_party/bigframes_vendored/pandas/core/window/rolling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,52 @@ def max(self):
3737
def min(self):
3838
"""Calculate the weighted window minimum."""
3939
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
40+
41+
def agg(self, func):
42+
"""
43+
Aggregate using one or more operations over the specified axis.
44+
45+
**Examples:**
46+
47+
>>> import bigframes.pandas as bpd
48+
49+
>>> df = bpd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
50+
>>> df
51+
A B C
52+
0 1 4 7
53+
1 2 5 8
54+
2 3 6 9
55+
<BLANKLINE>
56+
[3 rows x 3 columns]
57+
58+
>>> df.rolling(2).sum()
59+
A B C
60+
0 <NA> <NA> <NA>
61+
1 3 9 15
62+
2 5 11 17
63+
<BLANKLINE>
64+
[3 rows x 3 columns]
65+
66+
>>> df.rolling(2).agg({"A": "sum", "B": "min"})
67+
A B
68+
0 <NA> <NA>
69+
1 3 4
70+
2 5 5
71+
<BLANKLINE>
72+
[3 rows x 2 columns]
73+
74+
Args:
75+
func (function, str, list or dict):
76+
Function to use for aggregating the data.
77+
78+
Accepted combinations are:
79+
80+
- string function name
81+
- list of function names, e.g. ``['sum', 'mean']``
82+
- dict of axis labels -> function names or list of such.
83+
84+
Returns:
85+
Series or DataFrame
86+
87+
"""
88+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 commit comments

Comments
 (0)