4646)
4747from .dtypes import DatetimeTZDtype , ExtensionDtype , PeriodDtype
4848from .generic import (
49+ ABCDataFrame ,
4950 ABCDatetimeArray ,
5051 ABCDatetimeIndex ,
5152 ABCPeriodArray ,
@@ -95,12 +96,13 @@ def maybe_downcast_to_dtype(result, dtype):
9596 """ try to cast to the specified dtype (e.g. convert back to bool/int
9697 or could be an astype of float64->float32
9798 """
99+ do_round = False
98100
99101 if is_scalar (result ):
100102 return result
101-
102- def trans ( x ):
103- return x
103+ elif isinstance ( result , ABCDataFrame ):
104+ # occurs in pivot_table doctest
105+ return result
104106
105107 if isinstance (dtype , str ):
106108 if dtype == "infer" :
@@ -118,83 +120,115 @@ def trans(x):
118120 elif inferred_type == "floating" :
119121 dtype = "int64"
120122 if issubclass (result .dtype .type , np .number ):
121-
122- def trans (x ): # noqa
123- return x .round ()
123+ do_round = True
124124
125125 else :
126126 dtype = "object"
127127
128- if isinstance (dtype , str ):
129128 dtype = np .dtype (dtype )
130129
131- try :
130+ converted = maybe_downcast_numeric (result , dtype , do_round )
131+ if converted is not result :
132+ return converted
133+
134+ # a datetimelike
135+ # GH12821, iNaT is casted to float
136+ if dtype .kind in ["M" , "m" ] and result .dtype .kind in ["i" , "f" ]:
137+ try :
138+ result = result .astype (dtype )
139+ except Exception :
140+ if dtype .tz :
141+ # convert to datetime and change timezone
142+ from pandas import to_datetime
143+
144+ result = to_datetime (result ).tz_localize ("utc" )
145+ result = result .tz_convert (dtype .tz )
146+
147+ elif dtype .type is Period :
148+ # TODO(DatetimeArray): merge with previous elif
149+ from pandas .core .arrays import PeriodArray
132150
151+ try :
152+ return PeriodArray (result , freq = dtype .freq )
153+ except TypeError :
154+ # e.g. TypeError: int() argument must be a string, a
155+ # bytes-like object or a number, not 'Period
156+ pass
157+
158+ return result
159+
160+
161+ def maybe_downcast_numeric (result , dtype , do_round : bool = False ):
162+ """
163+ Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
164+
165+ Parameters
166+ ----------
167+ result : ndarray or ExtensionArray
168+ dtype : np.dtype or ExtensionDtype
169+ do_round : bool
170+
171+ Returns
172+ -------
173+ ndarray or ExtensionArray
174+ """
175+ if not isinstance (dtype , np .dtype ):
176+ # e.g. SparseDtype has no itemsize attr
177+ return result
178+
179+ if isinstance (result , list ):
180+ # reached via groupoby.agg _ohlc; really this should be handled
181+ # earlier
182+ result = np .array (result )
183+
184+ def trans (x ):
185+ if do_round :
186+ return x .round ()
187+ return x
188+
189+ if dtype .kind == result .dtype .kind :
133190 # don't allow upcasts here (except if empty)
134- if dtype .kind == result .dtype .kind :
135- if result .dtype .itemsize <= dtype .itemsize and np .prod (result .shape ):
136- return result
191+ if result .dtype .itemsize <= dtype .itemsize and result .size :
192+ return result
137193
138- if is_bool_dtype (dtype ) or is_integer_dtype (dtype ):
194+ if is_bool_dtype (dtype ) or is_integer_dtype (dtype ):
139195
196+ if not result .size :
140197 # if we don't have any elements, just astype it
141- if not np .prod (result .shape ):
142- return trans (result ).astype (dtype )
198+ return trans (result ).astype (dtype )
143199
144- # do a test on the first element, if it fails then we are done
145- r = result .ravel ()
146- arr = np .array ([r [0 ]])
200+ # do a test on the first element, if it fails then we are done
201+ r = result .ravel ()
202+ arr = np .array ([r [0 ]])
147203
204+ if isna (arr ).any () or not np .allclose (arr , trans (arr ).astype (dtype ), rtol = 0 ):
148205 # if we have any nulls, then we are done
149- if isna (arr ).any () or not np .allclose (
150- arr , trans (arr ).astype (dtype ), rtol = 0
151- ):
152- return result
206+ return result
153207
208+ elif not isinstance (r [0 ], (np .integer , np .floating , np .bool , int , float , bool )):
154209 # a comparable, e.g. a Decimal may slip in here
155- elif not isinstance (
156- r [0 ], (np .integer , np .floating , np .bool , int , float , bool )
157- ):
158- return result
210+ return result
159211
160- if (
161- issubclass (result .dtype .type , (np .object_ , np .number ))
162- and notna (result ).all ()
163- ):
164- new_result = trans (result ).astype (dtype )
165- try :
166- if np .allclose (new_result , result , rtol = 0 ):
167- return new_result
168- except Exception :
169-
170- # comparison of an object dtype with a number type could
171- # hit here
172- if (new_result == result ).all ():
173- return new_result
174- elif issubclass (dtype .type , np .floating ) and not is_bool_dtype (result .dtype ):
175- return result .astype (dtype )
176-
177- # a datetimelike
178- # GH12821, iNaT is casted to float
179- elif dtype .kind in ["M" , "m" ] and result .dtype .kind in ["i" , "f" ]:
212+ if (
213+ issubclass (result .dtype .type , (np .object_ , np .number ))
214+ and notna (result ).all ()
215+ ):
216+ new_result = trans (result ).astype (dtype )
180217 try :
181- result = result .astype (dtype )
218+ if np .allclose (new_result , result , rtol = 0 ):
219+ return new_result
182220 except Exception :
183- if dtype .tz :
184- # convert to datetime and change timezone
185- from pandas import to_datetime
186-
187- result = to_datetime (result ).tz_localize ("utc" )
188- result = result .tz_convert (dtype .tz )
189-
190- elif dtype .type == Period :
191- # TODO(DatetimeArray): merge with previous elif
192- from pandas .core .arrays import PeriodArray
193-
194- return PeriodArray (result , freq = dtype .freq )
195-
196- except Exception :
197- pass
221+ # comparison of an object dtype with a number type could
222+ # hit here
223+ if (new_result == result ).all ():
224+ return new_result
225+
226+ elif (
227+ issubclass (dtype .type , np .floating )
228+ and not is_bool_dtype (result .dtype )
229+ and not is_string_dtype (result .dtype )
230+ ):
231+ return result .astype (dtype )
198232
199233 return result
200234
0 commit comments