diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 0b4733175c..2de5813b81 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -31,6 +31,7 @@ from collections.abc import Iterable, Iterator, Mapping from types import ModuleType + import pyarrow as pa from pandas._typing import Dtype as PandasDtype from pandas.core.dtypes.dtypes import BaseMaskedDtype from typing_extensions import TypeAlias, TypeIs @@ -73,44 +74,12 @@ \] # Closing bracket for datetime64 $""" PATTERN_PD_DATETIME = re.compile(PD_DATETIME_RGX, re.VERBOSE) -PA_DATETIME_RGX = r"""^ - timestamp\[ - (?Ps|ms|us|ns) # Match time unit: s, ms, us, or ns - (?:, # Begin non-capturing group for optional timezone - \s?tz= # Match "tz=" prefix - (?P # Start named group for timezone - [a-zA-Z\/]* # Match timezone name (e.g., UTC, America/New_York) - (?: # Begin optional non-capturing group for offset - [+-]\d{2}:\d{2} # Match offset in format +HH:MM or -HH:MM - )? # End optional offset group - ) # End time_zone group - )? # End optional timezone group - \] # Closing bracket for timestamp - \[pyarrow\] # Literal string "[pyarrow]" -$""" -PATTERN_PA_DATETIME = re.compile(PA_DATETIME_RGX, re.VERBOSE) PD_DURATION_RGX = r"""^ timedelta64\[ (?Ps|ms|us|ns) # Match time unit: s, ms, us, or ns \] # Closing bracket for timedelta64 $""" PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE) -PA_DURATION_RGX = r"""^ - duration\[ - (?Ps|ms|us|ns) # Match time unit: s, ms, us, or ns - \] # Closing bracket for duration - \[pyarrow\] # Literal string "[pyarrow]" -$""" -PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE) -PA_DECIMAL_RGX = r"""^ - decimal128\( - (?P\d+) # Precision value - ,\s # Literal string ", " - (?P\d+) # Scale value - \) - \[pyarrow\] # Literal string "[pyarrow]" -$""" -PATTERN_PA_DECIMAL = re.compile(PA_DECIMAL_RGX, re.VERBOSE) NativeIntervalUnit: TypeAlias = Literal[ "year", @@ -219,79 +188,56 @@ def rename( return cast("NativeNDFrameT", result) # type: ignore[redundant-cast] +@functools.lru_cache(maxsize=16) +def is_dtype_non_pyarrow_string(native_dtype: Any) -> bool: + """*There is no problem which can't be solved by adding an extra string type* pandas.""" + # TODO @dangotbanned: Investigate how we could handle `cudf` without `str(native_dtype)` + # https://github.com/rapidsai/cudf/blob/a32b8cf62c9b086b645b0825b78b99f065b1887f/python/cudf/cudf/utils/dtypes.py#L646-L670 + return isinstance(native_dtype, pd.StringDtype) or str(native_dtype) in { + "string", + "string[python]", + "string[pyarrow_numpy]", + "", # why? why? why? + "str", + } + + @functools.lru_cache(maxsize=16) def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912 dtype = str(native_dtype) dtypes = version.dtypes - if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: + if dtype in {"int64", "Int64"}: return dtypes.Int64() - if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}: + if dtype in {"int32", "Int32"}: return dtypes.Int32() - if dtype in {"int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"}: + if dtype in {"int16", "Int16"}: return dtypes.Int16() - if dtype in {"int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"}: + if dtype in {"int8", "Int8"}: return dtypes.Int8() - if dtype in {"uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"}: + if dtype in {"uint64", "UInt64"}: return dtypes.UInt64() - if dtype in {"uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"}: + if dtype in {"uint32", "UInt32"}: return dtypes.UInt32() - if dtype in {"uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"}: + if dtype in {"uint16", "UInt16"}: return dtypes.UInt16() - if dtype in {"uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"}: + if dtype in {"uint8", "UInt8"}: return dtypes.UInt8() - if dtype in { - "float64", - "Float64", - "Float64[pyarrow]", - "float64[pyarrow]", - "double[pyarrow]", - }: + if dtype in {"float64", "Float64"}: return dtypes.Float64() - if dtype in { - "float32", - "Float32", - "Float32[pyarrow]", - "float32[pyarrow]", - "float[pyarrow]", - }: + if dtype in {"float32", "Float32"}: return dtypes.Float32() - if dtype in { - # "there is no problem which can't be solved by adding an extra string type" pandas - "string", - "string[python]", - "string[pyarrow]", - "string[pyarrow_numpy]", - "large_string[pyarrow]", - "str", - }: + if is_dtype_non_pyarrow_string(native_dtype): return dtypes.String() - if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}: + if dtype in {"bool", "boolean"}: return dtypes.Boolean() - if dtype.startswith("dictionary<"): - return dtypes.Categorical() - if dtype == "category": - return native_categorical_to_narwhals_dtype(native_dtype, version) - if (match_ := PATTERN_PD_DATETIME.match(dtype)) or ( - match_ := PATTERN_PA_DATETIME.match(dtype) - ): + if match_ := PATTERN_PD_DATETIME.match(dtype): dt_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment] dt_time_zone: str | None = match_.group("time_zone") return dtypes.Datetime(dt_time_unit, dt_time_zone) - if (match_ := PATTERN_PD_DURATION.match(dtype)) or ( - match_ := PATTERN_PA_DURATION.match(dtype) - ): + if match_ := PATTERN_PD_DURATION.match(dtype): du_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment] return dtypes.Duration(du_time_unit) - if dtype == "date32[day][pyarrow]": - return dtypes.Date() - if match_ := PATTERN_PA_DECIMAL.match(dtype): - precision, scale = int(match_.group("precision")), int(match_.group("scale")) - return dtypes.Decimal(precision, scale) - if dtype.startswith("time") and dtype.endswith("[pyarrow]"): - return dtypes.Time() - if dtype.startswith("binary") and dtype.endswith("[pyarrow]"): - return dtypes.Binary() return dtypes.Unknown() # pragma: no cover @@ -319,9 +265,7 @@ def object_native_to_narwhals_dtype( def native_categorical_to_narwhals_dtype( - native_dtype: pd.CategoricalDtype, - version: Version, - implementation: Literal[Implementation.CUDF] | None = None, + native_dtype: pd.CategoricalDtype, version: Version, implementation: Implementation ) -> DType: dtypes = version.dtypes if version is Version.V1: @@ -340,12 +284,17 @@ def _cudf_categorical_to_list( native_dtype: Any, ) -> Callable[[], list[Any]]: # pragma: no cover # NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype + # https://github.com/rapidsai/cudf/issues/18536 + # https://github.com/rapidsai/cudf/issues/14027 def fn() -> list[Any]: return native_dtype.categories.to_arrow().to_pylist() return fn +CUDF_BASE_DTYPE_PREFIX = ("list", "struct", "decimal") + + def native_to_narwhals_dtype( native_dtype: Any, version: Version, @@ -355,21 +304,18 @@ def native_to_narwhals_dtype( ) -> DType: str_dtype = str(native_dtype) - if str_dtype.startswith(("large_list", "list", "struct", "fixed_size_list")): + if is_dtype_pyarrow(native_dtype) or str_dtype.startswith(CUDF_BASE_DTYPE_PREFIX): from narwhals._arrow.utils import ( native_to_narwhals_dtype as arrow_native_to_narwhals_dtype, ) if hasattr(native_dtype, "to_arrow"): # pragma: no cover - # cudf, cudf.pandas - return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version) - return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version) - if str_dtype == "category" and implementation.is_cudf(): - # https://github.com/rapidsai/cudf/issues/18536 - # https://github.com/rapidsai/cudf/issues/14027 - return native_categorical_to_narwhals_dtype( - native_dtype, version, Implementation.CUDF - ) + pa_dtype: pa.DataType = native_dtype.to_arrow() # pyright: ignore[reportAttributeAccessIssue] + else: + pa_dtype = native_dtype.pyarrow_dtype + return arrow_native_to_narwhals_dtype(pa_dtype, version) + if str_dtype == "category": + return native_categorical_to_narwhals_dtype(native_dtype, version, implementation) if str_dtype != "object": return non_object_native_to_narwhals_dtype(native_dtype, version) if implementation is Implementation.DASK: