@@ -568,3 +568,40 @@ def test__get_dtype(input_param, result):
568568def test__get_dtype_fails (input_param ):
569569 # python objects
570570 pytest .raises (TypeError , com ._get_dtype , input_param )
571+
572+
573+ @pytest .mark .parametrize ('input_param,result' , [
574+ (int , np .dtype (int ).type ),
575+ ('int32' , np .int32 ),
576+ (float , np .dtype (float ).type ),
577+ ('float64' , np .float64 ),
578+ (np .dtype ('float64' ), np .float64 ),
579+ (str , np .dtype (str ).type ),
580+ (pd .Series ([1 , 2 ], dtype = np .dtype ('int16' )), np .int16 ),
581+ (pd .Series (['a' , 'b' ]), np .object_ ),
582+ (pd .Index ([1 , 2 ], dtype = 'int64' ), np .int64 ),
583+ (pd .Index (['a' , 'b' ]), np .object_ ),
584+ ('category' , com .CategoricalDtypeType ),
585+ (pd .Categorical (['a' , 'b' ]).dtype , com .CategoricalDtypeType ),
586+ (pd .Categorical (['a' , 'b' ]), com .CategoricalDtypeType ),
587+ (pd .CategoricalIndex (['a' , 'b' ]).dtype , com .CategoricalDtypeType ),
588+ (pd .CategoricalIndex (['a' , 'b' ]), com .CategoricalDtypeType ),
589+ (pd .DatetimeIndex ([1 , 2 ]), np .datetime64 ),
590+ (pd .DatetimeIndex ([1 , 2 ]).dtype , np .datetime64 ),
591+ ('<M8[ns]' , np .datetime64 ),
592+ (pd .DatetimeIndex ([1 , 2 ], tz = 'Europe/London' ), com .DatetimeTZDtypeType ),
593+ (pd .DatetimeIndex ([1 , 2 ], tz = 'Europe/London' ).dtype ,
594+ com .DatetimeTZDtypeType ),
595+ ('datetime64[ns, Europe/London]' , com .DatetimeTZDtypeType ),
596+ (pd .SparseSeries ([1 , 2 ], dtype = 'int32' ), np .int32 ),
597+ (pd .SparseSeries ([1 , 2 ], dtype = 'int32' ).dtype , np .int32 ),
598+ (PeriodDtype (freq = 'D' ), com .PeriodDtypeType ),
599+ ('period[D]' , com .PeriodDtypeType ),
600+ (IntervalDtype (), com .IntervalDtypeType ),
601+ (None , type (None )),
602+ (1 , type (None )),
603+ (1.2 , type (None )),
604+ (pd .DataFrame ([1 , 2 ]), type (None )), # composite dtype
605+ ])
606+ def test__get_dtype_type (input_param , result ):
607+ assert com ._get_dtype_type (input_param ) == result
0 commit comments