Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,9 @@ def replace(self, to_replace, value, inplace: bool = False):

# ------------------------------------------------------------------------
# String methods interface
def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")):
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ def _cmp_method(self, other, op):
# String methods interface
_str_na_value = StringDtype.na_value

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
from pandas.arrays import BooleanArray

if dtype is None:
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,9 @@ def value_counts(self, dropna: bool = True) -> Series:

_str_na_value = ArrowStringDtype.na_value

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
# TODO: de-duplicate with StringArray method. This method is moreless copy and
# paste.

Expand Down
50 changes: 6 additions & 44 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
import numpy as np

import pandas._libs.lib as lib
from pandas._typing import (
ArrayLike,
FrameOrSeriesUnion,
)
from pandas._typing import FrameOrSeriesUnion
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -160,7 +157,6 @@ class StringMethods(NoNewAttributesMixin):
# TODO: Dispatch all the methods
# Currently the following are not dispatched to the array
# * cat
# * extract
# * extractall

def __init__(self, data):
Expand Down Expand Up @@ -243,7 +239,7 @@ def _wrap_result(
self,
result,
name=None,
expand=None,
expand: bool | None = None,
fill_value=np.nan,
returns_string=True,
):
Expand Down Expand Up @@ -2385,10 +2381,7 @@ def extract(
2 NaN
dtype: object
"""
from pandas import (
DataFrame,
array as pd_array,
)
from pandas import DataFrame

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")
Expand All @@ -2400,8 +2393,6 @@ def extract(
if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex):
raise ValueError("only one regex group is supported with Index")

# TODO: dispatch

obj = self._data
result_dtype = _result_dtype(obj)

Expand All @@ -2415,8 +2406,8 @@ def extract(
result = DataFrame(columns=columns, dtype=result_dtype)

else:
result_list = _str_extract(
obj.array, pat, flags=flags, expand=returns_df
result_list = self._data.array._str_extract(
pat, flags=flags, expand=returns_df
)

result_index: Index | None
Expand All @@ -2431,9 +2422,7 @@ def extract(

else:
name = _get_single_group_name(regex)
result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df)
# not dispatching, so we have to reconstruct here.
result = pd_array(result_arr, dtype=result_dtype)
result = self._data.array._str_extract(pat, flags=flags, expand=returns_df)
return self._wrap_result(result, name=name)

@forbid_nonstring_types(["bytes"])
Expand Down Expand Up @@ -3121,33 +3110,6 @@ def _get_group_names(regex: re.Pattern) -> list[Hashable]:
return [names.get(1 + i, i) for i in range(regex.groups)]


def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True):
"""
Find groups in each string in the array using passed regular expression.

Returns
-------
np.ndarray or list of lists is expand is True
"""
regex = re.compile(pat, flags=flags)

empty_row = [np.nan] * regex.groups

def f(x):
if not isinstance(x, str):
return empty_row
m = regex.search(x)
if m:
return [np.nan if item is None else item for item in m.groups()]
else:
return empty_row

if expand:
return [f(val) for val in np.asarray(arr)]

return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object)


def str_extractall(arr, pat, flags=0):
regex = re.compile(pat, flags=flags)
# the regex must contain capture groups.
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,7 @@ def _str_split(self, pat=None, n=-1, expand=False):
@abc.abstractmethod
def _str_rsplit(self, pat=None, n=-1):
pass

@abc.abstractmethod
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
pass
37 changes: 33 additions & 4 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __len__(self):
# For typing, _str_map relies on the object being sized.
raise NotImplementedError

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
"""
Map a callable over valid element of the array.

Expand All @@ -47,6 +49,8 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
for object-dtype and Categorical and ``pd.NA`` for StringArray.
dtype : Dtype, optional
The dtype of the result array.
convert : bool, default True
Whether to call `maybe_convert_objects` on the resulting ndarray
"""
if dtype is None:
dtype = np.dtype("object")
Expand All @@ -60,9 +64,9 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):

arr = np.asarray(self, dtype=object)
mask = isna(arr)
convert = not np.all(mask)
map_convert = convert and not np.all(mask)
try:
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert)
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert)
except (TypeError, AttributeError) as e:
# Reraise the exception if callable `f` got wrong number of args.
# The user may want to be warned by this, instead of getting NaN
Expand All @@ -88,7 +92,7 @@ def g(x):
return result
if na_value is not np.nan:
np.putmask(result, mask, na_value)
if result.dtype == object:
if convert and result.dtype == object:
result = lib.maybe_convert_objects(result)
return result

Expand Down Expand Up @@ -410,3 +414,28 @@ def _str_lstrip(self, to_strip=None):

def _str_rstrip(self, to_strip=None):
return self._str_map(lambda x: x.rstrip(to_strip))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
regex = re.compile(pat, flags=flags)
na_value = self._str_na_value

if not expand:

def g(x):
m = regex.search(x)
return m.groups()[0] if m else na_value

return self._str_map(g, convert=False)

empty_row = [na_value] * regex.groups

def f(x):
if not isinstance(x, str):
return empty_row
m = regex.search(x)
if m:
return [na_value if item is None else item for item in m.groups()]
else:
return empty_row

return [f(val) for val in np.asarray(self)]