1111import pandas .core .algorithms as algos
1212import pandas .core .nanops as nanops
1313from pandas .compat import zip
14+ from pandas import to_timedelta , to_datetime
15+ from pandas .types .common import is_datetime64_dtype , is_timedelta64_dtype
1416
1517import numpy as np
1618
@@ -81,14 +83,17 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3,
8183 array([1, 1, 1, 1, 1], dtype=int64)
8284 """
8385 # NOTE: this binning code is changed a bit from histogram for var(x) == 0
86+
87+ # for handling the cut for datetime and timedelta objects
88+ x_is_series , series_index , name , x = _preprocess_for_cut (x )
89+ x , dtype = _coerce_to_type (x )
90+
8491 if not np .iterable (bins ):
8592 if is_scalar (bins ) and bins < 1 :
8693 raise ValueError ("`bins` should be a positive integer." )
87- try : # for array-like
88- sz = x .size
89- except AttributeError :
90- x = np .asarray (x )
91- sz = x .size
94+
95+ sz = x .size
96+
9297 if sz == 0 :
9398 raise ValueError ('Cannot cut empty array' )
9499 # handle empty arrays. Can't determine range, so use 0-1.
@@ -114,9 +119,12 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3,
114119 if (np .diff (bins ) < 0 ).any ():
115120 raise ValueError ('bins must increase monotonically.' )
116121
117- return _bins_to_cuts (x , bins , right = right , labels = labels ,
118- retbins = retbins , precision = precision ,
119- include_lowest = include_lowest )
122+ fac , bins = _bins_to_cuts (x , bins , right = right , labels = labels ,
123+ precision = precision ,
124+ include_lowest = include_lowest , dtype = dtype )
125+
126+ return _postprocess_for_cut (fac , bins , retbins , x_is_series ,
127+ series_index , name )
120128
121129
122130def qcut (x , q , labels = None , retbins = False , precision = 3 ):
@@ -166,26 +174,26 @@ def qcut(x, q, labels=None, retbins=False, precision=3):
166174 >>> pd.qcut(range(5), 4, labels=False)
167175 array([0, 0, 1, 2, 3], dtype=int64)
168176 """
177+ x_is_series , series_index , name , x = _preprocess_for_cut (x )
178+
179+ x , dtype = _coerce_to_type (x )
180+
169181 if is_integer (q ):
170182 quantiles = np .linspace (0 , 1 , q + 1 )
171183 else :
172184 quantiles = q
173185 bins = algos .quantile (x , quantiles )
174- return _bins_to_cuts (x , bins , labels = labels , retbins = retbins ,
175- precision = precision , include_lowest = True )
186+ fac , bins = _bins_to_cuts (x , bins , labels = labels ,
187+ precision = precision , include_lowest = True ,
188+ dtype = dtype )
176189
190+ return _postprocess_for_cut (fac , bins , retbins , x_is_series ,
191+ series_index , name )
177192
178- def _bins_to_cuts (x , bins , right = True , labels = None , retbins = False ,
179- precision = 3 , name = None , include_lowest = False ):
180- x_is_series = isinstance (x , Series )
181- series_index = None
182-
183- if x_is_series :
184- series_index = x .index
185- if name is None :
186- name = x .name
187193
188- x = np .asarray (x )
194+ def _bins_to_cuts (x , bins , right = True , labels = None ,
195+ precision = 3 , include_lowest = False ,
196+ dtype = None ):
189197
190198 side = 'left' if right else 'right'
191199 ids = bins .searchsorted (x , side = side )
@@ -205,7 +213,8 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False,
205213 while True :
206214 try :
207215 levels = _format_levels (bins , precision , right = right ,
208- include_lowest = include_lowest )
216+ include_lowest = include_lowest ,
217+ dtype = dtype )
209218 except ValueError :
210219 increases += 1
211220 precision += 1
@@ -229,18 +238,12 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False,
229238 fac = fac .astype (np .float64 )
230239 np .putmask (fac , na_mask , np .nan )
231240
232- if x_is_series :
233- fac = Series (fac , index = series_index , name = name )
234-
235- if not retbins :
236- return fac
237-
238241 return fac , bins
239242
240243
241244def _format_levels (bins , prec , right = True ,
242- include_lowest = False ):
243- fmt = lambda v : _format_label (v , precision = prec )
245+ include_lowest = False , dtype = None ):
246+ fmt = lambda v : _format_label (v , precision = prec , dtype = dtype )
244247 if right :
245248 levels = []
246249 for a , b in zip (bins , bins [1 :]):
@@ -258,12 +261,16 @@ def _format_levels(bins, prec, right=True,
258261 else :
259262 levels = ['[%s, %s)' % (fmt (a ), fmt (b ))
260263 for a , b in zip (bins , bins [1 :])]
261-
262264 return levels
263265
264266
265- def _format_label (x , precision = 3 ):
267+ def _format_label (x , precision = 3 , dtype = None ):
266268 fmt_str = '%%.%dg' % precision
269+
270+ if is_datetime64_dtype (dtype ):
271+ return to_datetime (x , unit = 'ns' )
272+ if is_timedelta64_dtype (dtype ):
273+ return to_timedelta (x , unit = 'ns' )
267274 if np .isinf (x ):
268275 return str (x )
269276 elif is_float (x ):
@@ -300,3 +307,55 @@ def _trim_zeros(x):
300307 if len (x ) > 1 and x [- 1 ] == '.' :
301308 x = x [:- 1 ]
302309 return x
310+
311+
312+ def _coerce_to_type (x ):
313+ """
314+ if the passed data is of datetime/timedelta type,
315+ this method converts it to integer so that cut method can
316+ handle it
317+ """
318+ dtype = None
319+
320+ if is_timedelta64_dtype (x ):
321+ x = to_timedelta (x ).view (np .int64 )
322+ dtype = np .timedelta64
323+ elif is_datetime64_dtype (x ):
324+ x = to_datetime (x ).view (np .int64 )
325+ dtype = np .datetime64
326+
327+ return x , dtype
328+
329+
330+ def _preprocess_for_cut (x ):
331+ """
332+ handles preprocessing for cut where we convert passed
333+ input to array, strip the index information and store it
334+ seperately
335+ """
336+ x_is_series = isinstance (x , Series )
337+ series_index = None
338+ name = None
339+
340+ if x_is_series :
341+ series_index = x .index
342+ name = x .name
343+
344+ x = np .asarray (x )
345+
346+ return x_is_series , series_index , name , x
347+
348+
349+ def _postprocess_for_cut (fac , bins , retbins , x_is_series , series_index , name ):
350+ """
351+ handles post processing for the cut method where
352+ we combine the index information if the originally passed
353+ datatype was a series
354+ """
355+ if x_is_series :
356+ fac = Series (fac , index = series_index , name = name )
357+
358+ if not retbins :
359+ return fac
360+
361+ return fac , bins
0 commit comments