@@ -604,7 +604,7 @@ def str_extract(arr, pat, flags=0, expand=None):
604604 return _str_extract_frame (arr ._orig , pat , flags = flags )
605605 else :
606606 result , name = _str_extract_noexpand (arr ._data , pat , flags = flags )
607- return arr ._wrap_result (result , name = name )
607+ return arr ._wrap_result (result , name = name , expand = expand )
608608
609609
610610def str_extractall (arr , pat , flags = 0 ):
@@ -1292,7 +1292,10 @@ def __iter__(self):
12921292 i += 1
12931293 g = self .get (i )
12941294
1295- def _wrap_result (self , result , use_codes = True , name = None ):
1295+ def _wrap_result (self , result , use_codes = True ,
1296+ name = None , expand = None ):
1297+
1298+ from pandas .core .index import Index , MultiIndex
12961299
12971300 # for category, we do the stuff on the categories, so blow it up
12981301 # to the full series again
@@ -1302,48 +1305,42 @@ def _wrap_result(self, result, use_codes=True, name=None):
13021305 if use_codes and self ._is_categorical :
13031306 result = take_1d (result , self ._orig .cat .codes )
13041307
1305- # leave as it is to keep extract and get_dummies results
1306- # can be merged to _wrap_result_expand in v0.17
1307- from pandas .core .series import Series
1308- from pandas .core .frame import DataFrame
1309- from pandas .core .index import Index
1310-
1311- if not hasattr (result , 'ndim' ):
1308+ if not hasattr (result , 'ndim' ) or not hasattr (result , 'dtype' ):
13121309 return result
1310+ assert result .ndim < 3
13131311
1314- if result .ndim == 1 :
1315- # Wait until we are sure result is a Series or Index before
1316- # checking attributes (GH 12180)
1317- name = name or getattr (result , 'name' , None ) or self ._orig .name
1318- if isinstance (self ._orig , Index ):
1319- # if result is a boolean np.array, return the np.array
1320- # instead of wrapping it into a boolean Index (GH 8875)
1321- if is_bool_dtype (result ):
1322- return result
1323- return Index (result , name = name )
1324- return Series (result , index = self ._orig .index , name = name )
1325- else :
1326- assert result .ndim < 3
1327- return DataFrame (result , index = self ._orig .index )
1312+ if expand is None :
1313+ # infer from ndim if expand is not specified
1314+ expand = False if result .ndim == 1 else True
1315+
1316+ elif expand is True and not isinstance (self ._orig , Index ):
1317+ # required when expand=True is explicitly specified
1318+ # not needed when infered
1319+
1320+ def cons_row (x ):
1321+ if is_list_like (x ):
1322+ return x
1323+ else :
1324+ return [x ]
1325+
1326+ result = [cons_row (x ) for x in result ]
13281327
1329- def _wrap_result_expand (self , result , expand = False ):
13301328 if not isinstance (expand , bool ):
13311329 raise ValueError ("expand must be True or False" )
13321330
1333- # for category, we do the stuff on the categories, so blow it up
1334- # to the full series again
1335- if self ._is_categorical :
1336- result = take_1d (result , self ._orig .cat .codes )
1337-
1338- from pandas .core .index import Index , MultiIndex
1339- if not hasattr (result , 'ndim' ):
1340- return result
1331+ if name is None :
1332+ name = getattr (result , 'name' , None )
1333+ if name is None :
1334+ # do not use logical or, _orig may be a DataFrame
1335+ # which has "name" column
1336+ name = self ._orig .name
13411337
1338+ # Wait until we are sure result is a Series or Index before
1339+ # checking attributes (GH 12180)
13421340 if isinstance (self ._orig , Index ):
1343- name = getattr (result , 'name' , None )
13441341 # if result is a boolean np.array, return the np.array
13451342 # instead of wrapping it into a boolean Index (GH 8875)
1346- if hasattr ( result , 'dtype' ) and is_bool_dtype (result ):
1343+ if is_bool_dtype (result ):
13471344 return result
13481345
13491346 if expand :
@@ -1354,18 +1351,10 @@ def _wrap_result_expand(self, result, expand=False):
13541351 else :
13551352 index = self ._orig .index
13561353 if expand :
1357-
1358- def cons_row (x ):
1359- if is_list_like (x ):
1360- return x
1361- else :
1362- return [x ]
1363-
13641354 cons = self ._orig ._constructor_expanddim
1365- data = [cons_row (x ) for x in result ]
1366- return cons (data , index = index )
1355+ return cons (result , index = index )
13671356 else :
1368- name = getattr ( result , 'name' , None )
1357+ # Must a Series
13691358 cons = self ._orig ._constructor
13701359 return cons (result , name = name , index = index )
13711360
@@ -1380,12 +1369,12 @@ def cat(self, others=None, sep=None, na_rep=None):
13801369 @copy (str_split )
13811370 def split (self , pat = None , n = - 1 , expand = False ):
13821371 result = str_split (self ._data , pat , n = n )
1383- return self ._wrap_result_expand (result , expand = expand )
1372+ return self ._wrap_result (result , expand = expand )
13841373
13851374 @copy (str_rsplit )
13861375 def rsplit (self , pat = None , n = - 1 , expand = False ):
13871376 result = str_rsplit (self ._data , pat , n = n )
1388- return self ._wrap_result_expand (result , expand = expand )
1377+ return self ._wrap_result (result , expand = expand )
13891378
13901379 _shared_docs ['str_partition' ] = ("""
13911380 Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1440,7 +1429,7 @@ def rsplit(self, pat=None, n=-1, expand=False):
14401429 def partition (self , pat = ' ' , expand = True ):
14411430 f = lambda x : x .partition (pat )
14421431 result = _na_map (f , self ._data )
1443- return self ._wrap_result_expand (result , expand = expand )
1432+ return self ._wrap_result (result , expand = expand )
14441433
14451434 @Appender (_shared_docs ['str_partition' ] % {
14461435 'side' : 'last' ,
@@ -1451,7 +1440,7 @@ def partition(self, pat=' ', expand=True):
14511440 def rpartition (self , pat = ' ' , expand = True ):
14521441 f = lambda x : x .rpartition (pat )
14531442 result = _na_map (f , self ._data )
1454- return self ._wrap_result_expand (result , expand = expand )
1443+ return self ._wrap_result (result , expand = expand )
14551444
14561445 @copy (str_get )
14571446 def get (self , i ):
@@ -1597,7 +1586,8 @@ def get_dummies(self, sep='|'):
15971586 # methods available for making the dummies...
15981587 data = self ._orig .astype (str ) if self ._is_categorical else self ._data
15991588 result = str_get_dummies (data , sep )
1600- return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1589+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ),
1590+ expand = True )
16011591
16021592 @copy (str_translate )
16031593 def translate (self , table , deletechars = None ):
0 commit comments