1515from __future__ import annotations
1616
1717import datetime
18- import typing
18+ from typing import Literal , Mapping , Sequence , TYPE_CHECKING , Union
1919
2020import bigframes_vendored .pandas .core .window .rolling as vendored_pandas_rolling
2121import numpy
2222import pandas
2323
2424from bigframes import dtypes
25+ from bigframes .core import agg_expressions
2526from bigframes .core import expression as ex
26- from bigframes .core import log_adapter , ordering , window_spec
27+ from bigframes .core import log_adapter , ordering , utils , window_spec
2728import bigframes .core .blocks as blocks
2829from bigframes .core .window import ordering as window_ordering
2930import bigframes .operations .aggregations as agg_ops
3031
32+ if TYPE_CHECKING :
33+ import bigframes .dataframe as df
34+ import bigframes .series as series
35+
3136
3237@log_adapter .class_logger
3338class Window (vendored_pandas_rolling .Window ):
@@ -37,7 +42,7 @@ def __init__(
3742 self ,
3843 block : blocks .Block ,
3944 window_spec : window_spec .WindowSpec ,
40- value_column_ids : typing . Sequence [str ],
45+ value_column_ids : Sequence [str ],
4146 drop_null_groups : bool = True ,
4247 is_series : bool = False ,
4348 skip_agg_column_id : str | None = None ,
@@ -52,55 +57,106 @@ def __init__(
5257 self ._skip_agg_column_id = skip_agg_column_id
5358
5459 def count (self ):
55- return self ._apply_aggregate (agg_ops .count_op )
60+ return self ._apply_aggregate_op (agg_ops .count_op )
5661
5762 def sum (self ):
58- return self ._apply_aggregate (agg_ops .sum_op )
63+ return self ._apply_aggregate_op (agg_ops .sum_op )
5964
6065 def mean (self ):
61- return self ._apply_aggregate (agg_ops .mean_op )
66+ return self ._apply_aggregate_op (agg_ops .mean_op )
6267
6368 def var (self ):
64- return self ._apply_aggregate (agg_ops .var_op )
69+ return self ._apply_aggregate_op (agg_ops .var_op )
6570
6671 def std (self ):
67- return self ._apply_aggregate (agg_ops .std_op )
72+ return self ._apply_aggregate_op (agg_ops .std_op )
6873
6974 def max (self ):
70- return self ._apply_aggregate (agg_ops .max_op )
75+ return self ._apply_aggregate_op (agg_ops .max_op )
7176
7277 def min (self ):
73- return self ._apply_aggregate (agg_ops .min_op )
78+ return self ._apply_aggregate_op (agg_ops .min_op )
7479
75- def _apply_aggregate (
76- self ,
77- op : agg_ops .UnaryAggregateOp ,
78- ):
79- agg_block = self ._aggregate_block (op )
80+ def agg (self , func ) -> Union [df .DataFrame , series .Series ]:
81+ if utils .is_dict_like (func ):
82+ return self ._agg_dict (func )
83+ elif utils .is_list_like (func ):
84+ return self ._agg_list (func )
85+ else :
86+ return self ._agg_func (func )
8087
81- if self ._is_series :
82- from bigframes .series import Series
88+ aggregate = agg
89+
90+ def _agg_func (self , func ) -> df .DataFrame :
91+ ids , labels = self ._aggregated_columns ()
92+ aggregations = [agg (col_id , agg_ops .lookup_agg_func (func )[0 ]) for col_id in ids ]
93+ return self ._apply_aggs (aggregations , labels )
94+
95+ def _agg_dict (self , func : Mapping ) -> df .DataFrame :
96+ aggregations : list [agg_expressions .Aggregation ] = []
97+ column_labels = []
98+ function_labels = []
8399
84- return Series (agg_block )
100+ want_aggfunc_level = any (utils .is_list_like (aggs ) for aggs in func .values ())
101+
102+ for label , funcs_for_id in func .items ():
103+ col_id = self ._block .label_to_col_id [label ][- 1 ] # get last matching column
104+ func_list = (
105+ funcs_for_id if utils .is_list_like (funcs_for_id ) else [funcs_for_id ]
106+ )
107+ for f in func_list :
108+ f_op , f_label = agg_ops .lookup_agg_func (f )
109+ aggregations .append (agg (col_id , f_op ))
110+ column_labels .append (label )
111+ function_labels .append (f_label )
112+ if want_aggfunc_level :
113+ result_labels : pandas .Index = utils .combine_indices (
114+ pandas .Index (column_labels ),
115+ pandas .Index (function_labels ),
116+ )
85117 else :
86- from bigframes . dataframe import DataFrame
118+ result_labels = pandas . Index ( column_labels )
87119
88- # Preserve column order.
89- column_labels = [
90- self ._block .col_id_to_label [col_id ] for col_id in self ._value_column_ids
91- ]
92- return DataFrame (agg_block )._reindex_columns (column_labels )
120+ return self ._apply_aggs (aggregations , result_labels )
93121
94- def _aggregate_block (self , op : agg_ops .UnaryAggregateOp ) -> blocks .Block :
95- agg_col_ids = [
96- col_id
97- for col_id in self ._value_column_ids
98- if col_id != self ._skip_agg_column_id
122+ def _agg_list (self , func : Sequence ) -> df .DataFrame :
123+ ids , labels = self ._aggregated_columns ()
124+ aggregations = [
125+ agg (col_id , agg_ops .lookup_agg_func (f )[0 ]) for col_id in ids for f in func
99126 ]
100- block , result_ids = self ._block .multi_apply_window_op (
101- agg_col_ids ,
102- op ,
103- self ._window_spec ,
127+
128+ if self ._is_series :
129+ # if series, no need to rebuild
130+ result_cols_idx = pandas .Index (
131+ [agg_ops .lookup_agg_func (f )[1 ] for f in func ]
132+ )
133+ else :
134+ if self ._block .column_labels .nlevels > 1 :
135+ # Restructure MultiIndex for proper format: (idx1, idx2, func)
136+ # rather than ((idx1, idx2), func).
137+ column_labels = [
138+ tuple (label ) + (agg_ops .lookup_agg_func (f )[1 ],)
139+ for label in labels .to_frame (index = False ).to_numpy ()
140+ for f in func
141+ ]
142+ else : # Single-level index
143+ column_labels = [
144+ (label , agg_ops .lookup_agg_func (f )[1 ])
145+ for label in labels
146+ for f in func
147+ ]
148+ result_cols_idx = pandas .MultiIndex .from_tuples (
149+ column_labels , names = [* self ._block .column_labels .names , None ]
150+ )
151+ return self ._apply_aggs (aggregations , result_cols_idx )
152+
153+ def _apply_aggs (
154+ self , exprs : Sequence [agg_expressions .Aggregation ], labels : pandas .Index
155+ ):
156+ block , ids = self ._block .apply_analytic (
157+ agg_exprs = exprs ,
158+ window = self ._window_spec ,
159+ result_labels = labels ,
104160 skip_null_groups = self ._drop_null_groups ,
105161 )
106162
@@ -115,24 +171,50 @@ def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
115171 )
116172 block = block .set_index (col_ids = index_ids )
117173
118- labels = [self ._block .col_id_to_label [col ] for col in agg_col_ids ]
119174 if self ._skip_agg_column_id is not None :
120- result_ids = [self ._skip_agg_column_id , * result_ids ]
121- labels .insert (0 , self ._block .col_id_to_label [self ._skip_agg_column_id ])
175+ block = block .select_columns ([self ._skip_agg_column_id , * ids ])
176+ else :
177+ block = block .select_columns (ids ).with_column_labels (labels )
178+
179+ if self ._is_series and (len (block .value_columns ) == 1 ):
180+ import bigframes .series as series
181+
182+ return series .Series (block )
183+ else :
184+ import bigframes .dataframe as df
185+
186+ return df .DataFrame (block )
187+
188+ def _apply_aggregate_op (
189+ self ,
190+ op : agg_ops .UnaryAggregateOp ,
191+ ):
192+ ids , labels = self ._aggregated_columns ()
193+ aggregations = [agg (col_id , op ) for col_id in ids ]
194+ return self ._apply_aggs (aggregations , labels )
122195
123- return block .select_columns (result_ids ).with_column_labels (labels )
196+ def _aggregated_columns (self ) -> tuple [Sequence [str ], pandas .Index ]:
197+ agg_col_ids = [
198+ col_id
199+ for col_id in self ._value_column_ids
200+ if col_id != self ._skip_agg_column_id
201+ ]
202+ labels : pandas .Index = pandas .Index (
203+ [self ._block .col_id_to_label [col ] for col in agg_col_ids ]
204+ )
205+ return agg_col_ids , labels
124206
125207
126208def create_range_window (
127209 block : blocks .Block ,
128210 window : pandas .Timedelta | numpy .timedelta64 | datetime .timedelta | str ,
129211 * ,
130- value_column_ids : typing . Sequence [str ] = tuple (),
212+ value_column_ids : Sequence [str ] = tuple (),
131213 min_periods : int | None ,
132214 on : str | None = None ,
133- closed : typing . Literal ["right" , "left" , "both" , "neither" ],
215+ closed : Literal ["right" , "left" , "both" , "neither" ],
134216 is_series : bool ,
135- grouping_keys : typing . Sequence [str ] = tuple (),
217+ grouping_keys : Sequence [str ] = tuple (),
136218 drop_null_groups : bool = True ,
137219) -> Window :
138220
@@ -184,3 +266,11 @@ def create_range_window(
184266 skip_agg_column_id = None if on is None else rolling_key_col_id ,
185267 drop_null_groups = drop_null_groups ,
186268 )
269+
270+
271+ def agg (input : str , op : agg_ops .AggregateOp ) -> agg_expressions .Aggregation :
272+ if isinstance (op , agg_ops .UnaryAggregateOp ):
273+ return agg_expressions .UnaryAggregation (op , ex .deref (input ))
274+ else :
275+ assert isinstance (op , agg_ops .NullaryAggregateOp )
276+ return agg_expressions .NullaryAggregation (op )
0 commit comments