@@ -343,7 +343,7 @@ def get_group_levels(self):
343343
344344 _cython_arity = {"ohlc" : 4 } # OHLC
345345
346- _name_functions = {"ohlc" : lambda * args : ["open" , "high" , "low" , "close" ]}
346+ _name_functions = {"ohlc" : ["open" , "high" , "low" , "close" ]}
347347
348348 def _is_builtin_func (self , arg ):
349349 """
@@ -433,6 +433,13 @@ def _cython_operation(
433433 assert kind in ["transform" , "aggregate" ]
434434 orig_values = values
435435
436+ if values .ndim > 2 :
437+ raise NotImplementedError ("number of dimensions is currently limited to 2" )
438+ elif values .ndim == 2 :
439+ # Note: it is *not* the case that axis is always 0 for 1-dim values,
440+ # as we can have 1D ExtensionArrays that we need to treat as 2D
441+ assert axis == 1 , axis
442+
436443 # can we do this operation with our cython functions
437444 # if not raise NotImplementedError
438445
@@ -545,10 +552,7 @@ def _cython_operation(
545552 if vdim == 1 and arity == 1 :
546553 result = result [:, 0 ]
547554
548- if how in self ._name_functions :
549- names = self ._name_functions [how ]() # type: Optional[List[str]]
550- else :
551- names = None
555+ names = self ._name_functions .get (how , None ) # type: Optional[List[str]]
552556
553557 if swapped :
554558 result = result .swapaxes (0 , axis )
@@ -578,10 +582,7 @@ def _aggregate(
578582 is_datetimelike : bool ,
579583 min_count : int = - 1 ,
580584 ):
581- if values .ndim > 2 :
582- # punting for now
583- raise NotImplementedError ("number of dimensions is currently limited to 2" )
584- elif agg_func is libgroupby .group_nth :
585+ if agg_func is libgroupby .group_nth :
585586 # different signature from the others
586587 # TODO: should we be using min_count instead of hard-coding it?
587588 agg_func (result , counts , values , comp_ids , rank = 1 , min_count = - 1 )
@@ -595,11 +596,7 @@ def _transform(
595596 ):
596597
597598 comp_ids , _ , ngroups = self .group_info
598- if values .ndim > 2 :
599- # punting for now
600- raise NotImplementedError ("number of dimensions is currently limited to 2" )
601- else :
602- transform_func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
599+ transform_func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
603600
604601 return result
605602
0 commit comments