Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 40 additions & 45 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
from narwhals.utils import Version

T = TypeVar("T")
Expand Down Expand Up @@ -70,16 +69,14 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An

@lru_cache(maxsize=16)
def native_to_narwhals_dtype(
dtype: pl.DataType,
version: Version,
backend_version: tuple[int, ...],
dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
) -> DType:
dtypes = import_dtypes_module(version)
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if dtype == getattr(pl, "Int128", None): # type: ignore[operator] # pragma: no cover
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
Expand All @@ -90,7 +87,7 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if dtype == getattr(pl, "UInt128", None): # type: ignore[operator] # pragma: no cover
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
Expand All @@ -113,32 +110,32 @@ def native_to_narwhals_dtype(
return dtypes.Enum()
if dtype == pl.Date:
return dtypes.Date()
if dtype == pl.Datetime:
dt_time_unit: TimeUnit = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone)
if dtype == pl.Duration:
du_time_unit: TimeUnit = getattr(dtype, "time_unit", "us")
return dtypes.Duration(time_unit=du_time_unit)
if dtype == pl.Struct:
return dtypes.Struct(
[
dtypes.Field(
field_name,
native_to_narwhals_dtype(field_type, version, backend_version),
)
for field_name, field_type in dtype # type: ignore[attr-defined]
]
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if dtype == pl.List:
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(
native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined]
native_to_narwhals_dtype(dtype.inner, version, backend_version)
)
if dtype == pl.Array:
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined]
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
return dtypes.Array(
inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined]
shape=outer_shape,
native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
)
if dtype == pl.Decimal:
return dtypes.Decimal()
Expand All @@ -157,7 +154,7 @@ def narwhals_to_native_dtype(
return pl.Float64()
if dtype == dtypes.Float32:
return pl.Float32()
if dtype == dtypes.Int128 and getattr(pl, "Int128", None) is not None:
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if dtype == dtypes.Int64:
Expand Down Expand Up @@ -200,24 +197,22 @@ def narwhals_to_native_dtype(
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if dtype == dtypes.List:
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr]
if dtype == dtypes.Struct:
return pl.Struct(
fields=[
pl.Field(
name=field.name,
dtype=narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields # type: ignore[union-attr]
]
)
if dtype == dtypes.Array: # pragma: no cover
size = dtype.size # type: ignore[union-attr]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(
field.name,
narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
return pl.Array(
inner=narwhals_to_native_dtype(dtype.inner, version, backend_version), # type: ignore[union-attr]
**kwargs,
narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
)
return pl.Unknown() # pragma: no cover

Expand Down
Loading