Skip to content

Commit

Permalink
Add proof of concept dask-friendly datetime encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerkclark committed Dec 30, 2023
1 parent 03ec3cb commit e5150c9
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 2 deletions.
88 changes: 86 additions & 2 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.duck_array_ops import asarray
from xarray.core.formatting import first_n_items, format_timestamp, last_item
from xarray.core.pdcompat import nanosecond_precision_timestamp
from xarray.core.pycompat import is_duck_dask_array
Expand Down Expand Up @@ -672,6 +673,43 @@ def encode_cf_datetime(
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
):
dates = asarray(dates)
if isinstance(dates, np.ndarray):
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
elif is_duck_dask_array(dates):
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)


def _cast_to_dtype_safe(num, dtype) -> np.ndarray:
cast_num = np.asarray(num, dtype=dtype)

if np.issubdtype(dtype, np.integer):
if not (num == cast_num).all():
raise ValueError(
f"Not possible to cast all encoded times from dtype {num.dtype!r} "
f"to dtype {dtype!r} without changing any of their values. "
f"Consider removing the dtype encoding or explicitly switching to "
f"a dtype encoding with a higher precision."
)
else:
if np.isinf(cast_num).any():
raise OverflowError(
f"Not possible to cast encoded times from dtype {num.dtype!r} "
f"to dtype {dtype!r} without overflow. Consider removing the "
f"dtype encoding or explicitly switching to a dtype encoding "
f"with a higher precision."
)

return cast_num


def _eagerly_encode_cf_datetime(
dates,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
called_via_map_blocks: bool = False,
) -> tuple[np.ndarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
Expand Down Expand Up @@ -731,7 +769,7 @@ def encode_cf_datetime(
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
elif np.issubdtype(dtype, np.integer) and not called_via_map_blocks:
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
Expand All @@ -752,7 +790,53 @@ def encode_cf_datetime(
# we already covered for this in pandas-based flow
num = cast_to_int_if_safe(num)

return (num, units, calendar)
if dtype is not None:
num = _cast_to_dtype_safe(num, dtype)

if called_via_map_blocks:
return num
else:
return (num, units, calendar)


def _lazily_encode_cf_datetime(
dates,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
):
import dask.array

if calendar is None:
# This will only trigger minor compute if dates is an object dtype array.
calendar = infer_calendar_name(dates)

if units is None and dtype is None:
if dates.dtype == "O":
units = "microseconds since 1970-01-01"
dtype = np.dtype("int64")
else:
units = "nanoseconds since 1970-01-01"
dtype = np.dtype("int64")

if units is None or dtype is None:
raise ValueError(
f"When encoding chunked arrays of datetime values, both the units and "
f"dtype must be prescribed or both must be unprescribed. Prescribing "
f"only one or the other is not currently supported. Got a units "
f"encoding of {units} and a dtype encoding of {dtype}."
)

num = dask.array.map_blocks(
_eagerly_encode_cf_datetime,
dates,
units,
calendar,
dtype,
called_via_map_blocks=True,
dtype=dtype,
)
return num, units, calendar


def encode_cf_timedelta(
Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,15 @@ def test_attributes(self, obj) -> None:
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
ds.to_zarr(store_target, **self.version_kwargs)

@requires_dask
def test_chunked_datetime64(self) -> None:
# Copied from @malmans2's PR #8253
original = create_test_data().astype("datetime64[ns]").chunk(1)
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
for name, actual_var in actual.variables.items():
assert original[name].chunks == actual_var.chunks
assert original.chunks == actual.chunks

def test_vectorized_indexing_negative_step(self) -> None:
if not has_dask:
pytest.xfail(
Expand Down
138 changes: 138 additions & 0 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
cftime_range,
coding,
conventions,
date_range,
decode_cf,
)
from xarray.coding.times import (
Expand All @@ -30,6 +31,7 @@
from xarray.coding.variables import SerializationWarning
from xarray.conventions import _update_bounds_attributes, cf_encoder
from xarray.core.common import contains_cftime_datetimes
from xarray.core.pycompat import is_duck_dask_array
from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
FirstElementAccessibleArray,
Expand Down Expand Up @@ -1387,3 +1389,139 @@ def test_roundtrip_float_times() -> None:
assert_identical(var, decoded_var)
assert decoded_var.encoding["units"] == units
assert decoded_var.encoding["_FillValue"] == fill_value


ENCODE_DATETIME64_VIA_DASK_TESTS = {
"pandas-encoding-with-prescribed-units-and-dtype": (
"D",
"days since 1700-01-01",
np.dtype("int32"),
),
"mixed-cftime-pandas-encoding-with-prescribed-units-and-dtype": (
"252YS",
"days since 1700-01-01",
np.dtype("int32"),
),
"pandas-encoding-with-default-units-and-dtype": ("252YS", None, None),
}


@requires_dask
@pytest.mark.parametrize(
("freq", "units", "dtype"),
ENCODE_DATETIME64_VIA_DASK_TESTS.values(),
ids=ENCODE_DATETIME64_VIA_DASK_TESTS.keys(),
)
def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype):
import dask.array

times = pd.date_range(start="1700", freq=freq, periods=3)
times = dask.array.from_array(times, chunks=1)
encoded_times, encoding_units, encoding_calendar = encode_cf_datetime(
times, units, None, dtype
)

assert is_duck_dask_array(encoded_times)
assert encoded_times.chunks == times.chunks

if units is not None and dtype is not None:
assert encoding_units == units
assert encoded_times.dtype == dtype
else:
assert encoding_units == "nanoseconds since 1970-01-01"
assert encoded_times.dtype == np.dtype("int64")

assert encoding_calendar == "proleptic_gregorian"

decoded_times = decode_cf_datetime(encoded_times, encoding_units, encoding_calendar)
np.testing.assert_equal(decoded_times, times)


@requires_dask
@pytest.mark.parametrize(
("units", "dtype"), [(None, np.dtype("int32")), ("2000-01-01", None)]
)
def test_encode_cf_datetime_via_dask_error(units, dtype):
import dask.array

times = pd.date_range(start="1700", freq="D", periods=3)
times = dask.array.from_array(times, chunks=1)

with pytest.raises(ValueError, match="When encoding chunked arrays"):
encode_cf_datetime(times, units, None, dtype)


ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS = {
"prescribed-units-and-dtype": ("D", "days since 1700-01-01", np.dtype("int32")),
"default-units-and-dtype": ("252YS", None, None),
}


@requires_cftime
@requires_dask
@pytest.mark.parametrize(
"calendar",
["standard", "proleptic_gregorian", "julian", "noleap", "all_leap", "360_day"],
)
@pytest.mark.parametrize(
("freq", "units", "dtype"),
ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.values(),
ids=ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.keys(),
)
def test_encode_cf_datetime_cftime_datetime_via_dask(calendar, freq, units, dtype):
import dask.array

times = cftime_range(start="1700", freq=freq, periods=3, calendar=calendar)
times = dask.array.from_array(times, chunks=1)
encoded_times, encoding_units, encoding_calendar = encode_cf_datetime(
times, units, None, dtype
)

assert is_duck_dask_array(encoded_times)
assert encoded_times.chunks == times.chunks

if units is not None and dtype is not None:
assert encoding_units == units
assert encoded_times.dtype == dtype
else:
assert encoding_units == "microseconds since 1970-01-01"
assert encoded_times.dtype == np.int64

assert encoding_calendar == calendar

decoded_times = decode_cf_datetime(
encoded_times, encoding_units, encoding_calendar, use_cftime=True
)
np.testing.assert_equal(decoded_times, times)


@requires_dask
@pytest.mark.parametrize(
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
)
def test_encode_cf_datetime_via_dask_casting_value_error(use_cftime):
import dask.array

times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime)
times = dask.array.from_array(times, chunks=1)
units = "days since 2000-01-01"
dtype = np.int64
encoded_times, *_ = encode_cf_datetime(times, units, None, dtype)
with pytest.raises(ValueError, match="Not possible"):
encoded_times.compute()


@requires_dask
@pytest.mark.parametrize(
"use_cftime", [False, pytest.param(True, marks=requires_cftime)]
)
def test_encode_cf_datetime_via_dask_casting_overflow_error(use_cftime):
import dask.array

times = date_range(start="1700", freq="252YS", periods=3, use_cftime=use_cftime)
times = dask.array.from_array(times, chunks=1)
units = "days since 1700-01-01"
dtype = np.dtype("float16")
encoded_times, *_ = encode_cf_datetime(times, units, None, dtype)
with pytest.raises(OverflowError, match="Not possible"):
encoded_times.compute()

0 comments on commit e5150c9

Please sign in to comment.