Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def native_non_extension_to_narwhals_dtype(dtype: pa.DataType, version: Version)
return dtypes.Array(
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
)
if pa.types.is_decimal(dtype):
return dtypes.Decimal()
if pa.types.is_decimal128(dtype):
return dtypes.Decimal(precision=dtype.precision, scale=dtype.scale)
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
return dtypes.Time()
if pa.types.is_binary(dtype):
Expand All @@ -215,7 +215,7 @@ def native_non_extension_to_narwhals_dtype(dtype: pa.DataType, version: Version)
dtypes.UInt32: pa.uint32(),
dtypes.UInt64: pa.uint64(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object)
UNSUPPORTED_DTYPES = (dtypes.Object,)


def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
Expand All @@ -237,10 +237,12 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
for field in dtype.fields
]
)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Array):
inner = narwhals_to_native_dtype(dtype.inner, version=version)
list_size = dtype.size
return pa.list_(inner, list_size=list_size)
if isinstance_or_issubclass(dtype, dtypes.Decimal):
return pa.decimal128(precision=dtype.precision, scale=dtype.scale)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow."
raise NotImplementedError(msg)
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
dtypes.Array,
dtypes.Time,
dtypes.Binary,
dtypes.Decimal,
)


Expand Down
11 changes: 8 additions & 3 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def native_to_narwhals_dtype(
version: Version,
deferred_time_zone: DeferredTimeZone,
) -> DType:
duckdb_dtype_id = duckdb_dtype.id
duckdb_dtype_id: str = duckdb_dtype.id
dtypes = version.dtypes

# Handle nested data types first
Expand Down Expand Up @@ -180,6 +180,10 @@ def native_to_narwhals_dtype(
if duckdb_dtype_id == "timestamp with time zone":
return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)

if duckdb_dtype_id == "decimal":
(_, precision), (_, scale) = duckdb_dtype.children
return dtypes.Decimal(precision=precision, scale=scale)

return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)


Expand Down Expand Up @@ -215,7 +219,6 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)
"timestamp_ns": dtypes.Datetime("ns"),
"boolean": dtypes.Boolean(),
"interval": dtypes.Duration(),
"decimal": dtypes.Decimal(),
"time": dtypes.Time(),
"blob": dtypes.Binary(),
}.get(duckdb_dtype_id, dtypes.Unknown())
Expand Down Expand Up @@ -247,7 +250,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version)
"us": duckdb_dtypes.TIMESTAMP,
"ns": duckdb_dtypes.TIMESTAMP_NS,
}
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Categorical)
UNSUPPORTED_DTYPES = dtypes.Categorical


def narwhals_to_native_dtype( # noqa: PLR0912, C901
Expand Down Expand Up @@ -298,6 +301,8 @@ def narwhals_to_native_dtype( # noqa: PLR0912, C901
duckdb_inner = narwhals_to_native_dtype(nw_inner, version, deferred_time_zone)
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape)
return duckdb_dtypes.DuckDBPyType(f"{duckdb_inner}{duckdb_shape_fmt}")
if isinstance(dtype, dtypes.Decimal):
return duckdb_dtypes.DuckDBPyType(f"DECIMAL({dtype.precision}, {dtype.scale})")
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for DuckDB."
raise NotImplementedError(msg)
Expand Down
18 changes: 15 additions & 3 deletions narwhals/_ibis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,15 @@ def native_to_narwhals_dtype(ibis_dtype: IbisDataType, version: Version) -> DTyp
for name, dtype in ibis_dtype.items()
]
)
if ibis_dtype.is_decimal(): # pragma: no cover
return dtypes.Decimal()
if is_decimal(ibis_dtype):
# Same default as in ibis.Decimal.{to_polars, to_pyarrow}
# https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/polars.py#L82-L86
# https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/pyarrow.py#L165-L177
# According to their own comment:
# > set default precision and scale to something; unclear how to choose this
# source: (see https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/pyarrow.py#L166)
scale = 9 if ibis_dtype.scale is None else ibis_dtype.scale
return dtypes.Decimal(precision=ibis_dtype.precision, scale=scale)
if ibis_dtype.is_time():
return dtypes.Time()
if ibis_dtype.is_binary():
Expand Down Expand Up @@ -202,6 +209,10 @@ def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]:
return obj.is_floating()


def is_decimal(obj: IbisDataType) -> TypeIs[ibis_dtypes.Decimal]:
return obj.is_decimal()


dtypes = Version.MAIN.dtypes
NW_TO_IBIS_DTYPES: Mapping[type[DType], IbisDataType] = {
dtypes.Float64: ibis_dtypes.Float64(),
Expand All @@ -219,7 +230,6 @@ def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]:
dtypes.UInt16: ibis_dtypes.UInt16(),
dtypes.UInt32: ibis_dtypes.UInt32(),
dtypes.UInt64: ibis_dtypes.UInt64(),
dtypes.Decimal: ibis_dtypes.Decimal(),
}
# Enum support: https://github.com/ibis-project/ibis/issues/10991
UNSUPPORTED_DTYPES = (dtypes.Int128, dtypes.UInt128, dtypes.Categorical, dtypes.Enum)
Expand All @@ -246,6 +256,8 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> IbisDataType
if isinstance_or_issubclass(dtype, dtypes.Array):
inner = narwhals_to_native_dtype(dtype.inner, version)
return ibis_dtypes.Array(value_type=inner, length=dtype.size)
if isinstance_or_issubclass(dtype, dtypes.Decimal):
return ibis_dtypes.Decimal(dtype.precision, dtype.scale)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for Ibis."
raise NotImplementedError(msg)
Expand Down
29 changes: 21 additions & 8 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
\] # Closing bracket for timedelta64
$"""

PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE)
PA_DURATION_RGX = r"""^
duration\[
Expand All @@ -103,6 +102,15 @@
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE)
PA_DECIMAL_RGX = r"""^
decimal128\(
(?P<precision>\d+) # Precision value
,\s # Literal string ", "
(?P<scale>\d+) # Scale value
\)
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DECIMAL = re.compile(PA_DECIMAL_RGX, re.VERBOSE)

NativeIntervalUnit: TypeAlias = Literal[
"year",
Expand Down Expand Up @@ -277,8 +285,9 @@ def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) ->
return dtypes.Duration(du_time_unit)
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
if dtype.startswith("decimal") and dtype.endswith("[pyarrow]"):
return dtypes.Decimal()
if match_ := PATTERN_PA_DECIMAL.match(dtype):
precision, scale = int(match_.group("precision")), int(match_.group("scale"))
return dtypes.Decimal(precision, scale)
if dtype.startswith("time") and dtype.endswith("[pyarrow]"):
return dtypes.Time()
if dtype.startswith("binary") and dtype.endswith("[pyarrow]"):
Expand Down Expand Up @@ -476,7 +485,6 @@ def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]:
None: "bool",
},
}
UNSUPPORTED_DTYPES = (dtypes.Decimal,)


def narwhals_to_native_dtype( # noqa: C901, PLR0912
Expand Down Expand Up @@ -568,12 +576,17 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if issubclass(
base_type, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
base_type,
(
dtypes.Struct,
dtypes.Array,
dtypes.List,
dtypes.Time,
dtypes.Binary,
dtypes.Decimal,
),
):
return narwhals_to_native_arrow_dtype(dtype, implementation, version)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for {implementation}."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
10 changes: 4 additions & 6 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
if dtype == pl.Decimal:
return dtypes.Decimal()
if isinstance_or_issubclass(dtype, pl.Decimal):
return dtypes.Decimal(dtype.precision, dtype.scale)
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
Expand Down Expand Up @@ -200,7 +200,6 @@ def _version_dependent_dtypes() -> dict[type[DType], pl.DataType]:
dtypes.Unknown: pl.Unknown(),
**_version_dependent_dtypes(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal,)


def narwhals_to_native_dtype( # noqa: C901
Expand Down Expand Up @@ -234,9 +233,8 @@ def narwhals_to_native_dtype( # noqa: C901
size = dtype.size
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Decimal):
return pl.Decimal(dtype.precision, dtype.scale)
return pl.Unknown() # pragma: no cover


Expand Down
8 changes: 5 additions & 3 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
if isinstance(dtype, native.TimestampType):
return dtypes.Datetime(time_zone=fetch_session_time_zone(session))
if isinstance(dtype, native.DecimalType):
# TODO(marco): cover this
return dtypes.Decimal() # pragma: no cover
return dtypes.Decimal(precision=dtype.precision, scale=dtype.scale)
if isinstance(dtype, native.ArrayType):
return dtypes.List(
inner=native_to_narwhals_dtype(
Expand Down Expand Up @@ -156,7 +155,7 @@ def fetch_session_time_zone(session: SparkSession) -> str:
)


def narwhals_to_native_dtype(
def narwhals_to_native_dtype( # noqa: C901
dtype: IntoDType, version: Version, spark_types: ModuleType, session: SparkSession
) -> _NativeDType:
dtypes = version.dtypes
Expand Down Expand Up @@ -195,6 +194,9 @@ def narwhals_to_native_dtype(
for field in dtype.fields
]
)
if isinstance_or_issubclass(dtype, dtypes.Decimal): # pragma: no cover
return native.DecimalType(precision=dtype.precision, scale=dtype.scale)

if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
msg = f"Converting to {base_type.__name__} dtype is not supported for Spark-Like backend."
raise UnsupportedDTypeError(msg)
Expand Down
4 changes: 2 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2614,7 +2614,7 @@ def schema(self) -> Schema:
>>> import narwhals as nw
>>> lf_native = duckdb.sql("SELECT * FROM VALUES (1, 4.5), (3, 2.) df(a, b)")
>>> nw.from_native(lf_native).schema # doctest:+SKIP
Schema({'a': Int32, 'b': Decimal})
Schema({'a': Int32, 'b': Decimal(precision=2, scale=1)})
"""
if self._compliant_frame._version is not Version.V1:
msg = (
Expand All @@ -2632,7 +2632,7 @@ def collect_schema(self) -> Schema:
>>> import narwhals as nw
>>> lf_native = duckdb.sql("SELECT * FROM VALUES (1, 4.5), (3, 2.) df(a, b)")
>>> nw.from_native(lf_native).collect_schema()
Schema({'a': Int32, 'b': Decimal})
Schema({'a': Int32, 'b': Decimal(precision=2, scale=1)})
"""
return super().collect_schema()

Expand Down
49 changes: 47 additions & 2 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
isinstance_or_issubclass,
qualified_type_name,
)
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -205,16 +206,60 @@ class NestedType(DType):


class Decimal(NumericType):
"""Decimal type.
"""Decimal 128-bit type with an optional precision and non-negative scale.

Arguments:
precision: Maximum number of digits in each number. If set to `None` (default),
the precision is set to 38 (the maximum supported by Polars).
scale: Number of digits to the right of the decimal point in each number.

Examples:
>>> import polars as pl
>>> import narwhals as nw
>>> s = pl.Series(["1.5"], dtype=pl.Decimal)
>>> nw.from_native(s, series_only=True).dtype
Decimal
Decimal(precision=2, scale=1)
"""

__slots__ = ("precision", "scale")

precision: int
scale: int

# !NOTE: Reason for `precision: int | None = None` rather than `precision: int = 38`
# is to mirror polars signature https://github.com/pola-rs/polars/blob/bb79993c3aa91d0db7d20be8f75c8075cad97067/py-polars/src/polars/datatypes/classes.py#L450-L454
def __init__(self, precision: int | None = None, scale: int = 0) -> None:
precision = 38 if precision is None else precision

if not ((is_int := isinstance(precision, int)) and 0 <= precision <= 38):
msg = f"precision must be a positive integer between 0 and 38, found {precision!r}"
raise ValueError(msg) if is_int else TypeError(msg)

if not ((is_int := isinstance(scale, int)) and scale >= 0):
msg = f"scale must be a positive integer, found {scale!r}"
raise ValueError(msg) if is_int else TypeError(msg)

if scale > precision:
msg = "scale must be less than or equal to precision"
raise InvalidOperationError(msg)

self.precision = precision
self.scale = scale

def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
return (other is Decimal) or (
isinstance(other, self.__class__)
and self.precision == other.precision
and self.scale == other.scale
)

def __hash__(self) -> int: # pragma: no cover
return hash((self.__class__, self.precision, self.scale))

def __repr__(self) -> str: # pragma: no cover
class_name = self.__class__.__name__
return f"{class_name}(precision={self.precision!r}, scale={self.scale!r})"


class Int128(SignedIntegerType):
"""128-bit signed integer type.
Expand Down
Loading
Loading