@@ -2897,16 +2897,15 @@ def _get_cythonized_result(
28972897
28982898 ids , _ , ngroups = grouper .group_info
28992899 output : dict [base .OutputKey , np .ndarray ] = {}
2900- base_func = getattr (libgroupby , how )
2901-
2902- error_msg = ""
2903- for idx , obj in enumerate (self ._iterate_slices ()):
2904- name = obj .name
2905- values = obj ._values
29062900
2907- if numeric_only and not is_numeric_dtype (values .dtype ):
2908- continue
2901+ base_func = getattr (libgroupby , how )
2902+ base_func = partial (base_func , labels = ids )
2903+ if needs_ngroups :
2904+ base_func = partial (base_func , ngroups = ngroups )
2905+ if min_count is not None :
2906+ base_func = partial (base_func , min_count = min_count )
29092907
2908+ def blk_func (values : ArrayLike ) -> ArrayLike :
29102909 if aggregate :
29112910 result_sz = ngroups
29122911 else :
@@ -2915,54 +2914,31 @@ def _get_cythonized_result(
29152914 result = np .zeros (result_sz , dtype = cython_dtype )
29162915 if needs_2d :
29172916 result = result .reshape ((- 1 , 1 ))
2918- func = partial (base_func , result )
2917+ func = partial (base_func , out = result )
29192918
29202919 inferences = None
29212920
29222921 if needs_counts :
29232922 counts = np .zeros (self .ngroups , dtype = np .int64 )
2924- func = partial (func , counts )
2923+ func = partial (func , counts = counts )
29252924
29262925 if needs_values :
29272926 vals = values
29282927 if pre_processing :
2929- try :
2930- vals , inferences = pre_processing (vals )
2931- except TypeError as err :
2932- error_msg = str (err )
2933- howstr = how .replace ("group_" , "" )
2934- warnings .warn (
2935- "Dropping invalid columns in "
2936- f"{ type (self ).__name__ } .{ howstr } is deprecated. "
2937- "In a future version, a TypeError will be raised. "
2938- f"Before calling .{ howstr } , select only columns which "
2939- "should be valid for the function." ,
2940- FutureWarning ,
2941- stacklevel = 3 ,
2942- )
2943- continue
2928+ vals , inferences = pre_processing (vals )
2929+
29442930 vals = vals .astype (cython_dtype , copy = False )
29452931 if needs_2d :
29462932 vals = vals .reshape ((- 1 , 1 ))
2947- func = partial (func , vals )
2948-
2949- func = partial (func , ids )
2950-
2951- if min_count is not None :
2952- func = partial (func , min_count )
2933+ func = partial (func , values = vals )
29532934
29542935 if needs_mask :
29552936 mask = isna (values ).view (np .uint8 )
2956- func = partial (func , mask )
2957-
2958- if needs_ngroups :
2959- func = partial (func , ngroups )
2937+ func = partial (func , mask = mask )
29602938
29612939 if needs_nullable :
29622940 is_nullable = isinstance (values , BaseMaskedArray )
29632941 func = partial (func , nullable = is_nullable )
2964- if post_processing :
2965- post_processing = partial (post_processing , nullable = is_nullable )
29662942
29672943 func (** kwargs ) # Call func to modify indexer values in place
29682944
@@ -2973,9 +2949,38 @@ def _get_cythonized_result(
29732949 result = algorithms .take_nd (values , result )
29742950
29752951 if post_processing :
2976- result = post_processing (result , inferences )
2952+ pp_kwargs = {}
2953+ if needs_nullable :
2954+ pp_kwargs ["nullable" ] = isinstance (values , BaseMaskedArray )
29772955
2978- key = base .OutputKey (label = name , position = idx )
2956+ result = post_processing (result , inferences , ** pp_kwargs )
2957+
2958+ return result
2959+
2960+ error_msg = ""
2961+ for idx , obj in enumerate (self ._iterate_slices ()):
2962+ values = obj ._values
2963+
2964+ if numeric_only and not is_numeric_dtype (values .dtype ):
2965+ continue
2966+
2967+ try :
2968+ result = blk_func (values )
2969+ except TypeError as err :
2970+ error_msg = str (err )
2971+ howstr = how .replace ("group_" , "" )
2972+ warnings .warn (
2973+ "Dropping invalid columns in "
2974+ f"{ type (self ).__name__ } .{ howstr } is deprecated. "
2975+ "In a future version, a TypeError will be raised. "
2976+ f"Before calling .{ howstr } , select only columns which "
2977+ "should be valid for the function." ,
2978+ FutureWarning ,
2979+ stacklevel = 3 ,
2980+ )
2981+ continue
2982+
2983+ key = base .OutputKey (label = obj .name , position = idx )
29792984 output [key ] = result
29802985
29812986 # error_msg is "" on an frame/series with no rows or columns
0 commit comments