From 6ebb91729fea124d58cfac6c78877eda5f9701e9 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Sun, 31 Dec 2023 10:48:25 -0500 Subject: [PATCH] Add dask support for timedelta encoding and more tests --- xarray/coding/times.py | 115 ++++++++++++++++----- xarray/tests/test_backends.py | 9 ++ xarray/tests/test_coding_times.py | 161 +++++++++++++++++++++--------- 3 files changed, 214 insertions(+), 71 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index fa30241fd55..3c9e41f8cd9 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -674,6 +674,15 @@ def encode_cf_datetime( calendar: str | None = None, dtype: np.dtype | None = None, ): + """Given an array of datetime objects, returns the tuple `(num, units, + calendar)` suitable for a CF compliant time variable. + + Unlike `date2num`, this function can handle datetime64 arrays. + + See Also + -------- + cftime.date2num + """ dates = asarray(dates) if isinstance(dates, np.ndarray): return _eagerly_encode_cf_datetime(dates, units, calendar, dtype) @@ -681,24 +690,38 @@ def encode_cf_datetime( return _lazily_encode_cf_datetime(dates, units, calendar, dtype) -def _cast_to_dtype_safe(num, dtype) -> np.ndarray: - cast_num = np.asarray(num, dtype=dtype) +def _cast_to_dtype_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="overflow") + cast_num = np.asarray(num, dtype=dtype) if np.issubdtype(dtype, np.integer): if not (num == cast_num).all(): - raise ValueError( - f"Not possible to cast all encoded times from dtype {num.dtype!r} " - f"to dtype {dtype!r} without changing any of their values. " - f"Consider removing the dtype encoding or explicitly switching to " - f"a dtype encoding with a higher precision." - ) + if np.issubdtype(num.dtype, np.floating): + raise ValueError( + f"Not possible to cast all encoded times from dtype " + f"{num.dtype!r} to integer dtype {dtype!r} without losing " + f"precision. Consider modifying the units such that " + f"integer values can be used, or removing the units and " + f"dtype encoding, at which point xarray will make an " + f"appropriate choice." + ) + else: + raise OverflowError( + f"Not possible to cast encoded times from dtype " + f"{num.dtype!r} to dtype {dtype!r} without overflow. " + f"Consider removing the dtype encoding, at which point " + f"xarray will make an appropriate choice, or explicitly " + f"switching to a larger integer dtype." + ) else: if np.isinf(cast_num).any(): raise OverflowError( f"Not possible to cast encoded times from dtype {num.dtype!r} " f"to dtype {dtype!r} without overflow. Consider removing the " - f"dtype encoding or explicitly switching to a dtype encoding " - f"with a higher precision." + f"dtype encoding, at which point xarray will make an " + f"appropriate choice, or explicitly switching to a larger " + f"floating point dtype." ) return cast_num @@ -711,15 +734,6 @@ def _eagerly_encode_cf_datetime( dtype: np.dtype | None = None, called_via_map_blocks: bool = False, ) -> tuple[np.ndarray, str, str]: - """Given an array of datetime objects, returns the tuple `(num, units, - calendar)` suitable for a CF compliant time variable. - - Unlike `date2num`, this function can handle datetime64 arrays. - - See Also - -------- - cftime.date2num - """ dates = np.asarray(dates) data_units = infer_datetime_units(dates) @@ -796,7 +810,7 @@ def _eagerly_encode_cf_datetime( if called_via_map_blocks: return num else: - return (num, units, calendar) + return num, units, calendar def _lazily_encode_cf_datetime( @@ -821,10 +835,10 @@ def _lazily_encode_cf_datetime( if units is None or dtype is None: raise ValueError( - f"When encoding chunked arrays of datetime values, both the units and " - f"dtype must be prescribed or both must be unprescribed. Prescribing " - f"only one or the other is not currently supported. Got a units " - f"encoding of {units} and a dtype encoding of {dtype}." + f"When encoding chunked arrays of datetime values, both the units " + f"and dtype must be prescribed or both must be unprescribed. " + f"Prescribing only one or the other is not currently supported. " + f"Got a units encoding of {units} and a dtype encoding of {dtype}." ) num = dask.array.map_blocks( @@ -841,6 +855,19 @@ def _lazily_encode_cf_datetime( def encode_cf_timedelta( timedeltas, units: str | None = None, dtype: np.dtype | None = None +): + timedeltas = asarray(timedeltas) + if is_duck_dask_array(timedeltas): + return _lazily_encode_cf_timedelta(timedeltas, units, dtype) + else: + return _eagerly_encode_cf_timedelta(timedeltas, units, dtype) + + +def _eagerly_encode_cf_timedelta( + timedeltas, + units: str | None = None, + dtype: np.dtype | None = None, + called_via_map_blocks: bool = False, ) -> tuple[np.ndarray, str]: data_units = infer_timedelta_units(timedeltas) @@ -868,7 +895,7 @@ def encode_cf_timedelta( f"Set encoding['dtype'] to integer dtype to serialize to int64. " f"Set encoding['dtype'] to floating point dtype to silence this warning." ) - elif np.issubdtype(dtype, np.integer): + elif np.issubdtype(dtype, np.integer) and not called_via_map_blocks: emit_user_level_warning( f"Timedeltas can't be serialized faithfully with requested units {units!r}. " f"Serializing with units {needed_units!r} instead. " @@ -881,7 +908,43 @@ def encode_cf_timedelta( num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(timedeltas.shape) - return (num, units) + + if dtype is not None: + num = _cast_to_dtype_safe(num, dtype) + + if called_via_map_blocks: + return num + else: + return num, units + + +def _lazily_encode_cf_timedelta( + timedeltas, units: str | None = None, dtype: np.dtype | None = None +): + import dask.array + + if units is None and dtype is None: + units = "nanoseconds" + dtype = np.dtype("int64") + + if units is None or dtype is None: + raise ValueError( + f"When encoding chunked arrays of timedelta values, both the " + f"units and dtype must be prescribed or both must be " + f"unprescribed. Prescribing only one or the other is not " + f"currently supported. Got a units encoding of {units} and a " + f"dtype encoding of {dtype}." + ) + + num = dask.array.map_blocks( + _eagerly_encode_cf_timedelta, + timedeltas, + units, + dtype, + called_via_map_blocks=True, + dtype=dtype, + ) + return num, units class CFDatetimeCoder(VariableCoder): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d5eb5b1ee2c..bc086a72964 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2818,6 +2818,15 @@ def test_chunked_datetime64(self) -> None: assert original[name].chunks == actual_var.chunks assert original.chunks == actual.chunks + @requires_dask + def test_chunked_timedelta64(self) -> None: + # Based @malmans2's datetime64[ns] test in PR #8253 + original = create_test_data().astype("timedelta64[ns]").chunk(1) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for name, actual_var in actual.variables.items(): + assert original[name].chunks == actual_var.chunks + assert original.chunks == actual.chunks + def test_vectorized_indexing_negative_step(self) -> None: if not has_dask: pytest.xfail( diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index b896718a13f..4aa2f36c1e5 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -25,7 +25,9 @@ _should_cftime_be_used, cftime_to_nptime, decode_cf_datetime, + decode_cf_timedelta, encode_cf_datetime, + encode_cf_timedelta, to_timedelta_unboxed, ) from xarray.coding.variables import SerializationWarning @@ -1391,26 +1393,26 @@ def test_roundtrip_float_times() -> None: assert decoded_var.encoding["_FillValue"] == fill_value -ENCODE_DATETIME64_VIA_DASK_TESTS = { +_ENCODE_DATETIME64_VIA_DASK_TESTS = { "pandas-encoding-with-prescribed-units-and-dtype": ( "D", "days since 1700-01-01", np.dtype("int32"), ), "mixed-cftime-pandas-encoding-with-prescribed-units-and-dtype": ( - "252YS", + "250YS", "days since 1700-01-01", np.dtype("int32"), ), - "pandas-encoding-with-default-units-and-dtype": ("252YS", None, None), + "pandas-encoding-with-default-units-and-dtype": ("250YS", None, None), } @requires_dask @pytest.mark.parametrize( ("freq", "units", "dtype"), - ENCODE_DATETIME64_VIA_DASK_TESTS.values(), - ids=ENCODE_DATETIME64_VIA_DASK_TESTS.keys(), + _ENCODE_DATETIME64_VIA_DASK_TESTS.values(), + ids=_ENCODE_DATETIME64_VIA_DASK_TESTS.keys(), ) def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype): import dask.array @@ -1439,39 +1441,32 @@ def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype): @requires_dask @pytest.mark.parametrize( - ("units", "dtype"), [(None, np.dtype("int32")), ("2000-01-01", None)] + ("range_function", "start", "units", "dtype"), + [ + (pd.date_range, "2000", None, np.dtype("int32")), + (pd.date_range, "2000", "days since 2000-01-01", None), + (pd.timedelta_range, "0D", None, np.dtype("int32")), + (pd.timedelta_range, "0D", "days", None), + ], ) -def test_encode_cf_datetime_via_dask_error(units, dtype): - import dask.array - - times = pd.date_range(start="1700", freq="D", periods=3) - times = dask.array.from_array(times, chunks=1) - +def test_encode_via_dask_cannot_infer_error(range_function, start, units, dtype): + values = range_function(start=start, freq="D", periods=3) + encoding = dict(units=units, dtype=dtype) + variable = Variable(["time"], values, encoding=encoding).chunk({"time": 1}) with pytest.raises(ValueError, match="When encoding chunked arrays"): - encode_cf_datetime(times, units, None, dtype) - - -ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS = { - "prescribed-units-and-dtype": ("D", "days since 1700-01-01", np.dtype("int32")), - "default-units-and-dtype": ("252YS", None, None), -} + conventions.encode_cf_variable(variable) @requires_cftime @requires_dask @pytest.mark.parametrize( - "calendar", - ["standard", "proleptic_gregorian", "julian", "noleap", "all_leap", "360_day"], -) -@pytest.mark.parametrize( - ("freq", "units", "dtype"), - ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.values(), - ids=ENCODE_CFTIME_DATETIME_VIA_DASK_TESTS.keys(), + ("units", "dtype"), [("days since 1700-01-01", np.dtype("int32")), (None, None)] ) -def test_encode_cf_datetime_cftime_datetime_via_dask(calendar, freq, units, dtype): +def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype): import dask.array - times = cftime_range(start="1700", freq=freq, periods=3, calendar=calendar) + calendar = "standard" + times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) times = dask.array.from_array(times, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype @@ -1495,33 +1490,109 @@ def test_encode_cf_datetime_cftime_datetime_via_dask(calendar, freq, units, dtyp np.testing.assert_equal(decoded_times, times) -@requires_dask @pytest.mark.parametrize( "use_cftime", [False, pytest.param(True, marks=requires_cftime)] ) -def test_encode_cf_datetime_via_dask_casting_value_error(use_cftime): - import dask.array - +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_datetime_casting_value_error(use_cftime, use_dask): times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime) - times = dask.array.from_array(times, chunks=1) - units = "days since 2000-01-01" - dtype = np.int64 - encoded_times, *_ = encode_cf_datetime(times, units, None, dtype) - with pytest.raises(ValueError, match="Not possible"): - encoded_times.compute() + encoding = dict(units="days since 2000-01-01", dtype=np.dtype("int64")) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_cftime and not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. For all other cases we raise. + with pytest.warns(UserWarning, match="Times can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours since 2000-01-01" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() -@requires_dask @pytest.mark.parametrize( "use_cftime", [False, pytest.param(True, marks=requires_cftime)] ) -def test_encode_cf_datetime_via_dask_casting_overflow_error(use_cftime): +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype): + # Regression test for GitHub issue #8542 + times = date_range(start="2018", freq="5h", periods=3, use_cftime=use_cftime) + encoding = dict(units="microseconds since 2018-01-01", dtype=dtype) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + with pytest.raises(OverflowError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@requires_dask +@pytest.mark.parametrize( + ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] +) +def test_encode_cf_timedelta_via_dask(units, dtype): import dask.array - times = date_range(start="1700", freq="252YS", periods=3, use_cftime=use_cftime) + times = pd.timedelta_range(start="0D", freq="D", periods=3) times = dask.array.from_array(times, chunks=1) - units = "days since 1700-01-01" - dtype = np.dtype("float16") - encoded_times, *_ = encode_cf_datetime(times, units, None, dtype) + encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "nanoseconds" + assert encoded_times.dtype == np.dtype("int64") + + decoded_times = decode_cf_timedelta(encoded_times, encoding_units) + np.testing.assert_equal(decoded_times, times) + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_timedelta_casting_value_error(use_dask): + timedeltas = pd.timedelta_range(start="0h", freq="12h", periods=3) + encoding = dict(units="days", dtype=np.dtype("int64")) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. + with pytest.warns(UserWarning, match="Timedeltas can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_timedelta_casting_overflow_error(use_dask, dtype): + timedeltas = pd.timedelta_range(start="0h", freq="5h", periods=3) + encoding = dict(units="microseconds", dtype=dtype) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + with pytest.raises(OverflowError, match="Not possible"): - encoded_times.compute() + encoded = conventions.encode_cf_variable(variable) + encoded.compute()