Skip to content

Commit c8c6a4b

Browse files
dcherianIllviljan
andauthored
Speed up .dt accessor by preserving Index objects. (#7796)
* Speed up .dt accessor by preserving Index objects. * Switch to _replace * Add benchmark * Add whats-new * Shallow copy * Fix mypy * Revert "Fix mypy" This reverts commit fb0b073. --------- Co-authored-by: Illviljan <[email protected]>
1 parent 93d79d5 commit c8c6a4b

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

asv_bench/benchmarks/accessors.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
import xarray as xr
4+
5+
from . import parameterized
6+
7+
NTIME = 365 * 30
8+
9+
10+
@parameterized(["calendar"], [("standard", "noleap")])
11+
class DateTimeAccessor:
12+
def setup(self, calendar):
13+
np.random.randn(NTIME)
14+
time = xr.date_range("2000", periods=30 * 365, calendar=calendar)
15+
data = np.ones((NTIME,))
16+
self.da = xr.DataArray(data, dims="time", coords={"time": time})
17+
18+
def time_dayofyear(self, calendar):
19+
self.da.time.dt.dayofyear
20+
21+
def time_year(self, calendar):
22+
self.da.time.dt.year
23+
24+
def time_floor(self, calendar):
25+
self.da.time.dt.floor("D")

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ Breaking changes
3737
Deprecations
3838
~~~~~~~~~~~~
3939

40+
Performance
41+
~~~~~~~~~~~
42+
- Optimize ``.dt `` accessor performance with ``CFTimeIndex``. (:pull:`7796`)
43+
By `Deepak Cherian <https://github.com/dcherian>`_.
4044

4145
Bug fixes
4246
~~~~~~~~~

xarray/core/accessor_dt.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from xarray.core.pycompat import is_duck_dask_array
1616
from xarray.core.types import T_DataArray
17+
from xarray.core.variable import IndexVariable
1718

1819
if TYPE_CHECKING:
1920
from numpy.typing import DTypeLike
@@ -48,7 +49,10 @@ def _access_through_cftimeindex(values, name):
4849
"""
4950
from xarray.coding.cftimeindex import CFTimeIndex
5051

51-
values_as_cftimeindex = CFTimeIndex(values.ravel())
52+
if not isinstance(values, CFTimeIndex):
53+
values_as_cftimeindex = CFTimeIndex(values.ravel())
54+
else:
55+
values_as_cftimeindex = values
5256
if name == "season":
5357
months = values_as_cftimeindex.month
5458
field_values = _season_from_months(months)
@@ -115,7 +119,7 @@ def _get_date_field(values, name, dtype):
115119
access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks
116120
)
117121
else:
118-
return access_method(values, name).astype(dtype)
122+
return access_method(values, name).astype(dtype, copy=False)
119123

120124

121125
def _round_through_series_or_index(values, name, freq):
@@ -200,6 +204,13 @@ def _strftime(values, date_format):
200204
return access_method(values, date_format)
201205

202206

207+
def _index_or_data(obj):
208+
if isinstance(obj.variable, IndexVariable):
209+
return obj.to_index()
210+
else:
211+
return obj.data
212+
213+
203214
class TimeAccessor(Generic[T_DataArray]):
204215
__slots__ = ("_obj",)
205216

@@ -209,14 +220,14 @@ def __init__(self, obj: T_DataArray) -> None:
209220
def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray:
210221
if dtype is None:
211222
dtype = self._obj.dtype
212-
obj_type = type(self._obj)
213-
result = _get_date_field(self._obj.data, name, dtype)
214-
return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims)
223+
result = _get_date_field(_index_or_data(self._obj), name, dtype)
224+
newvar = self._obj.variable.copy(data=result, deep=False)
225+
return self._obj._replace(newvar, name=name)
215226

216227
def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray:
217-
obj_type = type(self._obj)
218-
result = _round_field(self._obj.data, name, freq)
219-
return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims)
228+
result = _round_field(_index_or_data(self._obj), name, freq)
229+
newvar = self._obj.variable.copy(data=result, deep=False)
230+
return self._obj._replace(newvar, name=name)
220231

221232
def floor(self, freq: str) -> T_DataArray:
222233
"""

0 commit comments

Comments
 (0)