@@ -403,7 +403,7 @@ def time_srs_bfill(self):
403403
404404class GroupByMethods :
405405
406- param_names = ["dtype" , "method" , "application" ]
406+ param_names = ["dtype" , "method" , "application" , "ncols" ]
407407 params = [
408408 ["int" , "float" , "object" , "datetime" , "uint" ],
409409 [
@@ -443,15 +443,23 @@ class GroupByMethods:
443443 "var" ,
444444 ],
445445 ["direct" , "transformation" ],
446+ [1 , 2 , 5 , 10 ],
446447 ]
447448
448- def setup (self , dtype , method , application ):
449+ def setup (self , dtype , method , application , ncols ):
449450 if method in method_blocklist .get (dtype , {}):
450451 raise NotImplementedError # skip benchmark
452+
453+ if ncols != 1 and method in ["value_counts" , "unique" ]:
454+ # DataFrameGroupBy doesn't have these methods
455+ raise NotImplementedError
456+
451457 ngroups = 1000
452458 size = ngroups * 2
453- rng = np .arange (ngroups )
454- values = rng .take (np .random .randint (0 , ngroups , size = size ))
459+ rng = np .arange (ngroups ).reshape (- 1 , 1 )
460+ rng = np .broadcast_to (rng , (len (rng ), ncols ))
461+ taker = np .random .randint (0 , ngroups , size = size )
462+ values = rng .take (taker , axis = 0 )
455463 if dtype == "int" :
456464 key = np .random .randint (0 , size , size = size )
457465 elif dtype == "uint" :
@@ -465,22 +473,27 @@ def setup(self, dtype, method, application):
465473 elif dtype == "datetime" :
466474 key = date_range ("1/1/2011" , periods = size , freq = "s" )
467475
468- df = DataFrame ({"values" : values , "key" : key })
476+ cols = [f"values{ n } " for n in range (ncols )]
477+ df = DataFrame (values , columns = cols )
478+ df ["key" ] = key
479+
480+ if len (cols ) == 1 :
481+ cols = cols [0 ]
469482
470483 if application == "transform" :
471484 if method == "describe" :
472485 raise NotImplementedError
473486
474- self .as_group_method = lambda : df .groupby ("key" )["values" ].transform (method )
475- self .as_field_method = lambda : df .groupby ("values" )["key" ].transform (method )
487+ self .as_group_method = lambda : df .groupby ("key" )[cols ].transform (method )
488+ self .as_field_method = lambda : df .groupby (cols )["key" ].transform (method )
476489 else :
477- self .as_group_method = getattr (df .groupby ("key" )["values" ], method )
478- self .as_field_method = getattr (df .groupby ("values" )["key" ], method )
490+ self .as_group_method = getattr (df .groupby ("key" )[cols ], method )
491+ self .as_field_method = getattr (df .groupby (cols )["key" ], method )
479492
480- def time_dtype_as_group (self , dtype , method , application ):
493+ def time_dtype_as_group (self , dtype , method , application , ncols ):
481494 self .as_group_method ()
482495
483- def time_dtype_as_field (self , dtype , method , application ):
496+ def time_dtype_as_field (self , dtype , method , application , ncols ):
484497 self .as_field_method ()
485498
486499
0 commit comments