Skip to content
Merged
138 changes: 42 additions & 96 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections.abc import Iterable, Iterator, Mapping
from types import ModuleType

import pyarrow as pa
from pandas._typing import Dtype as PandasDtype
from pandas.core.dtypes.dtypes import BaseMaskedDtype
from typing_extensions import TypeAlias, TypeIs
Expand Down Expand Up @@ -73,44 +74,12 @@
\] # Closing bracket for datetime64
$"""
PATTERN_PD_DATETIME = re.compile(PD_DATETIME_RGX, re.VERBOSE)
PA_DATETIME_RGX = r"""^
timestamp\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
(?:, # Begin non-capturing group for optional timezone
\s?tz= # Match "tz=" prefix
(?P<time_zone> # Start named group for timezone
[a-zA-Z\/]* # Match timezone name (e.g., UTC, America/New_York)
(?: # Begin optional non-capturing group for offset
[+-]\d{2}:\d{2} # Match offset in format +HH:MM or -HH:MM
)? # End optional offset group
) # End time_zone group
)? # End optional timezone group
\] # Closing bracket for timestamp
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DATETIME = re.compile(PA_DATETIME_RGX, re.VERBOSE)
PD_DURATION_RGX = r"""^
timedelta64\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
\] # Closing bracket for timedelta64
$"""
PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE)
PA_DURATION_RGX = r"""^
duration\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
\] # Closing bracket for duration
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE)
PA_DECIMAL_RGX = r"""^
decimal128\(
(?P<precision>\d+) # Precision value
,\s # Literal string ", "
(?P<scale>\d+) # Scale value
\)
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DECIMAL = re.compile(PA_DECIMAL_RGX, re.VERBOSE)

NativeIntervalUnit: TypeAlias = Literal[
"year",
Expand Down Expand Up @@ -219,79 +188,56 @@ def rename(
return cast("NativeNDFrameT", result) # type: ignore[redundant-cast]


@functools.lru_cache(maxsize=16)
def is_dtype_non_pyarrow_string(native_dtype: Any) -> bool:
"""*There is no problem which can't be solved by adding an extra string type* pandas."""
# TODO @dangotbanned: Investigate how we could handle `cudf` without `str(native_dtype)`
# https://github.com/rapidsai/cudf/blob/a32b8cf62c9b086b645b0825b78b99f065b1887f/python/cudf/cudf/utils/dtypes.py#L646-L670
return isinstance(native_dtype, pd.StringDtype) or str(native_dtype) in {
"string",
"string[python]",
"string[pyarrow_numpy]",
"<StringDtype(na_value=nan)>", # why? why? why?
"str",
}


@functools.lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912
dtype = str(native_dtype)

dtypes = version.dtypes
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
if dtype in {"int64", "Int64"}:
return dtypes.Int64()
if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}:
if dtype in {"int32", "Int32"}:
return dtypes.Int32()
if dtype in {"int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"}:
if dtype in {"int16", "Int16"}:
return dtypes.Int16()
if dtype in {"int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"}:
if dtype in {"int8", "Int8"}:
return dtypes.Int8()
if dtype in {"uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"}:
if dtype in {"uint64", "UInt64"}:
return dtypes.UInt64()
if dtype in {"uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"}:
if dtype in {"uint32", "UInt32"}:
return dtypes.UInt32()
if dtype in {"uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"}:
if dtype in {"uint16", "UInt16"}:
return dtypes.UInt16()
if dtype in {"uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"}:
if dtype in {"uint8", "UInt8"}:
return dtypes.UInt8()
if dtype in {
"float64",
"Float64",
"Float64[pyarrow]",
"float64[pyarrow]",
"double[pyarrow]",
}:
if dtype in {"float64", "Float64"}:
return dtypes.Float64()
if dtype in {
"float32",
"Float32",
"Float32[pyarrow]",
"float32[pyarrow]",
"float[pyarrow]",
}:
if dtype in {"float32", "Float32"}:
return dtypes.Float32()
if dtype in {
# "there is no problem which can't be solved by adding an extra string type" pandas
"string",
"string[python]",
"string[pyarrow]",
"string[pyarrow_numpy]",
"large_string[pyarrow]",
"str",
}:
if is_dtype_non_pyarrow_string(native_dtype):
return dtypes.String()
if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}:
if dtype in {"bool", "boolean"}:
return dtypes.Boolean()
if dtype.startswith("dictionary<"):
return dtypes.Categorical()
if dtype == "category":
return native_categorical_to_narwhals_dtype(native_dtype, version)
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
match_ := PATTERN_PA_DATETIME.match(dtype)
):
if match_ := PATTERN_PD_DATETIME.match(dtype):
dt_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
dt_time_zone: str | None = match_.group("time_zone")
return dtypes.Datetime(dt_time_unit, dt_time_zone)
if (match_ := PATTERN_PD_DURATION.match(dtype)) or (
match_ := PATTERN_PA_DURATION.match(dtype)
):
if match_ := PATTERN_PD_DURATION.match(dtype):
du_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
return dtypes.Duration(du_time_unit)
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
if match_ := PATTERN_PA_DECIMAL.match(dtype):
precision, scale = int(match_.group("precision")), int(match_.group("scale"))
return dtypes.Decimal(precision, scale)
if dtype.startswith("time") and dtype.endswith("[pyarrow]"):
return dtypes.Time()
if dtype.startswith("binary") and dtype.endswith("[pyarrow]"):
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover


Expand Down Expand Up @@ -319,9 +265,7 @@ def object_native_to_narwhals_dtype(


def native_categorical_to_narwhals_dtype(
native_dtype: pd.CategoricalDtype,
version: Version,
implementation: Literal[Implementation.CUDF] | None = None,
native_dtype: pd.CategoricalDtype, version: Version, implementation: Implementation
) -> DType:
dtypes = version.dtypes
if version is Version.V1:
Expand All @@ -340,12 +284,17 @@ def _cudf_categorical_to_list(
native_dtype: Any,
) -> Callable[[], list[Any]]: # pragma: no cover
# NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype
# https://github.com/rapidsai/cudf/issues/18536
# https://github.com/rapidsai/cudf/issues/14027
def fn() -> list[Any]:
return native_dtype.categories.to_arrow().to_pylist()

return fn


CUDF_BASE_DTYPE_PREFIX = ("list", "struct", "decimal")


def native_to_narwhals_dtype(
native_dtype: Any,
version: Version,
Expand All @@ -355,21 +304,18 @@ def native_to_narwhals_dtype(
) -> DType:
str_dtype = str(native_dtype)

if str_dtype.startswith(("large_list", "list", "struct", "fixed_size_list")):
if is_dtype_pyarrow(native_dtype) or str_dtype.startswith(CUDF_BASE_DTYPE_PREFIX):
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)

if hasattr(native_dtype, "to_arrow"): # pragma: no cover
# cudf, cudf.pandas
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
if str_dtype == "category" and implementation.is_cudf():
# https://github.com/rapidsai/cudf/issues/18536
# https://github.com/rapidsai/cudf/issues/14027
return native_categorical_to_narwhals_dtype(
native_dtype, version, Implementation.CUDF
)
pa_dtype: pa.DataType = native_dtype.to_arrow() # pyright: ignore[reportAttributeAccessIssue]
else:
pa_dtype = native_dtype.pyarrow_dtype
return arrow_native_to_narwhals_dtype(pa_dtype, version)
if str_dtype == "category":
return native_categorical_to_narwhals_dtype(native_dtype, version, implementation)
if str_dtype != "object":
return non_object_native_to_narwhals_dtype(native_dtype, version)
if implementation is Implementation.DASK:
Expand Down
Loading