Skip to content

Speed up .dt accessor by preserving Index objects. #7796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions asv_bench/benchmarks/accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np

import xarray as xr

from . import parameterized

NTIME = 365 * 30


@parameterized(["calendar"], [("standard", "noleap")])
class DateTimeAccessor:
def setup(self, calendar):
np.random.randn(NTIME)
time = xr.date_range("2000", periods=30 * 365, calendar=calendar)
data = np.ones((NTIME,))
self.da = xr.DataArray(data, dims="time", coords={"time": time})

def time_dayofyear(self, calendar):
self.da.time.dt.dayofyear

def time_year(self, calendar):
self.da.time.dt.year

def time_floor(self, calendar):
self.da.time.dt.floor("D")
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Breaking changes
Deprecations
~~~~~~~~~~~~

Performance
~~~~~~~~~~~
- Optimize ``.dt `` accessor performance with ``CFTimeIndex``. (:pull:`7796`)
By `Deepak Cherian <https://github.com/dcherian>`_.

Bug fixes
~~~~~~~~~
Expand Down
27 changes: 19 additions & 8 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import T_DataArray
from xarray.core.variable import IndexVariable

if TYPE_CHECKING:
from numpy.typing import DTypeLike
Expand Down Expand Up @@ -48,7 +49,10 @@ def _access_through_cftimeindex(values, name):
"""
from xarray.coding.cftimeindex import CFTimeIndex

values_as_cftimeindex = CFTimeIndex(values.ravel())
if not isinstance(values, CFTimeIndex):
values_as_cftimeindex = CFTimeIndex(values.ravel())
else:
values_as_cftimeindex = values
if name == "season":
months = values_as_cftimeindex.month
field_values = _season_from_months(months)
Expand Down Expand Up @@ -115,7 +119,7 @@ def _get_date_field(values, name, dtype):
access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks
)
else:
return access_method(values, name).astype(dtype)
return access_method(values, name).astype(dtype, copy=False)


def _round_through_series_or_index(values, name, freq):
Expand Down Expand Up @@ -200,6 +204,13 @@ def _strftime(values, date_format):
return access_method(values, date_format)


def _index_or_data(obj):
if isinstance(obj.variable, IndexVariable):
return obj.to_index()
else:
return obj.data


class TimeAccessor(Generic[T_DataArray]):
__slots__ = ("_obj",)

Expand All @@ -209,14 +220,14 @@ def __init__(self, obj: T_DataArray) -> None:
def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray:
if dtype is None:
dtype = self._obj.dtype
obj_type = type(self._obj)
result = _get_date_field(self._obj.data, name, dtype)
return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims)
result = _get_date_field(_index_or_data(self._obj), name, dtype)
newvar = self._obj.variable.copy(data=result, deep=False)
return self._obj._replace(newvar, name=name)

def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray:
obj_type = type(self._obj)
result = _round_field(self._obj.data, name, freq)
return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims)
result = _round_field(_index_or_data(self._obj), name, freq)
newvar = self._obj.variable.copy(data=result, deep=False)
return self._obj._replace(newvar, name=name)

def floor(self, freq: str) -> T_DataArray:
"""
Expand Down