Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5ccf476
feat: Add `DType.base_type`
dangotbanned Aug 10, 2025
3070470
refactor: Simplify `_arrow.utils.narwhals_to_native_dtype`
dangotbanned Aug 10, 2025
3a1a257
test: fix doctests
dangotbanned Aug 11, 2025
b725314
Merge branch 'main' into dtype-base-type
dangotbanned Aug 11, 2025
970cb8a
refactor: Simplify `_polars.utils.narwhals_to_native_dtype`
dangotbanned Aug 11, 2025
9e994f3
refactor: Simplify `_ibis.utils.narwhals_to_native_dtype`
dangotbanned Aug 11, 2025
52dfc47
Merge branch 'main' into dtype-base-type
dangotbanned Aug 12, 2025
7f5b00e
refactor: Simplify `_duckdb.utils.narwhals_to_native_dtype`
dangotbanned Aug 12, 2025
ee3d38d
Merge branch 'main' into dtype-base-type
dangotbanned Aug 13, 2025
dacd1db
Merge branch 'main' into dtype-base-type
dangotbanned Aug 13, 2025
6924757
refactor: tweak duckdb array conversion
dangotbanned Aug 13, 2025
0d77180
refactor: Simplify `_dask.utils.narwhals_to_native_dtype`
dangotbanned Aug 13, 2025
5240e49
refactor: Align unsupported across backends
dangotbanned Aug 13, 2025
6212cda
Merge remote-tracking branch 'upstream/main' into dtype-base-type
dangotbanned Aug 13, 2025
654bd01
docs: remove returns section
dangotbanned Aug 13, 2025
17508b2
fix: Include `v1` dtypes for dask
dangotbanned Aug 13, 2025
b66891c
refactor: Simplify `_spark_like.utils.narwhals_to_native_dtype`
dangotbanned Aug 13, 2025
cf83c9d
refactor: Simplify `_pandas_like.utils.narwhals_to_native_dtype`
dangotbanned Aug 13, 2025
2465761
🧹🧹🧹
dangotbanned Aug 13, 2025
61d2011
actually raise the exception
MarcoGorelli Aug 13, 2025
d9e6c65
Update narwhals/_pandas_like/utils.py
MarcoGorelli Aug 13, 2025
2967c2d
fix syntax (i shouldnt have tried modifying from github ui)
MarcoGorelli Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow.compute as pc

from narwhals._compliant import EagerSeriesNamespace
from narwhals._utils import isinstance_or_issubclass
from narwhals._utils import Version, isinstance_or_issubclass

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
Expand All @@ -25,7 +25,6 @@
ScalarAny,
)
from narwhals._duration import IntervalUnit
from narwhals._utils import Version
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, PythonLiteral

Expand Down Expand Up @@ -169,44 +168,38 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: #
return dtypes.Unknown() # pragma: no cover


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912
dtypes = Version.MAIN.dtypes
NW_TO_PA_DTYPES: Mapping[type[DType], pa.DataType] = {
dtypes.Float64: pa.float64(),
dtypes.Float32: pa.float32(),
dtypes.Binary: pa.binary(),
dtypes.String: pa.string(),
dtypes.Boolean: pa.bool_(),
dtypes.Categorical: pa.dictionary(pa.uint32(), pa.string()),
dtypes.Date: pa.date32(),
dtypes.Time: pa.time64("ns"),
dtypes.Int8: pa.int8(),
dtypes.Int16: pa.int16(),
dtypes.Int32: pa.int32(),
dtypes.Int64: pa.int64(),
dtypes.UInt8: pa.uint8(),
dtypes.UInt16: pa.uint16(),
dtypes.UInt32: pa.uint32(),
dtypes.UInt64: pa.uint64(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object)


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Decimal):
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Float64):
return pa.float64()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return pa.float32()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return pa.int64()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return pa.int32()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return pa.int16()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return pa.int8()
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return pa.uint64()
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return pa.uint32()
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return pa.uint16()
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return pa.uint8()
if isinstance_or_issubclass(dtype, dtypes.String):
return pa.string()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pa.bool_()
if isinstance_or_issubclass(dtype, dtypes.Categorical):
return pa.dictionary(pa.uint32(), pa.string())
base_type = dtype.base_type()
if pa_type := NW_TO_PA_DTYPES.get(base_type):
return pa_type
if isinstance_or_issubclass(dtype, dtypes.Datetime):
unit = dtype.time_unit
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pa.duration(dtype.time_unit)
if isinstance_or_issubclass(dtype, dtypes.Date):
return pa.date32()
if isinstance_or_issubclass(dtype, dtypes.List):
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
Expand All @@ -220,12 +213,8 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
inner = narwhals_to_native_dtype(dtype.inner, version=version)
list_size = dtype.size
return pa.list_(inner, list_size=list_size)
if isinstance_or_issubclass(dtype, dtypes.Time):
return pa.time64("ns")
if isinstance_or_issubclass(dtype, dtypes.Binary):
return pa.binary()
if isinstance_or_issubclass(dtype, dtypes.Object):
msg = f"Converting to {dtype} dtype is not supported for pyarrow."
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
Expand Down
90 changes: 39 additions & 51 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from narwhals.dependencies import get_pyarrow

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence

import dask.dataframe as dd
import dask.dataframe.dask_expr as dx

from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
from narwhals._dask.expr import DaskExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
else:
try:
Expand Down Expand Up @@ -79,36 +80,45 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
raise RuntimeError(msg)


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any: # noqa: C901, PLR0912
dtypes = Version.MAIN.dtypes
dtypes_v1 = Version.V1.dtypes
NW_TO_DASK_DTYPES: Mapping[type[DType], str] = {
dtypes.Float64: "float64",
dtypes.Float32: "float32",
dtypes.Boolean: "bool",
dtypes.Categorical: "category",
dtypes.Date: "date32[day][pyarrow]",
dtypes.Int8: "int8",
dtypes.Int16: "int16",
dtypes.Int32: "int32",
dtypes.Int64: "int64",
dtypes.UInt8: "uint8",
dtypes.UInt16: "uint16",
dtypes.UInt32: "uint32",
dtypes.UInt64: "uint64",
dtypes.Datetime: "datetime64[us]",
dtypes.Duration: "timedelta64[ns]",
dtypes_v1.Datetime: "datetime64[us]",
dtypes_v1.Duration: "timedelta64[ns]",
}
UNSUPPORTED_DTYPES = (
dtypes.List,
dtypes.Struct,
dtypes.Array,
dtypes.Time,
dtypes.Binary,
)


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any:
dtypes = version.dtypes
if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return "uint8"
base_type = dtype.base_type()
if dask_type := NW_TO_DASK_DTYPES.get(base_type):
return dask_type
if isinstance_or_issubclass(dtype, dtypes.String):
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
if get_pyarrow() is not None:
return "string[pyarrow]"
return "string[python]" # pragma: no cover
return "string[pyarrow]" if get_pyarrow() else "string[python]"
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
Expand All @@ -122,30 +132,8 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any: # noqa
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)

if isinstance_or_issubclass(dtype, dtypes.Categorical):
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Date):
return "date32[day][pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
msg = f"Converting to {base_type.__name__} dtype is not supported for Dask."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Time): # pragma: no cover
msg = "Converting to Time dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Binary): # pragma: no cover
msg = "Converting to Binary dtype is not supported yet"
raise NotImplementedError(msg)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
8 changes: 2 additions & 6 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

from duckdb import CoalesceOperator, StarExpression
from duckdb.typing import DuckDBPyType

from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
Expand Down Expand Up @@ -274,15 +273,12 @@ def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [expr.cast(DuckDBPyType(native_dtype)) for expr in self(df)]
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [
expr.cast(DuckDBPyType(native_dtype))
for expr in self.window_function(df, inputs)
]
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
func,
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def func(df: DuckDBLazyFrame) -> list[Expression]:
tz = DeferredTimeZone(df.native)
if dtype is not None:
target = narwhals_to_native_dtype(dtype, self._version, tz)
return [lit(value).cast(target)] # type: ignore[arg-type]
return [lit(value).cast(target)]
return [lit(value)]

return self._expr(
Expand Down
Loading
Loading