Skip to content

Commit 35ed378

Browse files
committed
Support dask arrays in datetime_to_numeric
1 parent c003dea commit 35ed378

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

xarray/core/duck_array_ops.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,20 @@ def pd_timedelta_to_float(value, datetime_unit):
517517
return np_timedelta64_to_float(value, datetime_unit)
518518

519519

520+
def _timedelta_to_seconds(array):
521+
return np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6
522+
523+
520524
def py_timedelta_to_float(array, datetime_unit):
521525
"""Convert a timedelta object to a float, possibly at a loss of resolution."""
522-
array = np.asarray(array)
523-
array = np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6
526+
if not is_duck_array(array):
527+
array = np.asarray(array)
528+
if is_duck_dask_array(array):
529+
array = array.map_blocks(
530+
_timedelta_to_seconds, meta=np.array([], dtype=np.float64)
531+
)
532+
else:
533+
array = _timedelta_to_seconds(array)
524534
conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit)
525535
return conversion_factor * array
526536

xarray/tests/test_duck_array_ops.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -675,39 +675,68 @@ def test_multiple_dims(dtype, dask, skipna, func):
675675
assert_allclose(actual, expected)
676676

677677

678-
def test_datetime_to_numeric_datetime64():
678+
@pytest.mark.parametrize("dask", [True, False])
679+
def test_datetime_to_numeric_datetime64(dask):
680+
if dask and not has_dask:
681+
pytest.skip("requires dask")
682+
679683
times = pd.date_range("2000", periods=5, freq="7D").values
680-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h")
684+
if dask:
685+
import dask.array
686+
687+
times = dask.array.from_array(times, chunks=-1)
688+
689+
with raise_if_dask_computes():
690+
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h")
681691
expected = 24 * np.arange(0, 35, 7)
682692
np.testing.assert_array_equal(result, expected)
683693

684694
offset = times[1]
685-
result = duck_array_ops.datetime_to_numeric(times, offset=offset, datetime_unit="h")
695+
with raise_if_dask_computes():
696+
result = duck_array_ops.datetime_to_numeric(
697+
times, offset=offset, datetime_unit="h"
698+
)
686699
expected = 24 * np.arange(-7, 28, 7)
687700
np.testing.assert_array_equal(result, expected)
688701

689702
dtype = np.float32
690-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype)
703+
with raise_if_dask_computes():
704+
result = duck_array_ops.datetime_to_numeric(
705+
times, datetime_unit="h", dtype=dtype
706+
)
691707
expected = 24 * np.arange(0, 35, 7).astype(dtype)
692708
np.testing.assert_array_equal(result, expected)
693709

694710

695711
@requires_cftime
696-
def test_datetime_to_numeric_cftime():
712+
@pytest.mark.parametrize("dask", [True, False])
713+
def test_datetime_to_numeric_cftime(dask):
714+
if dask and not has_dask:
715+
pytest.skip("requires dask")
716+
697717
times = cftime_range("2000", periods=5, freq="7D", calendar="standard").values
698-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int)
718+
if dask:
719+
import dask.array
720+
721+
times = dask.array.from_array(times, chunks=-1)
722+
with raise_if_dask_computes():
723+
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int)
699724
expected = 24 * np.arange(0, 35, 7)
700725
np.testing.assert_array_equal(result, expected)
701726

702727
offset = times[1]
703-
result = duck_array_ops.datetime_to_numeric(
704-
times, offset=offset, datetime_unit="h", dtype=int
705-
)
728+
with raise_if_dask_computes():
729+
result = duck_array_ops.datetime_to_numeric(
730+
times, offset=offset, datetime_unit="h", dtype=int
731+
)
706732
expected = 24 * np.arange(-7, 28, 7)
707733
np.testing.assert_array_equal(result, expected)
708734

709735
dtype = np.float32
710-
result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype)
736+
with raise_if_dask_computes():
737+
result = duck_array_ops.datetime_to_numeric(
738+
times, datetime_unit="h", dtype=dtype
739+
)
711740
expected = 24 * np.arange(0, 35, 7).astype(dtype)
712741
np.testing.assert_array_equal(result, expected)
713742

0 commit comments

Comments
 (0)