diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index bb9921b97d..5a678b5bf4 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -187,8 +187,8 @@ def native_non_extension_to_narwhals_dtype(dtype: pa.DataType, version: Version) return dtypes.Array( native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size ) - if pa.types.is_decimal(dtype): - return dtypes.Decimal() + if pa.types.is_decimal128(dtype): + return dtypes.Decimal(precision=dtype.precision, scale=dtype.scale) if pa.types.is_time32(dtype) or pa.types.is_time64(dtype): return dtypes.Time() if pa.types.is_binary(dtype): @@ -215,7 +215,7 @@ def native_non_extension_to_narwhals_dtype(dtype: pa.DataType, version: Version) dtypes.UInt32: pa.uint32(), dtypes.UInt64: pa.uint64(), } -UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object) +UNSUPPORTED_DTYPES = (dtypes.Object,) def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: @@ -237,10 +237,12 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: for field in dtype.fields ] ) - if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + if isinstance_or_issubclass(dtype, dtypes.Array): inner = narwhals_to_native_dtype(dtype.inner, version=version) list_size = dtype.size return pa.list_(inner, list_size=list_size) + if isinstance_or_issubclass(dtype, dtypes.Decimal): + return pa.decimal128(precision=dtype.precision, scale=dtype.scale) if issubclass(base_type, UNSUPPORTED_DTYPES): msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow." raise NotImplementedError(msg) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index cf5d04eae3..6f99fcd6ca 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -100,6 +100,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None: dtypes.Array, dtypes.Time, dtypes.Binary, + dtypes.Decimal, ) diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 37d86a88ba..800f16043a 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -139,7 +139,7 @@ def native_to_narwhals_dtype( version: Version, deferred_time_zone: DeferredTimeZone, ) -> DType: - duckdb_dtype_id = duckdb_dtype.id + duckdb_dtype_id: str = duckdb_dtype.id dtypes = version.dtypes # Handle nested data types first @@ -180,6 +180,10 @@ def native_to_narwhals_dtype( if duckdb_dtype_id == "timestamp with time zone": return dtypes.Datetime(time_zone=deferred_time_zone.time_zone) + if duckdb_dtype_id == "decimal": + (_, precision), (_, scale) = duckdb_dtype.children + return dtypes.Decimal(precision=precision, scale=scale) + return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version) @@ -215,7 +219,6 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) "timestamp_ns": dtypes.Datetime("ns"), "boolean": dtypes.Boolean(), "interval": dtypes.Duration(), - "decimal": dtypes.Decimal(), "time": dtypes.Time(), "blob": dtypes.Binary(), }.get(duckdb_dtype_id, dtypes.Unknown()) @@ -247,7 +250,7 @@ def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) "us": duckdb_dtypes.TIMESTAMP, "ns": duckdb_dtypes.TIMESTAMP_NS, } -UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Categorical) +UNSUPPORTED_DTYPES = dtypes.Categorical def narwhals_to_native_dtype( # noqa: PLR0912, C901 @@ -298,6 +301,8 @@ def narwhals_to_native_dtype( # noqa: PLR0912, C901 duckdb_inner = narwhals_to_native_dtype(nw_inner, version, deferred_time_zone) duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape) return duckdb_dtypes.DuckDBPyType(f"{duckdb_inner}{duckdb_shape_fmt}") + if isinstance(dtype, dtypes.Decimal): + return duckdb_dtypes.DuckDBPyType(f"DECIMAL({dtype.precision}, {dtype.scale})") if issubclass(base_type, UNSUPPORTED_DTYPES): msg = f"Converting to {base_type.__name__} dtype is not supported for DuckDB." raise NotImplementedError(msg) diff --git a/narwhals/_ibis/utils.py b/narwhals/_ibis/utils.py index 09a8ce2ffb..eed667efec 100644 --- a/narwhals/_ibis/utils.py +++ b/narwhals/_ibis/utils.py @@ -173,8 +173,15 @@ def native_to_narwhals_dtype(ibis_dtype: IbisDataType, version: Version) -> DTyp for name, dtype in ibis_dtype.items() ] ) - if ibis_dtype.is_decimal(): # pragma: no cover - return dtypes.Decimal() + if is_decimal(ibis_dtype): + # Same default as in ibis.Decimal.{to_polars, to_pyarrow} + # https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/polars.py#L82-L86 + # https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/pyarrow.py#L165-L177 + # According to their own comment: + # > set default precision and scale to something; unclear how to choose this + # source: (see https://github.com/ibis-project/ibis/blob/5028d8e5e1921d5de48f59fad48b86c0de541b0d/ibis/formats/pyarrow.py#L166) + scale = 9 if ibis_dtype.scale is None else ibis_dtype.scale + return dtypes.Decimal(precision=ibis_dtype.precision, scale=scale) if ibis_dtype.is_time(): return dtypes.Time() if ibis_dtype.is_binary(): @@ -202,6 +209,10 @@ def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]: return obj.is_floating() +def is_decimal(obj: IbisDataType) -> TypeIs[ibis_dtypes.Decimal]: + return obj.is_decimal() + + dtypes = Version.MAIN.dtypes NW_TO_IBIS_DTYPES: Mapping[type[DType], IbisDataType] = { dtypes.Float64: ibis_dtypes.Float64(), @@ -219,7 +230,6 @@ def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]: dtypes.UInt16: ibis_dtypes.UInt16(), dtypes.UInt32: ibis_dtypes.UInt32(), dtypes.UInt64: ibis_dtypes.UInt64(), - dtypes.Decimal: ibis_dtypes.Decimal(), } # Enum support: https://github.com/ibis-project/ibis/issues/10991 UNSUPPORTED_DTYPES = (dtypes.Int128, dtypes.UInt128, dtypes.Categorical, dtypes.Enum) @@ -246,6 +256,8 @@ def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> IbisDataType if isinstance_or_issubclass(dtype, dtypes.Array): inner = narwhals_to_native_dtype(dtype.inner, version) return ibis_dtypes.Array(value_type=inner, length=dtype.size) + if isinstance_or_issubclass(dtype, dtypes.Decimal): + return ibis_dtypes.Decimal(dtype.precision, dtype.scale) if issubclass(base_type, UNSUPPORTED_DTYPES): msg = f"Converting to {base_type.__name__} dtype is not supported for Ibis." raise NotImplementedError(msg) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2777edb6f7..0b4733175c 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -94,7 +94,6 @@ (?Ps|ms|us|ns) # Match time unit: s, ms, us, or ns \] # Closing bracket for timedelta64 $""" - PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE) PA_DURATION_RGX = r"""^ duration\[ @@ -103,6 +102,15 @@ \[pyarrow\] # Literal string "[pyarrow]" $""" PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE) +PA_DECIMAL_RGX = r"""^ + decimal128\( + (?P\d+) # Precision value + ,\s # Literal string ", " + (?P\d+) # Scale value + \) + \[pyarrow\] # Literal string "[pyarrow]" +$""" +PATTERN_PA_DECIMAL = re.compile(PA_DECIMAL_RGX, re.VERBOSE) NativeIntervalUnit: TypeAlias = Literal[ "year", @@ -277,8 +285,9 @@ def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> return dtypes.Duration(du_time_unit) if dtype == "date32[day][pyarrow]": return dtypes.Date() - if dtype.startswith("decimal") and dtype.endswith("[pyarrow]"): - return dtypes.Decimal() + if match_ := PATTERN_PA_DECIMAL.match(dtype): + precision, scale = int(match_.group("precision")), int(match_.group("scale")) + return dtypes.Decimal(precision, scale) if dtype.startswith("time") and dtype.endswith("[pyarrow]"): return dtypes.Time() if dtype.startswith("binary") and dtype.endswith("[pyarrow]"): @@ -476,7 +485,6 @@ def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]: None: "bool", }, } -UNSUPPORTED_DTYPES = (dtypes.Decimal,) def narwhals_to_native_dtype( # noqa: C901, PLR0912 @@ -568,12 +576,17 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912 msg = "Can not cast / initialize Enum without categories present" raise ValueError(msg) if issubclass( - base_type, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary) + base_type, + ( + dtypes.Struct, + dtypes.Array, + dtypes.List, + dtypes.Time, + dtypes.Binary, + dtypes.Decimal, + ), ): return narwhals_to_native_arrow_dtype(dtype, implementation, version) - if issubclass(base_type, UNSUPPORTED_DTYPES): - msg = f"Converting to {base_type.__name__} dtype is not supported for {implementation}." - raise NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 7d6c702297..1011f7ce93 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -160,8 +160,8 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912 if isinstance_or_issubclass(dtype, pl.Array): outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape) - if dtype == pl.Decimal: - return dtypes.Decimal() + if isinstance_or_issubclass(dtype, pl.Decimal): + return dtypes.Decimal(dtype.precision, dtype.scale) if dtype == pl.Time: return dtypes.Time() if dtype == pl.Binary: @@ -200,7 +200,6 @@ def _version_dependent_dtypes() -> dict[type[DType], pl.DataType]: dtypes.Unknown: pl.Unknown(), **_version_dependent_dtypes(), } -UNSUPPORTED_DTYPES = (dtypes.Decimal,) def narwhals_to_native_dtype( # noqa: C901 @@ -234,9 +233,8 @@ def narwhals_to_native_dtype( # noqa: C901 size = dtype.size kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size} return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs) - if issubclass(base_type, UNSUPPORTED_DTYPES): - msg = f"Converting to {base_type.__name__} dtype is not supported for Polars." - raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Decimal): + return pl.Decimal(dtype.precision, dtype.scale) return pl.Unknown() # pragma: no cover diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index b26e071efd..e0e0c8c857 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -96,8 +96,7 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912 if isinstance(dtype, native.TimestampType): return dtypes.Datetime(time_zone=fetch_session_time_zone(session)) if isinstance(dtype, native.DecimalType): - # TODO(marco): cover this - return dtypes.Decimal() # pragma: no cover + return dtypes.Decimal(precision=dtype.precision, scale=dtype.scale) if isinstance(dtype, native.ArrayType): return dtypes.List( inner=native_to_narwhals_dtype( @@ -156,7 +155,7 @@ def fetch_session_time_zone(session: SparkSession) -> str: ) -def narwhals_to_native_dtype( +def narwhals_to_native_dtype( # noqa: C901 dtype: IntoDType, version: Version, spark_types: ModuleType, session: SparkSession ) -> _NativeDType: dtypes = version.dtypes @@ -195,6 +194,9 @@ def narwhals_to_native_dtype( for field in dtype.fields ] ) + if isinstance_or_issubclass(dtype, dtypes.Decimal): # pragma: no cover + return native.DecimalType(precision=dtype.precision, scale=dtype.scale) + if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover msg = f"Converting to {base_type.__name__} dtype is not supported for Spark-Like backend." raise UnsupportedDTypeError(msg) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 827932c05e..63b5b74259 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -2614,7 +2614,7 @@ def schema(self) -> Schema: >>> import narwhals as nw >>> lf_native = duckdb.sql("SELECT * FROM VALUES (1, 4.5), (3, 2.) df(a, b)") >>> nw.from_native(lf_native).schema # doctest:+SKIP - Schema({'a': Int32, 'b': Decimal}) + Schema({'a': Int32, 'b': Decimal(precision=2, scale=1)}) """ if self._compliant_frame._version is not Version.V1: msg = ( @@ -2632,7 +2632,7 @@ def collect_schema(self) -> Schema: >>> import narwhals as nw >>> lf_native = duckdb.sql("SELECT * FROM VALUES (1, 4.5), (3, 2.) df(a, b)") >>> nw.from_native(lf_native).collect_schema() - Schema({'a': Int32, 'b': Decimal}) + Schema({'a': Int32, 'b': Decimal(precision=2, scale=1)}) """ return super().collect_schema() diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index b433bd6fd5..587b6c2758 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -12,6 +12,7 @@ isinstance_or_issubclass, qualified_type_name, ) +from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -205,16 +206,60 @@ class NestedType(DType): class Decimal(NumericType): - """Decimal type. + """Decimal 128-bit type with an optional precision and non-negative scale. + + Arguments: + precision: Maximum number of digits in each number. If set to `None` (default), + the precision is set to 38 (the maximum supported by Polars). + scale: Number of digits to the right of the decimal point in each number. Examples: >>> import polars as pl >>> import narwhals as nw >>> s = pl.Series(["1.5"], dtype=pl.Decimal) >>> nw.from_native(s, series_only=True).dtype - Decimal + Decimal(precision=2, scale=1) """ + __slots__ = ("precision", "scale") + + precision: int + scale: int + + # !NOTE: Reason for `precision: int | None = None` rather than `precision: int = 38` + # is to mirror polars signature https://github.com/pola-rs/polars/blob/bb79993c3aa91d0db7d20be8f75c8075cad97067/py-polars/src/polars/datatypes/classes.py#L450-L454 + def __init__(self, precision: int | None = None, scale: int = 0) -> None: + precision = 38 if precision is None else precision + + if not ((is_int := isinstance(precision, int)) and 0 <= precision <= 38): + msg = f"precision must be a positive integer between 0 and 38, found {precision!r}" + raise ValueError(msg) if is_int else TypeError(msg) + + if not ((is_int := isinstance(scale, int)) and scale >= 0): + msg = f"scale must be a positive integer, found {scale!r}" + raise ValueError(msg) if is_int else TypeError(msg) + + if scale > precision: + msg = "scale must be less than or equal to precision" + raise InvalidOperationError(msg) + + self.precision = precision + self.scale = scale + + def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] + return (other is Decimal) or ( + isinstance(other, self.__class__) + and self.precision == other.precision + and self.scale == other.scale + ) + + def __hash__(self) -> int: # pragma: no cover + return hash((self.__class__, self.precision, self.scale)) + + def __repr__(self) -> str: # pragma: no cover + class_name = self.__class__.__name__ + return f"{class_name}(precision={self.precision!r}, scale={self.scale!r})" + class Int128(SignedIntegerType): """128-bit signed integer type. diff --git a/tests/dtypes/dtypes_test.py b/tests/dtypes/dtypes_test.py index 7fb5cc29f2..83e367c08d 100644 --- a/tests/dtypes/dtypes_test.py +++ b/tests/dtypes/dtypes_test.py @@ -8,7 +8,7 @@ import pytest import narwhals as nw -from narwhals.exceptions import PerformanceWarning +from narwhals.exceptions import InvalidOperationError, PerformanceWarning from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION, pyspark_session if TYPE_CHECKING: @@ -407,38 +407,56 @@ def test_huge_int_to_native() -> None: assert type_a_unit == "UHUGEINT" -def test_cast_decimal_to_native() -> None: - pytest.importorskip("duckdb") - pytest.importorskip("pandas") - pytest.importorskip("polars") - pytest.importorskip("pyarrow") +@pytest.mark.parametrize( + ("precision", "scale"), [(None, 1), (None, 20), (2, 1), (10, 1), (10, 8)] +) +def test_cast_decimal_to_native( + request: pytest.FixtureRequest, + constructor: Constructor, + precision: int | None, + scale: int, +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail(reason="Unsupported dtype")) - import duckdb - import pandas as pd - import polars as pl - import pyarrow as pa + if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0): + pytest.skip(reason="too old to convert to decimal") - data = {"a": [1, 2, 3]} - - df = pl.DataFrame(data) - library_obj_to_test = [ - df, - duckdb.sql(""" - select cast(a as INT1) as a - from df - """), - pd.DataFrame(data), - pa.Table.from_arrays( - [pa.array(data["a"])], schema=pa.schema([("a", pa.int64())]) - ), - ] - for obj in library_obj_to_test: - with pytest.raises(NotImplementedError, match=r"to.+Decimal.+not supported."): - ( - nw.from_native(obj) # type: ignore[call-overload] - .with_columns(a=nw.col("a").cast(nw.Decimal())) - .to_native() - ) + data = {"a": [1.1, 2.2, 3.3]} + + df = nw.from_native(constructor(data)) + + if df.implementation.is_pandas_like() and ( + PYARROW_VERSION == (0, 0, 0) or PANDAS_VERSION < (2, 2) + ): + pytest.skip(reason="pyarrow is required to convert to decimal dtype") + + native_result = df.with_columns( + a=nw.col("a").cast(nw.Decimal(precision=precision, scale=scale)) + ).to_native() + + schema = nw.from_native(native_result).collect_schema() + assert schema["a"] == nw.Decimal(precision, scale) + + +@pytest.mark.parametrize( + ("precision", "scale", "exception", "msg"), + [ + (2.1, 0, TypeError, "precision must be a positive integer between 0 and 38"), + ("foo", 0, TypeError, "precision must be a positive integer between 0 and 38"), + (-1, 0, ValueError, "precision must be a positive integer between 0 and 38"), + (39, 0, ValueError, "precision must be a positive integer between 0 and 38"), + (None, 2.1, TypeError, "scale must be a positive integer"), + (None, "foo", TypeError, "scale must be a positive integer"), + (None, -1, ValueError, "scale must be a positive integer"), + (2, 3, InvalidOperationError, "scale must be less than or equal to precision"), + ], +) +def test_decimal_invalid( + precision: int | None, scale: int, exception: type[Exception], msg: str +) -> None: + with pytest.raises(exception, match=msg): + nw.Decimal(precision=precision, scale=scale) @pytest.mark.parametrize(