11import abc
22import inspect
3- from typing import TYPE_CHECKING , Any , Dict , Iterator , Optional , Tuple , Type
3+ from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Tuple , Type , cast
44
55import numpy as np
66
77from pandas ._config import option_context
88
9- from pandas ._typing import AggFuncType , Axis , FrameOrSeriesUnion
9+ from pandas ._typing import (
10+ AggFuncType ,
11+ AggFuncTypeBase ,
12+ AggFuncTypeDict ,
13+ Axis ,
14+ FrameOrSeriesUnion ,
15+ )
1016from pandas .util ._decorators import cache_readonly
1117
1218from pandas .core .dtypes .common import (
1723)
1824from pandas .core .dtypes .generic import ABCSeries
1925
26+ from pandas .core .aggregation import agg_dict_like , agg_list_like
2027from pandas .core .construction import create_series_with_explicit_dtype
2128
2229if TYPE_CHECKING :
2734
2835def frame_apply (
2936 obj : "DataFrame" ,
37+ how : str ,
3038 func : AggFuncType ,
3139 axis : Axis = 0 ,
3240 raw : bool = False ,
@@ -44,6 +52,7 @@ def frame_apply(
4452
4553 return klass (
4654 obj ,
55+ how ,
4756 func ,
4857 raw = raw ,
4958 result_type = result_type ,
@@ -84,13 +93,16 @@ def wrap_results_for_axis(
8493 def __init__ (
8594 self ,
8695 obj : "DataFrame" ,
96+ how : str ,
8797 func ,
8898 raw : bool ,
8999 result_type : Optional [str ],
90100 args ,
91101 kwds ,
92102 ):
103+ assert how in ("apply" , "agg" )
93104 self .obj = obj
105+ self .how = how
94106 self .raw = raw
95107 self .args = args or ()
96108 self .kwds = kwds or {}
@@ -104,15 +116,19 @@ def __init__(
104116 self .result_type = result_type
105117
106118 # curry if needed
107- if (kwds or args ) and not isinstance (func , (np .ufunc , str )):
119+ if (
120+ (kwds or args )
121+ and not isinstance (func , (np .ufunc , str ))
122+ and not is_list_like (func )
123+ ):
108124
109125 def f (x ):
110126 return func (x , * args , ** kwds )
111127
112128 else :
113129 f = func
114130
115- self .f = f
131+ self .f : AggFuncType = f
116132
117133 @property
118134 def res_columns (self ) -> "Index" :
@@ -139,6 +155,54 @@ def agg_axis(self) -> "Index":
139155 return self .obj ._get_agg_axis (self .axis )
140156
141157 def get_result (self ):
158+ if self .how == "apply" :
159+ return self .apply ()
160+ else :
161+ return self .agg ()
162+
163+ def agg (self ) -> Tuple [Optional [FrameOrSeriesUnion ], Optional [bool ]]:
164+ """
165+ Provide an implementation for the aggregators.
166+
167+ Returns
168+ -------
169+ tuple of result, how.
170+
171+ Notes
172+ -----
173+ how can be a string describe the required post-processing, or
174+ None if not required.
175+ """
176+ obj = self .obj
177+ arg = self .f
178+ args = self .args
179+ kwargs = self .kwds
180+
181+ _axis = kwargs .pop ("_axis" , None )
182+ if _axis is None :
183+ _axis = getattr (obj , "axis" , 0 )
184+
185+ if isinstance (arg , str ):
186+ return obj ._try_aggregate_string_function (arg , * args , ** kwargs ), None
187+ elif is_dict_like (arg ):
188+ arg = cast (AggFuncTypeDict , arg )
189+ return agg_dict_like (obj , arg , _axis ), True
190+ elif is_list_like (arg ):
191+ # we require a list, but not a 'str'
192+ arg = cast (List [AggFuncTypeBase ], arg )
193+ return agg_list_like (obj , arg , _axis = _axis ), None
194+ else :
195+ result = None
196+
197+ if callable (arg ):
198+ f = obj ._get_cython_func (arg )
199+ if f and not args and not kwargs :
200+ return getattr (obj , f )(), None
201+
202+ # caller can react
203+ return result , True
204+
205+ def apply (self ) -> FrameOrSeriesUnion :
142206 """ compute the results """
143207 # dispatch to agg
144208 if is_list_like (self .f ) or is_dict_like (self .f ):
@@ -191,6 +255,8 @@ def apply_empty_result(self):
191255 we will try to apply the function to an empty
192256 series in order to see if this is a reduction function
193257 """
258+ assert callable (self .f )
259+
194260 # we are not asked to reduce or infer reduction
195261 # so just return a copy of the existing object
196262 if self .result_type not in ["reduce" , None ]:
@@ -246,6 +312,8 @@ def wrapper(*args, **kwargs):
246312 return self .obj ._constructor_sliced (result , index = self .agg_axis )
247313
248314 def apply_broadcast (self , target : "DataFrame" ) -> "DataFrame" :
315+ assert callable (self .f )
316+
249317 result_values = np .empty_like (target .values )
250318
251319 # axis which we want to compare compliance
@@ -279,6 +347,8 @@ def apply_standard(self):
279347 return self .wrap_results (results , res_index )
280348
281349 def apply_series_generator (self ) -> Tuple [ResType , "Index" ]:
350+ assert callable (self .f )
351+
282352 series_gen = self .series_generator
283353 res_index = self .result_index
284354
0 commit comments