diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 5fbb8a8d7c6a..5d5fb5c07f54 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -23,6 +23,7 @@ from ray.util.annotations import DeveloperAPI, PublicAPI if TYPE_CHECKING: + from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace from ray.data.namespace_expressions.list_namespace import _ListNamespace from ray.data.namespace_expressions.string_namespace import _StringNamespace from ray.data.namespace_expressions.struct_namespace import _StructNamespace @@ -486,6 +487,13 @@ def struct(self) -> "_StructNamespace": return _StructNamespace(self) + @property + def dt(self) -> "_DatetimeNamespace": + """Access datetime operations for this expression.""" + from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace + + return _DatetimeNamespace(self) + def _unalias(self) -> "Expr": return self @@ -1061,6 +1069,7 @@ def download(uri_column_name: str) -> DownloadExpr: "_ListNamespace", "_StringNamespace", "_StructNamespace", + "_DatetimeNamespace", ] @@ -1078,4 +1087,8 @@ def __getattr__(name: str): from ray.data.namespace_expressions.struct_namespace import _StructNamespace return _StructNamespace + elif name == "_DatetimeNamespace": + from ray.data.namespace_expressions.dt_namespace import _DatetimeNamespace + + return _DatetimeNamespace raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/ray/data/namespace_expressions/dt_namespace.py b/python/ray/data/namespace_expressions/dt_namespace.py new file mode 100644 index 000000000000..851db899f7b9 --- /dev/null +++ b/python/ray/data/namespace_expressions/dt_namespace.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Literal + +import pyarrow +import pyarrow.compute as pc + +from ray.data.datatype import DataType +from ray.data.expressions import pyarrow_udf + +if TYPE_CHECKING: + from ray.data.expressions import Expr, UDFExpr + +TemporalUnit = Literal[ + "year", + "quarter", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond", + "nanosecond", +] + + +@dataclass +class _DatetimeNamespace: + """Datetime namespace for operations on datetime-typed expression columns.""" + + _expr: "Expr" + + def _unary_temporal_int( + self, func: Callable[[pyarrow.Array], pyarrow.Array] + ) -> "UDFExpr": + """Helper for year/month/… that return int32.""" + + @pyarrow_udf(return_dtype=DataType.int32()) + def _udf(arr: pyarrow.Array) -> pyarrow.Array: + return func(arr) + + return _udf(self._expr) + + # extractors + + def year(self) -> "UDFExpr": + """Extract year component.""" + return self._unary_temporal_int(pc.year) + + def month(self) -> "UDFExpr": + """Extract month component.""" + return self._unary_temporal_int(pc.month) + + def day(self) -> "UDFExpr": + """Extract day component.""" + return self._unary_temporal_int(pc.day) + + def hour(self) -> "UDFExpr": + """Extract hour component.""" + return self._unary_temporal_int(pc.hour) + + def minute(self) -> "UDFExpr": + """Extract minute component.""" + return self._unary_temporal_int(pc.minute) + + def second(self) -> "UDFExpr": + """Extract second component.""" + return self._unary_temporal_int(pc.second) + + # formatting + + def strftime(self, fmt: str) -> "UDFExpr": + """Format timestamps with a strftime pattern.""" + + @pyarrow_udf(return_dtype=DataType.string()) + def _format(arr: pyarrow.Array) -> pyarrow.Array: + return pc.strftime(arr, format=fmt) + + return _format(self._expr) + + # rounding + + def ceil(self, unit: TemporalUnit) -> "UDFExpr": + """Ceil timestamps to the next multiple of the given unit.""" + return_dtype = self._expr.data_type + + @pyarrow_udf(return_dtype=return_dtype) + def _ceil(arr: pyarrow.Array) -> pyarrow.Array: + return pc.ceil_temporal(arr, multiple=1, unit=unit) + + return _ceil(self._expr) + + def floor(self, unit: TemporalUnit) -> "UDFExpr": + """Floor timestamps to the previous multiple of the given unit.""" + return_dtype = self._expr.data_type + + @pyarrow_udf(return_dtype=return_dtype) + def _floor(arr: pyarrow.Array) -> pyarrow.Array: + return pc.floor_temporal(arr, multiple=1, unit=unit) + + return _floor(self._expr) + + def round(self, unit: TemporalUnit) -> "UDFExpr": + """Round timestamps to the nearest multiple of the given unit.""" + return_dtype = self._expr.data_type + + @pyarrow_udf(return_dtype=return_dtype) + def _round(arr: pyarrow.Array) -> pyarrow.Array: + + return pc.round_temporal(arr, multiple=1, unit=unit) + + return _round(self._expr) diff --git a/python/ray/data/tests/test_namespace_expressions.py b/python/ray/data/tests/test_namespace_expressions.py index 0129b7a5609c..ef3510582d97 100644 --- a/python/ray/data/tests/test_namespace_expressions.py +++ b/python/ray/data/tests/test_namespace_expressions.py @@ -4,6 +4,7 @@ convenient access to PyArrow compute functions through the expression API. """ +import datetime from typing import Any import pandas as pd @@ -457,7 +458,9 @@ def test_struct_nested_field(self, dataset_format): ) items_data = [ {"user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}}, - {"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}}, + { + "user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}, + }, ] ds = _create_dataset(items_data, dataset_format, arrow_table) @@ -504,7 +507,9 @@ def test_struct_nested_bracket(self, dataset_format): ) items_data = [ {"user": {"name": "Alice", "address": {"city": "NYC", "zip": "10001"}}}, - {"user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}}, + { + "user": {"name": "Bob", "address": {"city": "LA", "zip": "90001"}}, + }, ] ds = _create_dataset(items_data, dataset_format, arrow_table) @@ -523,6 +528,64 @@ def test_struct_nested_bracket(self, dataset_format): assert rows_same(result, expected) +# ────────────────────────────────────── +# Datetime Namespace Tests +# ────────────────────────────────────── + + +def test_datetime_namespace_all_operations(ray_start_regular): + """Test all datetime namespace operations on a datetime column.""" + + ts = datetime.datetime(2024, 1, 2, 10, 30, 0) + + ds = ray.data.from_items([{"ts": ts}]) + + result_ds = ds.select( + [ + col("ts").dt.year().alias("year"), + col("ts").dt.month().alias("month"), + col("ts").dt.day().alias("day"), + col("ts").dt.hour().alias("hour"), + col("ts").dt.minute().alias("minute"), + col("ts").dt.second().alias("second"), + col("ts").dt.strftime("%Y-%m-%d").alias("date_str"), + col("ts").dt.floor("day").alias("ts_floor"), + col("ts").dt.ceil("day").alias("ts_ceil"), + col("ts").dt.round("day").alias("ts_round"), + ] + ) + + actual = result_ds.to_pandas() + + expected = pd.DataFrame( + [ + { + "year": 2024, + "month": 1, + "day": 2, + "hour": 10, + "minute": 30, + "second": 0, + "date_str": "2024-01-02", + "ts_floor": datetime.datetime(2024, 1, 2, 0, 0, 0), + "ts_ceil": datetime.datetime(2024, 1, 3, 0, 0, 0), + "ts_round": datetime.datetime(2024, 1, 3, 0, 0, 0), + } + ] + ) + + assert rows_same(actual, expected) + + +def test_dt_namespace_invalid_dtype_raises(ray_start_regular): + """Test that dt namespace on non-datetime column raises an error.""" + + ds = ray.data.from_items([{"value": 1}]) + + with pytest.raises(Exception): + ds.select(col("value").dt.year()).to_pandas() + + # ────────────────────────────────────── # Integration Tests # ──────────────────────────────────────