Skip to content

Commit

Permalink
Add dask support for timedelta encoding and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerkclark committed Dec 31, 2023
1 parent e5150c9 commit f0b9a8d
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 72 deletions.
115 changes: 89 additions & 26 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,31 +674,54 @@ def encode_cf_datetime(
calendar: str | None = None,
dtype: np.dtype | None = None,
):
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
Unlike `date2num`, this function can handle datetime64 arrays.
See Also
--------
cftime.date2num
"""
dates = asarray(dates)
if isinstance(dates, np.ndarray):
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
elif is_duck_dask_array(dates):
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)


def _cast_to_dtype_safe(num, dtype) -> np.ndarray:
cast_num = np.asarray(num, dtype=dtype)
def _cast_to_dtype_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="overflow")
cast_num = np.asarray(num, dtype=dtype)

if np.issubdtype(dtype, np.integer):
if not (num == cast_num).all():
raise ValueError(
f"Not possible to cast all encoded times from dtype {num.dtype!r} "
f"to dtype {dtype!r} without changing any of their values. "
f"Consider removing the dtype encoding or explicitly switching to "
f"a dtype encoding with a higher precision."
)
if np.issubdtype(num.dtype, np.floating):
raise ValueError(
f"Not possible to cast all encoded times from dtype "
f"{num.dtype!r} to integer dtype {dtype!r} without losing "
f"precision. Consider modifying the units such that "
f"integer values can be used, or removing the units and "
f"dtype encoding, at which point xarray will make an "
f"appropriate choice."
)
else:
raise OverflowError(
f"Not possible to cast encoded times from dtype "
f"{num.dtype!r} to dtype {dtype!r} without overflow. "
f"Consider removing the dtype encoding, at which point "
f"xarray will make an appropriate choice, or explicitly "
f"switching to a larger integer dtype."
)
else:
if np.isinf(cast_num).any():
raise OverflowError(
f"Not possible to cast encoded times from dtype {num.dtype!r} "
f"to dtype {dtype!r} without overflow. Consider removing the "
f"dtype encoding or explicitly switching to a dtype encoding "
f"with a higher precision."
f"dtype encoding, at which point xarray will make an "
f"appropriate choice, or explicitly switching to a larger "
f"floating point dtype."
)

return cast_num
Expand All @@ -711,15 +734,6 @@ def _eagerly_encode_cf_datetime(
dtype: np.dtype | None = None,
called_via_map_blocks: bool = False,
) -> tuple[np.ndarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
Unlike `date2num`, this function can handle datetime64 arrays.
See Also
--------
cftime.date2num
"""
dates = np.asarray(dates)

data_units = infer_datetime_units(dates)
Expand Down Expand Up @@ -796,7 +810,7 @@ def _eagerly_encode_cf_datetime(
if called_via_map_blocks:
return num
else:
return (num, units, calendar)
return num, units, calendar


def _lazily_encode_cf_datetime(
Expand All @@ -821,10 +835,10 @@ def _lazily_encode_cf_datetime(

if units is None or dtype is None:
raise ValueError(
f"When encoding chunked arrays of datetime values, both the units and "
f"dtype must be prescribed or both must be unprescribed. Prescribing "
f"only one or the other is not currently supported. Got a units "
f"encoding of {units} and a dtype encoding of {dtype}."
f"When encoding chunked arrays of datetime values, both the units "
f"and dtype must be prescribed or both must be unprescribed. "
f"Prescribing only one or the other is not currently supported. "
f"Got a units encoding of {units} and a dtype encoding of {dtype}."
)

num = dask.array.map_blocks(
Expand All @@ -841,6 +855,19 @@ def _lazily_encode_cf_datetime(

def encode_cf_timedelta(
timedeltas, units: str | None = None, dtype: np.dtype | None = None
):
timedeltas = asarray(timedeltas)
if is_duck_dask_array(timedeltas):
return _lazily_encode_cf_timedelta(timedeltas, units, dtype)
else:
return _eagerly_encode_cf_timedelta(timedeltas, units, dtype)


def _eagerly_encode_cf_timedelta(
timedeltas,
units: str | None = None,
dtype: np.dtype | None = None,
called_via_map_blocks: bool = False,
) -> tuple[np.ndarray, str]:
data_units = infer_timedelta_units(timedeltas)

Expand Down Expand Up @@ -868,7 +895,7 @@ def encode_cf_timedelta(
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
elif np.issubdtype(dtype, np.integer) and not called_via_map_blocks:
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Serializing with units {needed_units!r} instead. "
Expand All @@ -881,7 +908,43 @@ def encode_cf_timedelta(

num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(timedeltas.shape)
return (num, units)

if dtype is not None:
num = _cast_to_dtype_safe(num, dtype)

if called_via_map_blocks:
return num
else:
return num, units


def _lazily_encode_cf_timedelta(
timedeltas, units: str | None = None, dtype: np.dtype | None = None
):
import dask.array

if units is None and dtype is None:
units = "nanoseconds"
dtype = np.dtype("int64")

if units is None or dtype is None:
raise ValueError(
f"When encoding chunked arrays of timedelta values, both the "
f"units and dtype must be prescribed or both must be "
f"unprescribed. Prescribing only one or the other is not "
f"currently supported. Got a units encoding of {units} and a "
f"dtype encoding of {dtype}."
)

num = dask.array.map_blocks(
_eagerly_encode_cf_timedelta,
timedeltas,
units,
dtype,
called_via_map_blocks=True,
dtype=dtype,
)
return num, units


class CFDatetimeCoder(VariableCoder):
Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2818,6 +2818,15 @@ def test_chunked_datetime64(self) -> None:
assert original[name].chunks == actual_var.chunks
assert original.chunks == actual.chunks

@requires_dask
def test_chunked_timedelta64(self) -> None:
# Based @malmans2's datetime64[ns] test in PR #8253
original = create_test_data().astype("timedelta64[ns]").chunk(1)
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
for name, actual_var in actual.variables.items():
assert original[name].chunks == actual_var.chunks
assert original.chunks == actual.chunks

def test_vectorized_indexing_negative_step(self) -> None:
if not has_dask:
pytest.xfail(
Expand Down
Loading

0 comments on commit f0b9a8d

Please sign in to comment.