diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index e51852f532..0a9a586fce 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -23,7 +23,6 @@ from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType - from narwhals.typing import TimeUnit from narwhals.utils import Version T = TypeVar("T") @@ -70,16 +69,14 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An @lru_cache(maxsize=16) def native_to_narwhals_dtype( - dtype: pl.DataType, - version: Version, - backend_version: tuple[int, ...], + dtype: pl.DataType, version: Version, backend_version: tuple[int, ...] ) -> DType: dtypes = import_dtypes_module(version) if dtype == pl.Float64: return dtypes.Float64() if dtype == pl.Float32: return dtypes.Float32() - if dtype == getattr(pl, "Int128", None): # type: ignore[operator] # pragma: no cover + if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover # Not available for Polars pre 1.8.0 return dtypes.Int128() if dtype == pl.Int64: @@ -90,7 +87,7 @@ def native_to_narwhals_dtype( return dtypes.Int16() if dtype == pl.Int8: return dtypes.Int8() - if dtype == getattr(pl, "UInt128", None): # type: ignore[operator] # pragma: no cover + if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover # Not available for Polars pre 1.8.0 return dtypes.UInt128() if dtype == pl.UInt64: @@ -113,32 +110,32 @@ def native_to_narwhals_dtype( return dtypes.Enum() if dtype == pl.Date: return dtypes.Date() - if dtype == pl.Datetime: - dt_time_unit: TimeUnit = getattr(dtype, "time_unit", "us") - dt_time_zone = getattr(dtype, "time_zone", None) - return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone) - if dtype == pl.Duration: - du_time_unit: TimeUnit = getattr(dtype, "time_unit", "us") - return dtypes.Duration(time_unit=du_time_unit) - if dtype == pl.Struct: - return dtypes.Struct( - [ - dtypes.Field( - field_name, - native_to_narwhals_dtype(field_type, version, backend_version), - ) - for field_name, field_type in dtype # type: ignore[attr-defined] - ] + if isinstance_or_issubclass(dtype, pl.Datetime): + return ( + dtypes.Datetime() + if dtype is pl.Datetime + else dtypes.Datetime(dtype.time_unit, dtype.time_zone) ) - if dtype == pl.List: + if isinstance_or_issubclass(dtype, pl.Duration): + return ( + dtypes.Duration() + if dtype is pl.Duration + else dtypes.Duration(dtype.time_unit) + ) + if isinstance_or_issubclass(dtype, pl.Struct): + fields = [ + dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version)) + for name, tp in dtype + ] + return dtypes.Struct(fields) + if isinstance_or_issubclass(dtype, pl.List): return dtypes.List( - native_to_narwhals_dtype(dtype.inner, version, backend_version) # type: ignore[attr-defined] + native_to_narwhals_dtype(dtype.inner, version, backend_version) ) - if dtype == pl.Array: - outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size # type: ignore[attr-defined] + if isinstance_or_issubclass(dtype, pl.Array): + outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size return dtypes.Array( - inner=native_to_narwhals_dtype(dtype.inner, version, backend_version), # type: ignore[attr-defined] - shape=outer_shape, + native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape ) if dtype == pl.Decimal: return dtypes.Decimal() @@ -157,7 +154,7 @@ def narwhals_to_native_dtype( return pl.Float64() if dtype == dtypes.Float32: return pl.Float32() - if dtype == dtypes.Int128 and getattr(pl, "Int128", None) is not None: + if dtype == dtypes.Int128 and hasattr(pl, "Int128"): # Not available for Polars pre 1.8.0 return pl.Int128() if dtype == dtypes.Int64: @@ -200,24 +197,22 @@ def narwhals_to_native_dtype( return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type] if isinstance_or_issubclass(dtype, dtypes.Duration): return pl.Duration(dtype.time_unit) # type: ignore[arg-type] - if dtype == dtypes.List: - return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr] - if dtype == dtypes.Struct: - return pl.Struct( - fields=[ - pl.Field( - name=field.name, - dtype=narwhals_to_native_dtype(field.dtype, version, backend_version), - ) - for field in dtype.fields # type: ignore[union-attr] - ] - ) - if dtype == dtypes.Array: # pragma: no cover - size = dtype.size # type: ignore[union-attr] + if isinstance_or_issubclass(dtype, dtypes.List): + return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) + if isinstance_or_issubclass(dtype, dtypes.Struct): + fields = [ + pl.Field( + field.name, + narwhals_to_native_dtype(field.dtype, version, backend_version), + ) + for field in dtype.fields + ] + return pl.Struct(fields) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + size = dtype.size kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} return pl.Array( - inner=narwhals_to_native_dtype(dtype.inner, version, backend_version), # type: ignore[union-attr] - **kwargs, + narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs ) return pl.Unknown() # pragma: no cover