From 7790863435c612f3e1205292bd7f4efa847fd310 Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Wed, 16 Feb 2022 00:05:04 +0100 Subject: [PATCH 01/34] Implement multiplication of cftime Tick offsets by floats (#6135) --- doc/whats-new.rst | 8 +++ xarray/coding/cftime_offsets.py | 61 +++++++++++++++--- xarray/coding/cftimeindex.py | 17 ++--- xarray/tests/test_cftime_offsets.py | 96 +++++++++++++++++++++++------ xarray/tests/test_cftimeindex.py | 52 +++++++++++++--- 5 files changed, 190 insertions(+), 44 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 02b341f963e..88453a641e0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -19,9 +19,15 @@ What's New v2022.02.0 (unreleased) ----------------------- + New Features ~~~~~~~~~~~~ +- Enabled multiplying tick offsets by floats. Allows ``float`` ``n`` in + :py:meth:`CFTimeIndex.shift` if ``shift_freq`` is between ``Day`` + and ``Microsecond``. (:issue:`6134`, :pull:`6135`). + By `Aaron Spring `_. + Breaking changes ~~~~~~~~~~~~~~~~ @@ -46,6 +52,7 @@ Documentation ~~~~~~~~~~~~~ + Internal Changes ~~~~~~~~~~~~~~~~ @@ -86,6 +93,7 @@ New Features - Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`) By `Joseph Nowak `_. + Breaking changes ~~~~~~~~~~~~~~~~ - Rely on matplotlib's default datetime converters instead of pandas' (:issue:`6102`, :pull:`6109`). diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 30bfd882b5c..a4e2870650d 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -39,11 +39,12 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import re from datetime import datetime, timedelta from functools import partial -from typing import ClassVar, Optional +from typing import ClassVar import numpy as np import pandas as pd @@ -87,10 +88,10 @@ def get_date_type(calendar, use_cftime=True): class BaseCFTimeOffset: - _freq: ClassVar[Optional[str]] = None - _day_option: ClassVar[Optional[str]] = None + _freq: ClassVar[str | None] = None + _day_option: ClassVar[str | None] = None - def __init__(self, n=1): + def __init__(self, n: int = 1): if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " @@ -122,6 +123,8 @@ def __sub__(self, other): return NotImplemented def __mul__(self, other): + if not isinstance(other, int): + return NotImplemented return type(self)(n=other * self.n) def __neg__(self): @@ -171,6 +174,40 @@ def _get_offset_day(self, other): return _get_day_of_month(other, self._day_option) +class Tick(BaseCFTimeOffset): + # analogous https://github.com/pandas-dev/pandas/blob/ccb25ab1d24c4fb9691270706a59c8d319750870/pandas/_libs/tslibs/offsets.pyx#L806 + + def _next_higher_resolution(self): + self_type = type(self) + if self_type not in [Day, Hour, Minute, Second, Millisecond]: + raise ValueError("Could not convert to integer offset at any resolution") + if type(self) is Day: + return Hour(self.n * 24) + if type(self) is Hour: + return Minute(self.n * 60) + if type(self) is Minute: + return Second(self.n * 60) + if type(self) is Second: + return Millisecond(self.n * 1000) + if type(self) is Millisecond: + return Microsecond(self.n * 1000) + + def __mul__(self, other): + if not isinstance(other, (int, float)): + return NotImplemented + if isinstance(other, float): + n = other * self.n + # If the new `n` is an integer, we can represent it using the + # same BaseCFTimeOffset subclass as self, otherwise we need to move up + # to a higher-resolution subclass + if np.isclose(n % 1, 0): + return type(self)(int(n)) + + new_self = self._next_higher_resolution() + return new_self * other + return type(self)(n=other * self.n) + + def _get_day_of_month(other, day_option): """Find the day in `other`'s month that satisfies a BaseCFTimeOffset's onOffset policy, as described by the `day_option` argument. @@ -396,6 +433,8 @@ def __sub__(self, other): return NotImplemented def __mul__(self, other): + if isinstance(other, float): + return NotImplemented return type(self)(n=other * self.n, month=self.month) def rule_code(self): @@ -482,6 +521,8 @@ def __sub__(self, other): return NotImplemented def __mul__(self, other): + if isinstance(other, float): + return NotImplemented return type(self)(n=other * self.n, month=self.month) def rule_code(self): @@ -541,7 +582,7 @@ def rollback(self, date): return date - YearEnd(month=self.month) -class Day(BaseCFTimeOffset): +class Day(Tick): _freq = "D" def as_timedelta(self): @@ -551,7 +592,7 @@ def __apply__(self, other): return other + self.as_timedelta() -class Hour(BaseCFTimeOffset): +class Hour(Tick): _freq = "H" def as_timedelta(self): @@ -561,7 +602,7 @@ def __apply__(self, other): return other + self.as_timedelta() -class Minute(BaseCFTimeOffset): +class Minute(Tick): _freq = "T" def as_timedelta(self): @@ -571,7 +612,7 @@ def __apply__(self, other): return other + self.as_timedelta() -class Second(BaseCFTimeOffset): +class Second(Tick): _freq = "S" def as_timedelta(self): @@ -581,7 +622,7 @@ def __apply__(self, other): return other + self.as_timedelta() -class Millisecond(BaseCFTimeOffset): +class Millisecond(Tick): _freq = "L" def as_timedelta(self): @@ -591,7 +632,7 @@ def __apply__(self, other): return other + self.as_timedelta() -class Microsecond(BaseCFTimeOffset): +class Microsecond(Tick): _freq = "U" def as_timedelta(self): diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 8f9d19d7897..d522d7910d4 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -38,11 +38,11 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import re import warnings from datetime import timedelta -from typing import Tuple, Type import numpy as np import pandas as pd @@ -66,7 +66,7 @@ REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END = 10 -OUT_OF_BOUNDS_TIMEDELTA_ERRORS: Tuple[Type[Exception], ...] +OUT_OF_BOUNDS_TIMEDELTA_ERRORS: tuple[type[Exception], ...] try: OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (pd.errors.OutOfBoundsTimedelta, OverflowError) except AttributeError: @@ -511,7 +511,7 @@ def contains(self, key): """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift(self, n, freq): + def shift(self, n: int | float, freq: str | timedelta): """Shift the CFTimeIndex a multiple of the given frequency. See the documentation for :py:func:`~xarray.cftime_range` for a @@ -519,7 +519,7 @@ def shift(self, n, freq): Parameters ---------- - n : int + n : int, float if freq of days or below Periods to shift by freq : str or datetime.timedelta A frequency string or datetime.timedelta object to shift by @@ -541,14 +541,15 @@ def shift(self, n, freq): >>> index.shift(1, "M") CFTimeIndex([2000-02-29 00:00:00], dtype='object', length=1, calendar='standard', freq=None) + >>> index.shift(1.5, "D") + CFTimeIndex([2000-02-01 12:00:00], + dtype='object', length=1, calendar='standard', freq=None) """ - from .cftime_offsets import to_offset - - if not isinstance(n, int): - raise TypeError(f"'n' must be an int, got {n}.") if isinstance(freq, timedelta): return self + n * freq elif isinstance(freq, str): + from .cftime_offsets import to_offset + return self + n * to_offset(freq) else: raise TypeError( diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 4f94b35e3c3..3879959675f 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -18,6 +18,7 @@ QuarterBegin, QuarterEnd, Second, + Tick, YearBegin, YearEnd, _days_in_month, @@ -54,11 +55,25 @@ def calendar(request): (YearEnd(), 1), (QuarterBegin(), 1), (QuarterEnd(), 1), + (Tick(), 1), + (Day(), 1), + (Hour(), 1), + (Minute(), 1), + (Second(), 1), + (Millisecond(), 1), + (Microsecond(), 1), (BaseCFTimeOffset(n=2), 2), (YearBegin(n=2), 2), (YearEnd(n=2), 2), (QuarterBegin(n=2), 2), (QuarterEnd(n=2), 2), + (Tick(n=2), 2), + (Day(n=2), 2), + (Hour(n=2), 2), + (Minute(n=2), 2), + (Second(n=2), 2), + (Millisecond(n=2), 2), + (Microsecond(n=2), 2), ], ids=_id_func, ) @@ -74,6 +89,15 @@ def test_cftime_offset_constructor_valid_n(offset, expected_n): (YearEnd, 1.5), (QuarterBegin, 1.5), (QuarterEnd, 1.5), + (MonthBegin, 1.5), + (MonthEnd, 1.5), + (Tick, 1.5), + (Day, 1.5), + (Hour, 1.5), + (Minute, 1.5), + (Second, 1.5), + (Millisecond, 1.5), + (Microsecond, 1.5), ], ids=_id_func, ) @@ -359,30 +383,64 @@ def test_eq(a, b): _MUL_TESTS = [ - (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)), - (YearEnd(), YearEnd(n=3)), - (YearBegin(), YearBegin(n=3)), - (QuarterEnd(), QuarterEnd(n=3)), - (QuarterBegin(), QuarterBegin(n=3)), - (MonthEnd(), MonthEnd(n=3)), - (MonthBegin(), MonthBegin(n=3)), - (Day(), Day(n=3)), - (Hour(), Hour(n=3)), - (Minute(), Minute(n=3)), - (Second(), Second(n=3)), - (Millisecond(), Millisecond(n=3)), - (Microsecond(), Microsecond(n=3)), + (BaseCFTimeOffset(), 3, BaseCFTimeOffset(n=3)), + (YearEnd(), 3, YearEnd(n=3)), + (YearBegin(), 3, YearBegin(n=3)), + (QuarterEnd(), 3, QuarterEnd(n=3)), + (QuarterBegin(), 3, QuarterBegin(n=3)), + (MonthEnd(), 3, MonthEnd(n=3)), + (MonthBegin(), 3, MonthBegin(n=3)), + (Tick(), 3, Tick(n=3)), + (Day(), 3, Day(n=3)), + (Hour(), 3, Hour(n=3)), + (Minute(), 3, Minute(n=3)), + (Second(), 3, Second(n=3)), + (Millisecond(), 3, Millisecond(n=3)), + (Microsecond(), 3, Microsecond(n=3)), + (Day(), 0.5, Hour(n=12)), + (Hour(), 0.5, Minute(n=30)), + (Minute(), 0.5, Second(n=30)), + (Second(), 0.5, Millisecond(n=500)), + (Millisecond(), 0.5, Microsecond(n=500)), ] -@pytest.mark.parametrize(("offset", "expected"), _MUL_TESTS, ids=_id_func) -def test_mul(offset, expected): - assert offset * 3 == expected +@pytest.mark.parametrize(("offset", "multiple", "expected"), _MUL_TESTS, ids=_id_func) +def test_mul(offset, multiple, expected): + assert offset * multiple == expected -@pytest.mark.parametrize(("offset", "expected"), _MUL_TESTS, ids=_id_func) -def test_rmul(offset, expected): - assert 3 * offset == expected +@pytest.mark.parametrize(("offset", "multiple", "expected"), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, multiple, expected): + assert multiple * offset == expected + + +def test_mul_float_multiple_next_higher_resolution(): + """Test more than one iteration through _next_higher_resolution is required.""" + assert 1e-6 * Second() == Microsecond() + assert 1e-6 / 60 * Minute() == Microsecond() + + +@pytest.mark.parametrize( + "offset", + [YearBegin(), YearEnd(), QuarterBegin(), QuarterEnd(), MonthBegin(), MonthEnd()], + ids=_id_func, +) +def test_nonTick_offset_multiplied_float_error(offset): + """Test that the appropriate error is raised if a non-Tick offset is + multiplied by a float.""" + with pytest.raises(TypeError, match="unsupported operand type"): + offset * 0.5 + + +def test_Microsecond_multiplied_float_error(): + """Test that the appropriate error is raised if a Tick offset is multiplied + by a float which causes it not to be representable by a + microsecond-precision timedelta.""" + with pytest.raises( + ValueError, match="Could not convert to integer offset at any resolution" + ): + Microsecond() * 0.5 @pytest.mark.parametrize( diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index c70fd53038b..2c6a0796c5f 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -755,7 +755,7 @@ def test_cftimeindex_add(index): @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) -def test_cftimeindex_add_timedeltaindex(calendar): +def test_cftimeindex_add_timedeltaindex(calendar) -> None: a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = a + deltas @@ -764,6 +764,44 @@ def test_cftimeindex_add_timedeltaindex(calendar): assert isinstance(result, CFTimeIndex) +@requires_cftime +@pytest.mark.parametrize("n", [2.0, 1.5]) +@pytest.mark.parametrize( + "freq,units", + [ + ("D", "D"), + ("H", "H"), + ("T", "min"), + ("S", "S"), + ("L", "ms"), + ], +) +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_cftimeindex_shift_float(n, freq, units, calendar) -> None: + a = xr.cftime_range("2000", periods=3, calendar=calendar, freq="D") + result = a + pd.Timedelta(n, units) + expected = a.shift(n, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@requires_cftime +def test_cftimeindex_shift_float_us() -> None: + a = xr.cftime_range("2000", periods=3, freq="D") + with pytest.raises( + ValueError, match="Could not convert to integer offset at any resolution" + ): + a.shift(2.5, "us") + + +@requires_cftime +@pytest.mark.parametrize("freq", ["AS", "A", "YS", "Y", "QS", "Q", "MS", "M"]) +def test_cftimeindex_shift_float_fails_for_non_tick_freqs(freq) -> None: + a = xr.cftime_range("2000", periods=3, freq="D") + with pytest.raises(TypeError, match="unsupported operand type"): + a.shift(2.5, freq) + + @requires_cftime def test_cftimeindex_radd(index): date_type = index.date_type @@ -781,7 +819,7 @@ def test_cftimeindex_radd(index): @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) -def test_timedeltaindex_add_cftimeindex(calendar): +def test_timedeltaindex_add_cftimeindex(calendar) -> None: a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = deltas + a @@ -829,7 +867,7 @@ def test_cftimeindex_sub_timedelta_array(index, other): @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) -def test_cftimeindex_sub_cftimeindex(calendar): +def test_cftimeindex_sub_cftimeindex(calendar) -> None: a = xr.cftime_range("2000", periods=5, calendar=calendar) b = a.shift(2, "D") result = b - a @@ -868,7 +906,7 @@ def test_distant_cftime_datetime_sub_cftimeindex(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) -def test_cftimeindex_sub_timedeltaindex(calendar): +def test_cftimeindex_sub_timedeltaindex(calendar) -> None: a = xr.cftime_range("2000", periods=5, calendar=calendar) deltas = pd.TimedeltaIndex([timedelta(days=2) for _ in range(5)]) result = a - deltas @@ -904,7 +942,7 @@ def test_cftimeindex_rsub(index): @requires_cftime @pytest.mark.parametrize("freq", ["D", timedelta(days=1)]) -def test_cftimeindex_shift(index, freq): +def test_cftimeindex_shift(index, freq) -> None: date_type = index.date_type expected_dates = [ date_type(1, 1, 3), @@ -919,14 +957,14 @@ def test_cftimeindex_shift(index, freq): @requires_cftime -def test_cftimeindex_shift_invalid_n(): +def test_cftimeindex_shift_invalid_n() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): index.shift("a", "D") @requires_cftime -def test_cftimeindex_shift_invalid_freq(): +def test_cftimeindex_shift_invalid_freq() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): index.shift(1, 1) From 07456ed71b89704a438b17cb3bcf442f5f208d5a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 16 Feb 2022 16:32:35 +0000 Subject: [PATCH 02/34] Remove xfail from tests decorated by @gen_cluster (#6282) --- xarray/tests/test_distributed.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index a6ea792b5ac..b97032014c4 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -5,14 +5,14 @@ import tempfile import pytest +from packaging.version import Version dask = pytest.importorskip("dask") # isort:skip distributed = pytest.importorskip("distributed") # isort:skip from dask.distributed import Client, Lock -from distributed.utils_test import cluster, gen_cluster -from distributed.utils_test import loop from distributed.client import futures_of +from distributed.utils_test import cluster, gen_cluster, loop import xarray as xr from xarray.backends.locks import HDF5_LOCK, CombinedLock @@ -208,7 +208,10 @@ def test_dask_distributed_cfgrib_integration_test(loop) -> None: assert_allclose(actual, expected) -@pytest.mark.xfail(reason="https://github.com/pydata/xarray/pull/6211") +@pytest.mark.xfail( + condition=Version(distributed.__version__) < Version("2022.02.0"), + reason="https://github.com/dask/distributed/pull/5739", +) @gen_cluster(client=True) async def test_async(c, s, a, b) -> None: x = create_test_data() @@ -241,7 +244,10 @@ def test_hdf5_lock() -> None: assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) -@pytest.mark.xfail(reason="https://github.com/pydata/xarray/pull/6211") +@pytest.mark.xfail( + condition=Version(distributed.__version__) < Version("2022.02.0"), + reason="https://github.com/dask/distributed/pull/5739", +) @gen_cluster(client=True) async def test_serializable_locks(c, s, a, b) -> None: def f(x, lock=None): From dfaedb2773208c78ab93940ef4a1979238ee0f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Gli=C3=9F?= Date: Thu, 17 Feb 2022 13:51:48 +0100 Subject: [PATCH 03/34] Allow to parse more backend kwargs to pydap backend (#6276) * propose implementation of #6274 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix * Add tests for new method PydapDataStore._update_default_params * minor fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add entry in whats-new.rst * Minor change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Define open_dataset_params for cleaner solution * remove import of inspect lib * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/backends/pydap_.py Co-authored-by: Mathias Hauser * update whats-new * Set pydap backend arguments explicitly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set defaults in PydapDataStore.open rather than entry point * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Jonas Gliss Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mathias Hauser --- doc/whats-new.rst | 4 ++++ xarray/backends/pydap_.py | 47 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 88453a641e0..37abf931eb4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,10 @@ New Features and ``Microsecond``. (:issue:`6134`, :pull:`6135`). By `Aaron Spring `_. +- Enbable to provide more keyword arguments to `pydap` backend when reading + OpenDAP datasets (:issue:`6274`). + By `Jonas Gliß `. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index ffaf3793928..a5a1430abf2 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -86,9 +86,40 @@ def __init__(self, ds): self.ds = ds @classmethod - def open(cls, url, session=None): + def open( + cls, + url, + application=None, + session=None, + output_grid=None, + timeout=None, + verify=None, + user_charset=None, + ): + + if output_grid is None: + output_grid = True + + if verify is None: + verify = True + + if timeout is None: + from pydap.lib import DEFAULT_TIMEOUT - ds = pydap.client.open_url(url, session=session) + timeout = DEFAULT_TIMEOUT + + if user_charset is None: + user_charset = "ascii" + + ds = pydap.client.open_url( + url=url, + application=application, + session=session, + output_grid=output_grid, + timeout=timeout, + verify=verify, + user_charset=user_charset, + ) return cls(ds) def open_store_variable(self, var): @@ -123,12 +154,22 @@ def open_dataset( drop_variables=None, use_cftime=None, decode_timedelta=None, + application=None, session=None, + output_grid=None, + timeout=None, + verify=None, + user_charset=None, ): store = PydapDataStore.open( - filename_or_obj, + url=filename_or_obj, + application=application, session=session, + output_grid=output_grid, + timeout=timeout, + verify=verify, + user_charset=user_charset, ) store_entrypoint = StoreBackendEntrypoint() From 678bc68b85eb4ea14d9cb262e683e65070d8a31b Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 18 Feb 2022 07:01:57 -0800 Subject: [PATCH 04/34] Add pytest-gha-annotations (#6271) This looks really good -- it adds an annotation to the PR inline for any failures. I'm not sure how it will do with our many environments -- we may have to find a way of only having it run on one environment if it doesn't de-dupe them. --- ci/requirements/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 516c964afc7..9269a70badf 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -37,6 +37,7 @@ dependencies: - pytest - pytest-cov - pytest-env + - pytest-github-actions-annotate-failures - pytest-xdist - rasterio - scipy From 33fbb648ac042f821a11870ffa544e5bcb6e178f Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 18 Feb 2022 17:51:55 +0100 Subject: [PATCH 05/34] use `warnings.catch_warnings(record=True)` instead of `pytest.warns(None)` (#6251) * no longer use pytest.warns(None) * use 'assert_no_warnings' --- xarray/tests/__init__.py | 17 ++++++++++------- xarray/tests/test_backends.py | 13 ++++++------- xarray/tests/test_coding_times.py | 13 +++++-------- xarray/tests/test_dataarray.py | 4 ++-- xarray/tests/test_dataset.py | 6 +++--- xarray/tests/test_ufuncs.py | 5 ++--- xarray/tests/test_variable.py | 4 ++-- 7 files changed, 30 insertions(+), 32 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 20dfdaf5076..00fec07f793 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -1,7 +1,7 @@ import importlib import platform import warnings -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from unittest import mock # noqa: F401 import numpy as np @@ -113,15 +113,10 @@ def __call__(self, dsk, keys, **kwargs): return dask.get(dsk, keys, **kwargs) -@contextmanager -def dummy_context(): - yield None - - def raise_if_dask_computes(max_computes=0): # return a dummy context manager so that this can be used for non-dask objects if not has_dask: - return dummy_context() + return nullcontext() scheduler = CountingScheduler(max_computes) return dask.config.set(scheduler=scheduler) @@ -170,6 +165,14 @@ def source_ndarray(array): return base +@contextmanager +def assert_no_warnings(): + + with warnings.catch_warnings(record=True) as record: + yield record + assert len(record) == 0, "got unexpected warning(s)" + + # Internal versions of xarray's test functions that validate additional # invariants diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index c0e340dd723..1d0342dd344 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -54,6 +54,7 @@ assert_array_equal, assert_equal, assert_identical, + assert_no_warnings, has_dask, has_h5netcdf_0_12, has_netCDF4, @@ -1814,12 +1815,11 @@ def test_warning_on_bad_chunks(self): good_chunks = ({"dim2": 3}, {"dim3": (6, 4)}, {}) for chunks in good_chunks: kwargs = {"chunks": chunks} - with pytest.warns(None) as record: + with assert_no_warnings(): with self.roundtrip(original, open_kwargs=kwargs) as actual: for k, v in actual.variables.items(): # only index variables should be in memory assert v._in_memory == (k in actual.dims) - assert len(record) == 0 @requires_dask def test_deprecate_auto_chunk(self): @@ -4986,10 +4986,9 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get(): with create_tmp_file() as tmpfile: - with pytest.warns(None) as record: + with assert_no_warnings(): ds = Dataset() ds.to_netcdf(tmpfile) - assert len(record) == 0 @requires_scipy_or_netCDF4 @@ -5031,7 +5030,7 @@ def test_use_cftime_standard_calendar_default_in_range(calendar): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: with open_dataset(tmp_file) as ds: assert_identical(expected_x, ds.x) assert_identical(expected_time, ds.time) @@ -5094,7 +5093,7 @@ def test_use_cftime_true(calendar, units_year): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: with open_dataset(tmp_file, use_cftime=True) as ds: assert_identical(expected_x, ds.x) assert_identical(expected_time, ds.time) @@ -5123,7 +5122,7 @@ def test_use_cftime_false_standard_calendar_in_range(calendar): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: with open_dataset(tmp_file, use_cftime=False) as ds: assert_identical(expected_x, ds.x) assert_identical(expected_time, ds.time) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 2e19ddb3a75..92d27f22eb8 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -32,6 +32,7 @@ from . import ( arm_xfail, assert_array_equal, + assert_no_warnings, has_cftime, has_cftime_1_4_1, requires_cftime, @@ -905,10 +906,9 @@ def test_use_cftime_default_standard_calendar_in_range(calendar) -> None: units = "days since 2000-01-01" expected = pd.date_range("2000", periods=2) - with pytest.warns(None) as record: + with assert_no_warnings(): result = decode_cf_datetime(numerical_dates, units, calendar) np.testing.assert_array_equal(result, expected) - assert not record @requires_cftime @@ -942,10 +942,9 @@ def test_use_cftime_default_non_standard_calendar(calendar, units_year) -> None: numerical_dates, units, calendar, only_use_cftime_datetimes=True ) - with pytest.warns(None) as record: + with assert_no_warnings(): result = decode_cf_datetime(numerical_dates, units, calendar) np.testing.assert_array_equal(result, expected) - assert not record @requires_cftime @@ -960,10 +959,9 @@ def test_use_cftime_true(calendar, units_year) -> None: numerical_dates, units, calendar, only_use_cftime_datetimes=True ) - with pytest.warns(None) as record: + with assert_no_warnings(): result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=True) np.testing.assert_array_equal(result, expected) - assert not record @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @@ -972,10 +970,9 @@ def test_use_cftime_false_standard_calendar_in_range(calendar) -> None: units = "days since 2000-01-01" expected = pd.date_range("2000", periods=2) - with pytest.warns(None) as record: + with assert_no_warnings(): result = decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) np.testing.assert_array_equal(result, expected) - assert not record @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b707ae2a063..8d73f9ec7ee 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -33,6 +33,7 @@ assert_chunks_equal, assert_equal, assert_identical, + assert_no_warnings, has_dask, raise_if_dask_computes, requires_bottleneck, @@ -6155,9 +6156,8 @@ def test_rolling_keep_attrs(funcname, argument): def test_raise_no_warning_for_nan_in_binary_ops(): - with pytest.warns(None) as record: + with assert_no_warnings(): xr.DataArray([1, 2, np.NaN]) > 0 - assert len(record) == 0 @pytest.mark.filterwarnings("error") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fed886465ed..c4fa847e664 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -38,6 +38,7 @@ assert_array_equal, assert_equal, assert_identical, + assert_no_warnings, create_test_data, has_cftime, has_dask, @@ -1873,7 +1874,7 @@ def test_reindex_warning(self): # Should not warn ind = xr.DataArray([0.0, 1.0], dims=["dim2"], name="ind") - with pytest.warns(None) as ws: + with warnings.catch_warnings(record=True) as ws: data.reindex(dim2=ind) assert len(ws) == 0 @@ -6165,9 +6166,8 @@ def test_ndrolling_construct(center, fill_value, dask): def test_raise_no_warning_for_nan_in_binary_ops(): - with pytest.warns(None) as record: + with assert_no_warnings(): Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 - assert len(record) == 0 @pytest.mark.filterwarnings("error") diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 3379fba44f8..590ae9ae003 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -8,7 +8,7 @@ from . import assert_array_equal from . import assert_identical as assert_identical_ -from . import mock +from . import assert_no_warnings, mock def assert_identical(a, b): @@ -164,9 +164,8 @@ def test_xarray_ufuncs_deprecation(): with pytest.warns(FutureWarning, match="xarray.ufuncs"): xu.cos(xr.DataArray([0, 1])) - with pytest.warns(None) as record: + with assert_no_warnings(): xu.angle(xr.DataArray([0, 1])) - assert len(record) == 0 @pytest.mark.filterwarnings("ignore::RuntimeWarning") diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 33fff62c304..a88d5a22c0d 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -33,6 +33,7 @@ assert_array_equal, assert_equal, assert_identical, + assert_no_warnings, raise_if_dask_computes, requires_cupy, requires_dask, @@ -2537,9 +2538,8 @@ def __init__(self, array): def test_raise_no_warning_for_nan_in_binary_ops(): - with pytest.warns(None) as record: + with assert_no_warnings(): Variable("x", [1, 2, np.NaN]) > 0 - assert len(record) == 0 class TestBackendIndexing: From d26894bdc3d80acd117a16a528663bcdf26bab32 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 22 Feb 2022 19:49:52 -0800 Subject: [PATCH 06/34] Move Zarr up in io.rst (#6289) * Move Zarr up in io.rst The existing version had it right down the page, below Iris / Pickle / et al. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/contributing.rst | 4 +- doc/user-guide/io.rst | 858 ++++++++++++++++++------------------- doc/whats-new.rst | 4 +- xarray/core/computation.py | 2 +- 4 files changed, 434 insertions(+), 434 deletions(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index df279caa54f..0913702fd83 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -274,13 +274,13 @@ Some other important things to know about the docs: .. ipython:: python x = 2 - x ** 3 + x**3 will be rendered as:: In [1]: x = 2 - In [2]: x ** 3 + In [2]: x**3 Out[2]: 8 Almost all code examples in the docs are run (and the output saved) during the diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 28eeeeda99b..834e9ad2464 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -498,596 +498,596 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: Note that this produces a file that is likely to be not readable by other netCDF libraries! -.. _io.iris: +.. _io.zarr: -Iris +Zarr ---- -The Iris_ tool allows easy reading of common meteorological and climate model formats -(including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very -similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is -installed, xarray can convert a ``DataArray`` into a ``Cube`` using -:py:meth:`DataArray.to_iris`: - -.. ipython:: python +`Zarr`_ is a Python package that provides an implementation of chunked, compressed, +N-dimensional arrays. +Zarr has the ability to store arrays in a range of ways, including in memory, +in files, and in cloud-based object storage such as `Amazon S3`_ and +`Google Cloud Storage`_. +Xarray's Zarr backend allows xarray to leverage these capabilities, including +the ability to store and analyze datasets far too large fit onto disk +(particularly :ref:`in combination with dask `). - da = xr.DataArray( - np.random.rand(4, 5), - dims=["x", "y"], - coords=dict(x=[10, 20, 30, 40], y=pd.date_range("2000-01-01", periods=5)), - ) +Xarray can't open just any zarr dataset, because xarray requires special +metadata (attributes) describing the dataset dimensions and coordinates. +At this time, xarray can only open zarr datasets that have been written by +xarray. For implementation details, see :ref:`zarr_encoding`. - cube = da.to_iris() - cube +To write a dataset with zarr, we use the :py:meth:`Dataset.to_zarr` method. -Conversely, we can create a new ``DataArray`` object from a ``Cube`` using -:py:meth:`DataArray.from_iris`: +To write to a local directory, we pass a path to a directory: .. ipython:: python + :suppress: - da_cube = xr.DataArray.from_iris(cube) - da_cube + ! rm -rf path/to/directory.zarr +.. ipython:: python -.. _Iris: https://scitools.org.uk/iris + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + ds.to_zarr("path/to/directory.zarr") +(The suffix ``.zarr`` is optional--just a reminder that a zarr store lives +there.) If the directory does not exist, it will be created. If a zarr +store is already present at that path, an error will be raised, preventing it +from being overwritten. To override this behavior and overwrite an existing +store, add ``mode='w'`` when invoking :py:meth:`~Dataset.to_zarr`. -OPeNDAP -------- +To store variable length strings, convert them to object arrays first with +``dtype=object``. -Xarray includes support for `OPeNDAP`__ (via the netCDF4 library or Pydap), which -lets us access large datasets over HTTP. +To read back a zarr dataset that has been created this way, we use the +:py:func:`open_zarr` method: -__ https://www.opendap.org/ +.. ipython:: python -For example, we can open a connection to GBs of weather data produced by the -`PRISM`__ project, and hosted by `IRI`__ at Columbia: + ds_zarr = xr.open_zarr("path/to/directory.zarr") + ds_zarr -__ https://www.prism.oregonstate.edu/ -__ https://iri.columbia.edu/ +Cloud Storage Buckets +~~~~~~~~~~~~~~~~~~~~~ -.. ipython source code for this section - we don't use this to avoid hitting the DAP server on every doc build. +It is possible to read and write xarray datasets directly from / to cloud +storage buckets using zarr. This example uses the `gcsfs`_ package to provide +an interface to `Google Cloud Storage`_. - remote_data = xr.open_dataset( - 'http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods', - decode_times=False) - tmax = remote_data.tmax[:500, ::3, ::3] - tmax +From v0.16.2: general `fsspec`_ URLs are parsed and the store set up for you +automatically when reading, such that you can open a dataset in a single +call. You should include any arguments to the storage backend as the +key ``storage_options``, part of ``backend_kwargs``. - @savefig opendap-prism-tmax.png - tmax[0].plot() +.. code:: python -.. ipython:: - :verbatim: + ds_gcs = xr.open_dataset( + "gcs:///path.zarr", + backend_kwargs={ + "storage_options": {"project": "", "token": None} + }, + engine="zarr", + ) - In [3]: remote_data = xr.open_dataset( - ...: "http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods", - ...: decode_times=False, - ...: ) - In [4]: remote_data - Out[4]: - - Dimensions: (T: 1422, X: 1405, Y: 621) - Coordinates: - * X (X) float32 -125.0 -124.958 -124.917 -124.875 -124.833 -124.792 -124.75 ... - * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 -772.5 -771.5 ... - * Y (Y) float32 49.9167 49.875 49.8333 49.7917 49.75 49.7083 49.6667 49.625 ... - Data variables: - ppt (T, Y, X) float64 ... - tdmean (T, Y, X) float64 ... - tmax (T, Y, X) float64 ... - tmin (T, Y, X) float64 ... - Attributes: - Conventions: IRIDL - expires: 1375315200 +This also works with ``open_mfdataset``, allowing you to pass a list of paths or +a URL to be interpreted as a glob string. -.. TODO: update this example to show off decode_cf? +For older versions, and for writing, you must explicitly set up a ``MutableMapping`` +instance and pass this, as follows: -.. note:: +.. code:: python - Like many real-world datasets, this dataset does not entirely follow - `CF conventions`_. Unexpected formats will usually cause xarray's automatic - decoding to fail. The way to work around this is to either set - ``decode_cf=False`` in ``open_dataset`` to turn off all use of CF - conventions, or by only disabling the troublesome parser. - In this case, we set ``decode_times=False`` because the time axis here - provides the calendar attribute in a format that xarray does not expect - (the integer ``360`` instead of a string like ``'360_day'``). + import gcsfs -We can select and slice this data any number of times, and nothing is loaded -over the network until we look at particular values: + fs = gcsfs.GCSFileSystem(project="", token=None) + gcsmap = gcsfs.mapping.GCSMap("", gcs=fs, check=True, create=False) + # write to the bucket + ds.to_zarr(store=gcsmap) + # read it back + ds_gcs = xr.open_zarr(gcsmap) -.. ipython:: - :verbatim: +(or use the utility function ``fsspec.get_mapper()``). - In [4]: tmax = remote_data["tmax"][:500, ::3, ::3] +.. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ +.. _Zarr: https://zarr.readthedocs.io/ +.. _Amazon S3: https://aws.amazon.com/s3/ +.. _Google Cloud Storage: https://cloud.google.com/storage/ +.. _gcsfs: https://github.com/fsspec/gcsfs - In [5]: tmax - Out[5]: - - [48541500 values with dtype=float64] - Coordinates: - * Y (Y) float32 49.9167 49.7917 49.6667 49.5417 49.4167 49.2917 ... - * X (X) float32 -125.0 -124.875 -124.75 -124.625 -124.5 -124.375 ... - * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 ... - Attributes: - pointwidth: 120 - standard_name: air_temperature - units: Celsius_scale - expires: 1443657600 +Zarr Compressors and Filters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # the data is downloaded automatically when we make the plot - In [6]: tmax[0].plot() +There are many different options for compression and filtering possible with +zarr. These are described in the +`zarr documentation `_. +These options can be passed to the ``to_zarr`` method as variable encoding. +For example: -.. image:: ../_static/opendap-prism-tmax.png +.. ipython:: python + :suppress: -Some servers require authentication before we can access the data. For this -purpose we can explicitly create a :py:class:`backends.PydapDataStore` -and pass in a `Requests`__ session object. For example for -HTTP Basic authentication:: + ! rm -rf foo.zarr - import xarray as xr - import requests +.. ipython:: python - session = requests.Session() - session.auth = ('username', 'password') + import zarr - store = xr.backends.PydapDataStore.open('http://example.com/data', - session=session) - ds = xr.open_dataset(store) + compressor = zarr.Blosc(cname="zstd", clevel=3, shuffle=2) + ds.to_zarr("foo.zarr", encoding={"foo": {"compressor": compressor}}) -`Pydap's cas module`__ has functions that generate custom sessions for -servers that use CAS single sign-on. For example, to connect to servers -that require NASA's URS authentication:: +.. note:: - import xarray as xr - from pydata.cas.urs import setup_session + Not all native zarr compression and filtering options have been tested with + xarray. - ds_url = 'https://gpm1.gesdisc.eosdis.nasa.gov/opendap/hyrax/example.nc' +.. _io.zarr.consolidated_metadata: - session = setup_session('username', 'password', check_url=ds_url) - store = xr.backends.PydapDataStore.open(ds_url, session=session) +Consolidated Metadata +~~~~~~~~~~~~~~~~~~~~~ - ds = xr.open_dataset(store) +Xarray needs to read all of the zarr metadata when it opens a dataset. +In some storage mediums, such as with cloud object storage (e.g. amazon S3), +this can introduce significant overhead, because two separate HTTP calls to the +object store must be made for each variable in the dataset. +As of xarray version 0.18, xarray by default uses a feature called +*consolidated metadata*, storing all metadata for the entire dataset with a +single key (by default called ``.zmetadata``). This typically drastically speeds +up opening the store. (For more information on this feature, consult the +`zarr docs `_.) -__ https://docs.python-requests.org -__ https://www.pydap.org/en/latest/client.html#authentication +By default, xarray writes consolidated metadata and attempts to read stores +with consolidated metadata, falling back to use non-consolidated metadata for +reads. Because this fall-back option is so much slower, xarray issues a +``RuntimeWarning`` with guidance when reading with consolidated metadata fails: -.. _io.pickle: + Failed to open Zarr store with consolidated metadata, falling back to try + reading non-consolidated metadata. This is typically much slower for + opening a dataset. To silence this warning, consider: -Pickle ------- + 1. Consolidating metadata in this existing store with + :py:func:`zarr.consolidate_metadata`. + 2. Explicitly setting ``consolidated=False``, to avoid trying to read + consolidate metadata. + 3. Explicitly setting ``consolidated=True``, to raise an error in this case + instead of falling back to try reading non-consolidated metadata. -The simplest way to serialize an xarray object is to use Python's built-in pickle -module: +.. _io.zarr.appending: -.. ipython:: python +Appending to existing Zarr stores +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - import pickle +Xarray supports several ways of incrementally writing variables to a Zarr +store. These options are useful for scenarios when it is infeasible or +undesirable to write your entire dataset at once. - # use the highest protocol (-1) because it is way faster than the default - # text based pickle format - pkl = pickle.dumps(ds, protocol=-1) +.. tip:: - pickle.loads(pkl) + If you can load all of your data into a single ``Dataset`` using dask, a + single call to ``to_zarr()`` will write all of your data in parallel. -Pickling is important because it doesn't require any external libraries -and lets you use xarray objects with Python modules like -:py:mod:`multiprocessing` or :ref:`Dask `. However, pickling is -**not recommended for long-term storage**. +.. warning:: -Restoring a pickle requires that the internal structure of the types for the -pickled data remain unchanged. Because the internal design of xarray is still -being refined, we make no guarantees (at this point) that objects pickled with -this version of xarray will work in future versions. - -.. note:: + Alignment of coordinates is currently not checked when modifying an + existing Zarr store. It is up to the user to ensure that coordinates are + consistent. - When pickling an object opened from a NetCDF file, the pickle file will - contain a reference to the file on disk. If you want to store the actual - array values, load it into memory first with :py:meth:`Dataset.load` - or :py:meth:`Dataset.compute`. +To add or overwrite entire variables, simply call :py:meth:`~Dataset.to_zarr` +with ``mode='a'`` on a Dataset containing the new variables, passing in an +existing Zarr store or path to a Zarr store. -.. _dictionary io: +To resize and then append values along an existing dimension in a store, set +``append_dim``. This is a good option if data always arives in a particular +order, e.g., for time-stepping a simulation: -Dictionary ----------- +.. ipython:: python + :suppress: -We can convert a ``Dataset`` (or a ``DataArray``) to a dict using -:py:meth:`Dataset.to_dict`: + ! rm -rf path/to/directory.zarr .. ipython:: python - d = ds.to_dict() - d + ds1 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-01", periods=2), + }, + ) + ds1.to_zarr("path/to/directory.zarr") + ds2 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-03", periods=2), + }, + ) + ds2.to_zarr("path/to/directory.zarr", append_dim="t") -We can create a new xarray object from a dict using -:py:meth:`Dataset.from_dict`: +Finally, you can use ``region`` to write to limited regions of existing arrays +in an existing Zarr store. This is a good option for writing data in parallel +from independent processes. + +To scale this up to writing large datasets, the first step is creating an +initial Zarr store without writing all of its array data. This can be done by +first creating a ``Dataset`` with dummy values stored in :ref:`dask `, +and then calling ``to_zarr`` with ``compute=False`` to write only metadata +(including ``attrs``) to Zarr: .. ipython:: python + :suppress: - ds_dict = xr.Dataset.from_dict(d) - ds_dict + ! rm -rf path/to/directory.zarr -Dictionary support allows for flexible use of xarray objects. It doesn't -require external libraries and dicts can easily be pickled, or converted to -json, or geojson. All the values are converted to lists, so dicts might -be quite large. +.. ipython:: python -To export just the dataset schema without the data itself, use the -``data=False`` option: + import dask.array + + # The values of this dask array are entirely irrelevant; only the dtype, + # shape and chunks are used + dummies = dask.array.zeros(30, chunks=10) + ds = xr.Dataset({"foo": ("x", dummies)}) + path = "path/to/directory.zarr" + # Now we write the metadata without computing any array values + ds.to_zarr(path, compute=False) + +Now, a Zarr store with the correct variable shapes and attributes exists that +can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a +mapping from dimension names to Python ``slice`` objects indicating where the +data should be written (in index space, not coordinate space), e.g., .. ipython:: python - ds.to_dict(data=False) + # For convenience, we'll slice a single dataset, but in the real use-case + # we would create them separately possibly even from separate processes. + ds = xr.Dataset({"foo": ("x", np.arange(30))}) + ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)}) + ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)}) + ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)}) -This can be useful for generating indices of dataset contents to expose to -search indices or other automated data discovery tools. +Concurrent writes with ``region`` are safe as long as they modify distinct +chunks in the underlying Zarr arrays (or use an appropriate ``lock``). -.. ipython:: python - :suppress: +As a safety check to make it harder to inadvertently override existing values, +if you set ``region`` then *all* variables included in a Dataset must have +dimensions included in ``region``. Other variables (typically coordinates) +need to be explicitly dropped and/or written in a separate calls to ``to_zarr`` +with ``mode='a'``. - import os +.. _io.iris: - os.remove("saved_on_disk.nc") +Iris +---- -.. _io.rasterio: +The Iris_ tool allows easy reading of common meteorological and climate model formats +(including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very +similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is +installed, xarray can convert a ``DataArray`` into a ``Cube`` using +:py:meth:`DataArray.to_iris`: -Rasterio --------- +.. ipython:: python -GeoTIFFs and other gridded raster datasets can be opened using `rasterio`_, if -rasterio is installed. Here is an example of how to use -:py:func:`open_rasterio` to read one of rasterio's `test files`_: + da = xr.DataArray( + np.random.rand(4, 5), + dims=["x", "y"], + coords=dict(x=[10, 20, 30, 40], y=pd.date_range("2000-01-01", periods=5)), + ) -.. deprecated:: 0.20.0 + cube = da.to_iris() + cube - Deprecated in favor of rioxarray. - For information about transitioning, see: - https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html +Conversely, we can create a new ``DataArray`` object from a ``Cube`` using +:py:meth:`DataArray.from_iris`: -.. ipython:: - :verbatim: +.. ipython:: python - In [7]: rio = xr.open_rasterio("RGB.byte.tif") + da_cube = xr.DataArray.from_iris(cube) + da_cube - In [8]: rio - Out[8]: - - [1703814 values with dtype=uint8] - Coordinates: - * band (band) int64 1 2 3 - * y (y) float64 2.827e+06 2.826e+06 2.826e+06 2.826e+06 2.826e+06 ... - * x (x) float64 1.021e+05 1.024e+05 1.027e+05 1.03e+05 1.033e+05 ... - Attributes: - res: (300.0379266750948, 300.041782729805) - transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.041782729805, 28... - is_tiled: 0 - crs: +init=epsg:32618 +.. _Iris: https://scitools.org.uk/iris -The ``x`` and ``y`` coordinates are generated out of the file's metadata -(``bounds``, ``width``, ``height``), and they can be understood as cartesian -coordinates defined in the file's projection provided by the ``crs`` attribute. -``crs`` is a PROJ4 string which can be parsed by e.g. `pyproj`_ or rasterio. -See :ref:`/examples/visualization_gallery.ipynb#Parsing-rasterio-geocoordinates` -for an example of how to convert these to longitudes and latitudes. +OPeNDAP +------- -Additionally, you can use `rioxarray`_ for reading in GeoTiff, netCDF or other -GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIFF. -`rioxarray`_ can also handle geospatial related tasks such as re-projecting and clipping. +Xarray includes support for `OPeNDAP`__ (via the netCDF4 library or Pydap), which +lets us access large datasets over HTTP. -.. ipython:: - :verbatim: +__ https://www.opendap.org/ - In [1]: import rioxarray +For example, we can open a connection to GBs of weather data produced by the +`PRISM`__ project, and hosted by `IRI`__ at Columbia: - In [2]: rds = rioxarray.open_rasterio("RGB.byte.tif") +__ https://www.prism.oregonstate.edu/ +__ https://iri.columbia.edu/ - In [3]: rds - Out[3]: - - [1703814 values with dtype=uint8] - Coordinates: - * band (band) int64 1 2 3 - * y (y) float64 2.827e+06 2.826e+06 ... 2.612e+06 2.612e+06 - * x (x) float64 1.021e+05 1.024e+05 ... 3.389e+05 3.392e+05 - spatial_ref int64 0 - Attributes: - STATISTICS_MAXIMUM: 255 - STATISTICS_MEAN: 29.947726688477 - STATISTICS_MINIMUM: 0 - STATISTICS_STDDEV: 52.340921626611 - transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.0417827... - _FillValue: 0.0 - scale_factor: 1.0 - add_offset: 0.0 - grid_mapping: spatial_ref +.. ipython source code for this section + we don't use this to avoid hitting the DAP server on every doc build. - In [4]: rds.rio.crs - Out[4]: CRS.from_epsg(32618) + remote_data = xr.open_dataset( + 'http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods', + decode_times=False) + tmax = remote_data.tmax[:500, ::3, ::3] + tmax - In [5]: rds4326 = rds.rio.reproject("epsg:4326") + @savefig opendap-prism-tmax.png + tmax[0].plot() - In [6]: rds4326.rio.crs - Out[6]: CRS.from_epsg(4326) +.. ipython:: + :verbatim: - In [7]: rds4326.rio.to_raster("RGB.byte.4326.tif") + In [3]: remote_data = xr.open_dataset( + ...: "http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods", + ...: decode_times=False, + ...: ) + In [4]: remote_data + Out[4]: + + Dimensions: (T: 1422, X: 1405, Y: 621) + Coordinates: + * X (X) float32 -125.0 -124.958 -124.917 -124.875 -124.833 -124.792 -124.75 ... + * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 -772.5 -771.5 ... + * Y (Y) float32 49.9167 49.875 49.8333 49.7917 49.75 49.7083 49.6667 49.625 ... + Data variables: + ppt (T, Y, X) float64 ... + tdmean (T, Y, X) float64 ... + tmax (T, Y, X) float64 ... + tmin (T, Y, X) float64 ... + Attributes: + Conventions: IRIDL + expires: 1375315200 -.. _rasterio: https://rasterio.readthedocs.io/en/latest/ -.. _rioxarray: https://corteva.github.io/rioxarray/stable/ -.. _test files: https://github.com/rasterio/rasterio/blob/master/tests/data/RGB.byte.tif -.. _pyproj: https://github.com/pyproj4/pyproj +.. TODO: update this example to show off decode_cf? -.. _io.zarr: +.. note:: -Zarr ----- + Like many real-world datasets, this dataset does not entirely follow + `CF conventions`_. Unexpected formats will usually cause xarray's automatic + decoding to fail. The way to work around this is to either set + ``decode_cf=False`` in ``open_dataset`` to turn off all use of CF + conventions, or by only disabling the troublesome parser. + In this case, we set ``decode_times=False`` because the time axis here + provides the calendar attribute in a format that xarray does not expect + (the integer ``360`` instead of a string like ``'360_day'``). -`Zarr`_ is a Python package that provides an implementation of chunked, compressed, -N-dimensional arrays. -Zarr has the ability to store arrays in a range of ways, including in memory, -in files, and in cloud-based object storage such as `Amazon S3`_ and -`Google Cloud Storage`_. -Xarray's Zarr backend allows xarray to leverage these capabilities, including -the ability to store and analyze datasets far too large fit onto disk -(particularly :ref:`in combination with dask `). +We can select and slice this data any number of times, and nothing is loaded +over the network until we look at particular values: -Xarray can't open just any zarr dataset, because xarray requires special -metadata (attributes) describing the dataset dimensions and coordinates. -At this time, xarray can only open zarr datasets that have been written by -xarray. For implementation details, see :ref:`zarr_encoding`. +.. ipython:: + :verbatim: -To write a dataset with zarr, we use the :py:meth:`Dataset.to_zarr` method. + In [4]: tmax = remote_data["tmax"][:500, ::3, ::3] -To write to a local directory, we pass a path to a directory: + In [5]: tmax + Out[5]: + + [48541500 values with dtype=float64] + Coordinates: + * Y (Y) float32 49.9167 49.7917 49.6667 49.5417 49.4167 49.2917 ... + * X (X) float32 -125.0 -124.875 -124.75 -124.625 -124.5 -124.375 ... + * T (T) float32 -779.5 -778.5 -777.5 -776.5 -775.5 -774.5 -773.5 ... + Attributes: + pointwidth: 120 + standard_name: air_temperature + units: Celsius_scale + expires: 1443657600 -.. ipython:: python - :suppress: + # the data is downloaded automatically when we make the plot + In [6]: tmax[0].plot() - ! rm -rf path/to/directory.zarr +.. image:: ../_static/opendap-prism-tmax.png -.. ipython:: python +Some servers require authentication before we can access the data. For this +purpose we can explicitly create a :py:class:`backends.PydapDataStore` +and pass in a `Requests`__ session object. For example for +HTTP Basic authentication:: - ds = xr.Dataset( - {"foo": (("x", "y"), np.random.rand(4, 5))}, - coords={ - "x": [10, 20, 30, 40], - "y": pd.date_range("2000-01-01", periods=5), - "z": ("x", list("abcd")), - }, - ) - ds.to_zarr("path/to/directory.zarr") + import xarray as xr + import requests -(The suffix ``.zarr`` is optional--just a reminder that a zarr store lives -there.) If the directory does not exist, it will be created. If a zarr -store is already present at that path, an error will be raised, preventing it -from being overwritten. To override this behavior and overwrite an existing -store, add ``mode='w'`` when invoking :py:meth:`~Dataset.to_zarr`. + session = requests.Session() + session.auth = ('username', 'password') -To store variable length strings, convert them to object arrays first with -``dtype=object``. + store = xr.backends.PydapDataStore.open('http://example.com/data', + session=session) + ds = xr.open_dataset(store) -To read back a zarr dataset that has been created this way, we use the -:py:func:`open_zarr` method: +`Pydap's cas module`__ has functions that generate custom sessions for +servers that use CAS single sign-on. For example, to connect to servers +that require NASA's URS authentication:: -.. ipython:: python + import xarray as xr + from pydata.cas.urs import setup_session - ds_zarr = xr.open_zarr("path/to/directory.zarr") - ds_zarr + ds_url = 'https://gpm1.gesdisc.eosdis.nasa.gov/opendap/hyrax/example.nc' -Cloud Storage Buckets -~~~~~~~~~~~~~~~~~~~~~ + session = setup_session('username', 'password', check_url=ds_url) + store = xr.backends.PydapDataStore.open(ds_url, session=session) -It is possible to read and write xarray datasets directly from / to cloud -storage buckets using zarr. This example uses the `gcsfs`_ package to provide -an interface to `Google Cloud Storage`_. + ds = xr.open_dataset(store) -From v0.16.2: general `fsspec`_ URLs are parsed and the store set up for you -automatically when reading, such that you can open a dataset in a single -call. You should include any arguments to the storage backend as the -key ``storage_options``, part of ``backend_kwargs``. +__ https://docs.python-requests.org +__ https://www.pydap.org/en/latest/client.html#authentication -.. code:: python +.. _io.pickle: - ds_gcs = xr.open_dataset( - "gcs:///path.zarr", - backend_kwargs={ - "storage_options": {"project": "", "token": None} - }, - engine="zarr", - ) +Pickle +------ +The simplest way to serialize an xarray object is to use Python's built-in pickle +module: -This also works with ``open_mfdataset``, allowing you to pass a list of paths or -a URL to be interpreted as a glob string. +.. ipython:: python -For older versions, and for writing, you must explicitly set up a ``MutableMapping`` -instance and pass this, as follows: + import pickle -.. code:: python + # use the highest protocol (-1) because it is way faster than the default + # text based pickle format + pkl = pickle.dumps(ds, protocol=-1) - import gcsfs + pickle.loads(pkl) - fs = gcsfs.GCSFileSystem(project="", token=None) - gcsmap = gcsfs.mapping.GCSMap("", gcs=fs, check=True, create=False) - # write to the bucket - ds.to_zarr(store=gcsmap) - # read it back - ds_gcs = xr.open_zarr(gcsmap) +Pickling is important because it doesn't require any external libraries +and lets you use xarray objects with Python modules like +:py:mod:`multiprocessing` or :ref:`Dask `. However, pickling is +**not recommended for long-term storage**. -(or use the utility function ``fsspec.get_mapper()``). +Restoring a pickle requires that the internal structure of the types for the +pickled data remain unchanged. Because the internal design of xarray is still +being refined, we make no guarantees (at this point) that objects pickled with +this version of xarray will work in future versions. -.. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ -.. _Zarr: https://zarr.readthedocs.io/ -.. _Amazon S3: https://aws.amazon.com/s3/ -.. _Google Cloud Storage: https://cloud.google.com/storage/ -.. _gcsfs: https://github.com/fsspec/gcsfs +.. note:: -Zarr Compressors and Filters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + When pickling an object opened from a NetCDF file, the pickle file will + contain a reference to the file on disk. If you want to store the actual + array values, load it into memory first with :py:meth:`Dataset.load` + or :py:meth:`Dataset.compute`. -There are many different options for compression and filtering possible with -zarr. These are described in the -`zarr documentation `_. -These options can be passed to the ``to_zarr`` method as variable encoding. -For example: +.. _dictionary io: -.. ipython:: python - :suppress: +Dictionary +---------- - ! rm -rf foo.zarr +We can convert a ``Dataset`` (or a ``DataArray``) to a dict using +:py:meth:`Dataset.to_dict`: .. ipython:: python - import zarr - - compressor = zarr.Blosc(cname="zstd", clevel=3, shuffle=2) - ds.to_zarr("foo.zarr", encoding={"foo": {"compressor": compressor}}) + d = ds.to_dict() + d -.. note:: +We can create a new xarray object from a dict using +:py:meth:`Dataset.from_dict`: - Not all native zarr compression and filtering options have been tested with - xarray. +.. ipython:: python -.. _io.zarr.consolidated_metadata: + ds_dict = xr.Dataset.from_dict(d) + ds_dict -Consolidated Metadata -~~~~~~~~~~~~~~~~~~~~~ +Dictionary support allows for flexible use of xarray objects. It doesn't +require external libraries and dicts can easily be pickled, or converted to +json, or geojson. All the values are converted to lists, so dicts might +be quite large. -Xarray needs to read all of the zarr metadata when it opens a dataset. -In some storage mediums, such as with cloud object storage (e.g. amazon S3), -this can introduce significant overhead, because two separate HTTP calls to the -object store must be made for each variable in the dataset. -As of xarray version 0.18, xarray by default uses a feature called -*consolidated metadata*, storing all metadata for the entire dataset with a -single key (by default called ``.zmetadata``). This typically drastically speeds -up opening the store. (For more information on this feature, consult the -`zarr docs `_.) +To export just the dataset schema without the data itself, use the +``data=False`` option: -By default, xarray writes consolidated metadata and attempts to read stores -with consolidated metadata, falling back to use non-consolidated metadata for -reads. Because this fall-back option is so much slower, xarray issues a -``RuntimeWarning`` with guidance when reading with consolidated metadata fails: +.. ipython:: python - Failed to open Zarr store with consolidated metadata, falling back to try - reading non-consolidated metadata. This is typically much slower for - opening a dataset. To silence this warning, consider: + ds.to_dict(data=False) - 1. Consolidating metadata in this existing store with - :py:func:`zarr.consolidate_metadata`. - 2. Explicitly setting ``consolidated=False``, to avoid trying to read - consolidate metadata. - 3. Explicitly setting ``consolidated=True``, to raise an error in this case - instead of falling back to try reading non-consolidated metadata. +This can be useful for generating indices of dataset contents to expose to +search indices or other automated data discovery tools. -.. _io.zarr.appending: +.. ipython:: python + :suppress: -Appending to existing Zarr stores -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + import os -Xarray supports several ways of incrementally writing variables to a Zarr -store. These options are useful for scenarios when it is infeasible or -undesirable to write your entire dataset at once. + os.remove("saved_on_disk.nc") -.. tip:: +.. _io.rasterio: - If you can load all of your data into a single ``Dataset`` using dask, a - single call to ``to_zarr()`` will write all of your data in parallel. +Rasterio +-------- -.. warning:: +GeoTIFFs and other gridded raster datasets can be opened using `rasterio`_, if +rasterio is installed. Here is an example of how to use +:py:func:`open_rasterio` to read one of rasterio's `test files`_: - Alignment of coordinates is currently not checked when modifying an - existing Zarr store. It is up to the user to ensure that coordinates are - consistent. +.. deprecated:: 0.20.0 -To add or overwrite entire variables, simply call :py:meth:`~Dataset.to_zarr` -with ``mode='a'`` on a Dataset containing the new variables, passing in an -existing Zarr store or path to a Zarr store. + Deprecated in favor of rioxarray. + For information about transitioning, see: + https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html -To resize and then append values along an existing dimension in a store, set -``append_dim``. This is a good option if data always arives in a particular -order, e.g., for time-stepping a simulation: +.. ipython:: + :verbatim: -.. ipython:: python - :suppress: + In [7]: rio = xr.open_rasterio("RGB.byte.tif") - ! rm -rf path/to/directory.zarr + In [8]: rio + Out[8]: + + [1703814 values with dtype=uint8] + Coordinates: + * band (band) int64 1 2 3 + * y (y) float64 2.827e+06 2.826e+06 2.826e+06 2.826e+06 2.826e+06 ... + * x (x) float64 1.021e+05 1.024e+05 1.027e+05 1.03e+05 1.033e+05 ... + Attributes: + res: (300.0379266750948, 300.041782729805) + transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.041782729805, 28... + is_tiled: 0 + crs: +init=epsg:32618 -.. ipython:: python - ds1 = xr.Dataset( - {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, - coords={ - "x": [10, 20, 30, 40], - "y": [1, 2, 3, 4, 5], - "t": pd.date_range("2001-01-01", periods=2), - }, - ) - ds1.to_zarr("path/to/directory.zarr") - ds2 = xr.Dataset( - {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, - coords={ - "x": [10, 20, 30, 40], - "y": [1, 2, 3, 4, 5], - "t": pd.date_range("2001-01-03", periods=2), - }, - ) - ds2.to_zarr("path/to/directory.zarr", append_dim="t") +The ``x`` and ``y`` coordinates are generated out of the file's metadata +(``bounds``, ``width``, ``height``), and they can be understood as cartesian +coordinates defined in the file's projection provided by the ``crs`` attribute. +``crs`` is a PROJ4 string which can be parsed by e.g. `pyproj`_ or rasterio. +See :ref:`/examples/visualization_gallery.ipynb#Parsing-rasterio-geocoordinates` +for an example of how to convert these to longitudes and latitudes. -Finally, you can use ``region`` to write to limited regions of existing arrays -in an existing Zarr store. This is a good option for writing data in parallel -from independent processes. -To scale this up to writing large datasets, the first step is creating an -initial Zarr store without writing all of its array data. This can be done by -first creating a ``Dataset`` with dummy values stored in :ref:`dask `, -and then calling ``to_zarr`` with ``compute=False`` to write only metadata -(including ``attrs``) to Zarr: +Additionally, you can use `rioxarray`_ for reading in GeoTiff, netCDF or other +GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIFF. +`rioxarray`_ can also handle geospatial related tasks such as re-projecting and clipping. -.. ipython:: python - :suppress: +.. ipython:: + :verbatim: - ! rm -rf path/to/directory.zarr + In [1]: import rioxarray -.. ipython:: python + In [2]: rds = rioxarray.open_rasterio("RGB.byte.tif") - import dask.array + In [3]: rds + Out[3]: + + [1703814 values with dtype=uint8] + Coordinates: + * band (band) int64 1 2 3 + * y (y) float64 2.827e+06 2.826e+06 ... 2.612e+06 2.612e+06 + * x (x) float64 1.021e+05 1.024e+05 ... 3.389e+05 3.392e+05 + spatial_ref int64 0 + Attributes: + STATISTICS_MAXIMUM: 255 + STATISTICS_MEAN: 29.947726688477 + STATISTICS_MINIMUM: 0 + STATISTICS_STDDEV: 52.340921626611 + transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.0417827... + _FillValue: 0.0 + scale_factor: 1.0 + add_offset: 0.0 + grid_mapping: spatial_ref - # The values of this dask array are entirely irrelevant; only the dtype, - # shape and chunks are used - dummies = dask.array.zeros(30, chunks=10) - ds = xr.Dataset({"foo": ("x", dummies)}) - path = "path/to/directory.zarr" - # Now we write the metadata without computing any array values - ds.to_zarr(path, compute=False) + In [4]: rds.rio.crs + Out[4]: CRS.from_epsg(32618) -Now, a Zarr store with the correct variable shapes and attributes exists that -can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a -mapping from dimension names to Python ``slice`` objects indicating where the -data should be written (in index space, not coordinate space), e.g., + In [5]: rds4326 = rds.rio.reproject("epsg:4326") -.. ipython:: python + In [6]: rds4326.rio.crs + Out[6]: CRS.from_epsg(4326) - # For convenience, we'll slice a single dataset, but in the real use-case - # we would create them separately possibly even from separate processes. - ds = xr.Dataset({"foo": ("x", np.arange(30))}) - ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)}) - ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)}) - ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)}) + In [7]: rds4326.rio.to_raster("RGB.byte.4326.tif") -Concurrent writes with ``region`` are safe as long as they modify distinct -chunks in the underlying Zarr arrays (or use an appropriate ``lock``). -As a safety check to make it harder to inadvertently override existing values, -if you set ``region`` then *all* variables included in a Dataset must have -dimensions included in ``region``. Other variables (typically coordinates) -need to be explicitly dropped and/or written in a separate calls to ``to_zarr`` -with ``mode='a'``. +.. _rasterio: https://rasterio.readthedocs.io/en/latest/ +.. _rioxarray: https://corteva.github.io/rioxarray/stable/ +.. _test files: https://github.com/rasterio/rasterio/blob/master/tests/data/RGB.byte.tif +.. _pyproj: https://github.com/pyproj4/pyproj .. _io.cfgrib: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 37abf931eb4..aa48bd619e8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -5133,7 +5133,7 @@ Enhancements .. ipython:: python ds = xray.Dataset(coords={"x": range(100), "y": range(100)}) - ds["distance"] = np.sqrt(ds.x ** 2 + ds.y ** 2) + ds["distance"] = np.sqrt(ds.x**2 + ds.y**2) @savefig where_example.png width=4in height=4in ds.distance.where(ds.distance < 100).plot() @@ -5341,7 +5341,7 @@ Enhancements .. ipython:: python ds = xray.Dataset({"y": ("x", [1, 2, 3])}) - ds.assign(z=lambda ds: ds.y ** 2) + ds.assign(z=lambda ds: ds.y**2) ds.assign_coords(z=("x", ["a", "b", "c"])) These methods return a new Dataset (or DataArray) with updated data or diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 88eefbdc441..c11bd1a78a4 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -944,7 +944,7 @@ def apply_ufunc( Calculate the vector magnitude of two arguments: >>> def magnitude(a, b): - ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) + ... func = lambda x, y: np.sqrt(x**2 + y**2) ... return xr.apply_ufunc(func, a, b) ... From b86a7c1a5614e7333267cdaab2e9f05e42e2cc0f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 22 Feb 2022 19:50:04 -0800 Subject: [PATCH 07/34] Align language def in bugreport.yml with schema (#6290) * Align language def in bugreport.yml with schema Not sure if this matters at all, but VSCode's linter was complaining * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bugreport.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index bb1febae2ff..043584f3ea6 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -33,14 +33,14 @@ body: Bug reports that follow these guidelines are easier to diagnose, and so are often handled much more quickly. This will be automatically formatted into code, so no need for markdown backticks. - render: python + render: Python - type: textarea id: log-output attributes: label: Relevant log output description: Please copy and paste any relevant output. This will be automatically formatted into code, so no need for markdown backticks. - render: python + render: Python - type: textarea id: extra From b760807646cce493cf8f4f619fcda9a14e84670f Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Wed, 23 Feb 2022 10:51:26 +0100 Subject: [PATCH 08/34] Amended docstring to reflect the actual behaviour of Dataset.map (#6232) * Amended docstring to reflect the actual behaviour of Dataset.map * Update --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index af59f5cd2f1..fb30cf22e04 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5116,9 +5116,9 @@ def map( to transform each DataArray `x` in this dataset into another DataArray. keep_attrs : bool, optional - If True, the dataset's attributes (`attrs`) will be copied from - the original object to the new one. If False, the new object will - be returned without attributes. + If True, both the dataset's and variables' attributes (`attrs`) will be + copied from the original objects to the new ones. If False, the new dataset + and variables will be returned without copying the attributes. args : tuple, optional Positional arguments passed on to `func`. **kwargs : Any From 964bee8ba4f84794be41e279bdce62931b669269 Mon Sep 17 00:00:00 2001 From: Romain Caneill Date: Wed, 23 Feb 2022 18:54:28 +0100 Subject: [PATCH 09/34] Adding the new wrapper gsw-xarray (#6294) --- doc/ecosystem.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index 469f83d37c1..2b49b1529e1 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -17,6 +17,7 @@ Geosciences - `climpred `_: Analysis of ensemble forecast models for climate prediction. - `geocube `_: Tool to convert geopandas vector data into rasterized xarray data. - `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope). +- `gsw-xarray `_: a wrapper around `gsw `_ that adds CF compliant attributes when possible, units, name. - `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meteorology data - `marc_analysis `_: Analysis package for CESM/MARC experiments and output. - `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data. From de965f342e1c9c5de92ab135fbc4062e21e72453 Mon Sep 17 00:00:00 2001 From: Lukas Pilz Date: Wed, 23 Feb 2022 18:54:46 +0100 Subject: [PATCH 10/34] Amended docs on how to add a new backend (#6292) Co-authored-by: Lukas Pilz --- doc/internals/how-to-add-new-backend.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index ceb59c8a3bd..4e4ae4e2e14 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -273,7 +273,7 @@ If you are using `Poetry `_ for your build system, y .. code-block:: toml - [tool.poetry.plugins."xarray_backends"] + [tool.poetry.plugins."xarray.backends"] "my_engine" = "my_package.my_module:MyBackendEntryClass" See https://python-poetry.org/docs/pyproject/#plugins for more information on Poetry plugins. From 17acbb027326ac5f2379fc6cabf425459759e0ca Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:08:30 -0500 Subject: [PATCH 11/34] Drop duplicates over multiple dims, and add Dataset.drop_duplicates (#6307) * tests for da.drop_duplicates over multiple dims * pass tests * test for Dataset.drop_duplicates * piped both paths through dataset.drop_duplicates * added dataset.drop_duplicates to API docs * whats-new entry * correct small bug when raising error --- doc/api.rst | 1 + doc/whats-new.rst | 4 ++- xarray/core/dataarray.py | 17 ++++++---- xarray/core/dataset.py | 39 ++++++++++++++++++++++ xarray/tests/test_dataarray.py | 59 ++++++++++++++++++++++++---------- xarray/tests/test_dataset.py | 31 ++++++++++++++++++ 6 files changed, 126 insertions(+), 25 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index b552bc6b4d2..d2c222da4db 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -106,6 +106,7 @@ Dataset contents Dataset.swap_dims Dataset.expand_dims Dataset.drop_vars + Dataset.drop_duplicates Dataset.drop_dims Dataset.set_coords Dataset.reset_coords diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aa48bd619e8..24a8042ee66 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -31,7 +31,9 @@ New Features - Enbable to provide more keyword arguments to `pydap` backend when reading OpenDAP datasets (:issue:`6274`). By `Jonas Gliß `. - +- Allow :py:meth:`DataArray.drop_duplicates` to drop duplicates along multiple dimensions at once, + and add :py:meth:`Dataset.drop_duplicates`. (:pull:`6307`) + By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 20e829d293e..b3c45d65818 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4659,14 +4659,15 @@ def curvefit( def drop_duplicates( self, - dim: Hashable, - keep: (str | bool) = "first", + dim: Hashable | Iterable[Hashable] | ..., + keep: Literal["first", "last"] | Literal[False] = "first", ): """Returns a new DataArray with duplicate dimension values removed. Parameters ---------- - dim : dimension label, optional + dim : dimension label or labels + Pass `...` to drop duplicates along all dimensions. keep : {"first", "last", False}, default: "first" Determines which duplicates (if any) to keep. - ``"first"`` : Drop duplicates except for the first occurrence. @@ -4676,11 +4677,13 @@ def drop_duplicates( Returns ------- DataArray + + See Also + -------- + Dataset.drop_duplicates """ - if dim not in self.dims: - raise ValueError(f"'{dim}' not found in dimensions") - indexes = {dim: ~self.get_index(dim).duplicated(keep=keep)} - return self.isel(indexes) + deduplicated = self._to_temp_dataset().drop_duplicates(dim, keep=keep) + return self._from_temp_dataset(deduplicated) def convert_calendar( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fb30cf22e04..be9df9d2e2d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7770,6 +7770,45 @@ def _wrapper(Y, *coords_, **kwargs): return result + def drop_duplicates( + self, + dim: Hashable | Iterable[Hashable] | ..., + keep: Literal["first", "last"] | Literal[False] = "first", + ): + """Returns a new Dataset with duplicate dimension values removed. + + Parameters + ---------- + dim : dimension label or labels + Pass `...` to drop duplicates along all dimensions. + keep : {"first", "last", False}, default: "first" + Determines which duplicates (if any) to keep. + - ``"first"`` : Drop duplicates except for the first occurrence. + - ``"last"`` : Drop duplicates except for the last occurrence. + - False : Drop all duplicates. + + Returns + ------- + Dataset + + See Also + -------- + DataArray.drop_duplicates + """ + if isinstance(dim, str): + dims = (dim,) + elif dim is ...: + dims = self.dims + else: + dims = dim + + missing_dims = set(dims) - set(self.dims) + if missing_dims: + raise ValueError(f"'{missing_dims}' not found in dimensions") + + indexes = {dim: ~self.get_index(dim).duplicated(keep=keep) for dim in dims} + return self.isel(indexes) + def convert_calendar( self, calendar: str, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8d73f9ec7ee..fc82c03c5d9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6618,25 +6618,50 @@ def test_clip(da): result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) -@pytest.mark.parametrize("keep", ["first", "last", False]) -def test_drop_duplicates(keep): - ds = xr.DataArray( - [0, 5, 6, 7], dims="time", coords={"time": [0, 0, 1, 2]}, name="test" - ) +class TestDropDuplicates: + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_drop_duplicates_1d(self, keep): + da = xr.DataArray( + [0, 5, 6, 7], dims="time", coords={"time": [0, 0, 1, 2]}, name="test" + ) - if keep == "first": - data = [0, 6, 7] - time = [0, 1, 2] - elif keep == "last": - data = [5, 6, 7] - time = [0, 1, 2] - else: - data = [6, 7] - time = [1, 2] + if keep == "first": + data = [0, 6, 7] + time = [0, 1, 2] + elif keep == "last": + data = [5, 6, 7] + time = [0, 1, 2] + else: + data = [6, 7] + time = [1, 2] + + expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test") + result = da.drop_duplicates("time", keep=keep) + assert_equal(expected, result) + + with pytest.raises(ValueError, match="['space'] not found"): + da.drop_duplicates("space", keep=keep) + + def test_drop_duplicates_2d(self): + da = xr.DataArray( + [[0, 5, 6, 7], [2, 1, 3, 4]], + dims=["space", "time"], + coords={"space": [10, 10], "time": [0, 0, 1, 2]}, + name="test", + ) + + expected = xr.DataArray( + [[0, 6, 7]], + dims=["space", "time"], + coords={"time": ("time", [0, 1, 2]), "space": ("space", [10])}, + name="test", + ) + + result = da.drop_duplicates(["time", "space"], keep="first") + assert_equal(expected, result) - expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test") - result = ds.drop_duplicates("time", keep=keep) - assert_equal(expected, result) + result = da.drop_duplicates(..., keep="first") + assert_equal(expected, result) class TestNumpyCoercion: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c4fa847e664..7ff75fb791b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6546,6 +6546,37 @@ def test_clip(ds): assert result.dims == ds.dims +class TestDropDuplicates: + @pytest.mark.parametrize("keep", ["first", "last", False]) + def test_drop_duplicates_1d(self, keep): + ds = xr.Dataset( + {"a": ("time", [0, 5, 6, 7]), "b": ("time", [9, 3, 8, 2])}, + coords={"time": [0, 0, 1, 2]}, + ) + + if keep == "first": + a = [0, 6, 7] + b = [9, 8, 2] + time = [0, 1, 2] + elif keep == "last": + a = [5, 6, 7] + b = [3, 8, 2] + time = [0, 1, 2] + else: + a = [6, 7] + b = [8, 2] + time = [1, 2] + + expected = xr.Dataset( + {"a": ("time", a), "b": ("time", b)}, coords={"time": time} + ) + result = ds.drop_duplicates("time", keep=keep) + assert_equal(expected, result) + + with pytest.raises(ValueError, match="['space'] not found"): + ds.drop_duplicates("space", keep=keep) + + class TestNumpyCoercion: def test_from_numpy(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) From 4292bdebd7c9c461b0814605509e90453fe47754 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 28 Feb 2022 10:11:01 +0100 Subject: [PATCH 12/34] from_dict: doctest (#6302) --- xarray/core/dataarray.py | 47 ++++++++++++++++------------ xarray/core/dataset.py | 67 +++++++++++++++++++++++++--------------- 2 files changed, 70 insertions(+), 44 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b3c45d65818..3d720f7fc8b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2873,25 +2873,7 @@ def to_dict(self, data: bool = True) -> dict: @classmethod def from_dict(cls, d: dict) -> DataArray: - """ - Convert a dictionary into an xarray.DataArray - - Input dict can take several forms: - - .. code:: python - - d = {"dims": "t", "data": x} - - d = { - "coords": {"t": {"dims": "t", "data": t, "attrs": {"units": "s"}}}, - "attrs": {"title": "air temperature"}, - "dims": "t", - "data": x, - "name": "a", - } - - where "t" is the name of the dimension, "a" is the name of the array, - and x and t are lists, numpy.arrays, or pandas objects. + """Convert a dictionary into an xarray.DataArray Parameters ---------- @@ -2906,6 +2888,33 @@ def from_dict(cls, d: dict) -> DataArray: -------- DataArray.to_dict Dataset.from_dict + + Examples + -------- + >>> d = {"dims": "t", "data": [1, 2, 3]} + >>> da = xr.DataArray.from_dict(d) + >>> da + + array([1, 2, 3]) + Dimensions without coordinates: t + + >>> d = { + ... "coords": { + ... "t": {"dims": "t", "data": [0, 1, 2], "attrs": {"units": "s"}} + ... }, + ... "attrs": {"title": "air temperature"}, + ... "dims": "t", + ... "data": [10, 20, 30], + ... "name": "a", + ... } + >>> da = xr.DataArray.from_dict(d) + >>> da + + array([10, 20, 30]) + Coordinates: + * t (t) int64 0 1 2 + Attributes: + title: air temperature """ coords = None if "coords" in d: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index be9df9d2e2d..90684c4db87 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5641,31 +5641,7 @@ def to_dict(self, data=True): @classmethod def from_dict(cls, d): - """ - Convert a dictionary into an xarray.Dataset. - - Input dict can take several forms: - - .. code:: python - - d = { - "t": {"dims": ("t"), "data": t}, - "a": {"dims": ("t"), "data": x}, - "b": {"dims": ("t"), "data": y}, - } - - d = { - "coords": {"t": {"dims": "t", "data": t, "attrs": {"units": "s"}}}, - "attrs": {"title": "air temperature"}, - "dims": "t", - "data_vars": { - "a": {"dims": "t", "data": x}, - "b": {"dims": "t", "data": y}, - }, - } - - where "t" is the name of the dimesion, "a" and "b" are names of data - variables and t, x, and y are lists, numpy.arrays or pandas objects. + """Convert a dictionary into an xarray.Dataset. Parameters ---------- @@ -5682,6 +5658,47 @@ def from_dict(cls, d): -------- Dataset.to_dict DataArray.from_dict + + Examples + -------- + >>> d = { + ... "t": {"dims": ("t"), "data": [0, 1, 2]}, + ... "a": {"dims": ("t"), "data": ["a", "b", "c"]}, + ... "b": {"dims": ("t"), "data": [10, 20, 30]}, + ... } + >>> ds = xr.Dataset.from_dict(d) + >>> ds + + Dimensions: (t: 3) + Coordinates: + * t (t) int64 0 1 2 + Data variables: + a (t) >> d = { + ... "coords": { + ... "t": {"dims": "t", "data": [0, 1, 2], "attrs": {"units": "s"}} + ... }, + ... "attrs": {"title": "air temperature"}, + ... "dims": "t", + ... "data_vars": { + ... "a": {"dims": "t", "data": [10, 20, 30]}, + ... "b": {"dims": "t", "data": ["a", "b", "c"]}, + ... }, + ... } + >>> ds = xr.Dataset.from_dict(d) + >>> ds + + Dimensions: (t: 3) + Coordinates: + * t (t) int64 0 1 2 + Data variables: + a (t) int64 10 20 30 + b (t) Date: Mon, 28 Feb 2022 04:53:21 -0500 Subject: [PATCH 13/34] On Windows, enable successful test of opening a dataset containing a cftime index (#6305) * Exit cluster context before deleting temporary directory Previously, on Windows, the scheduler in the outer context prevented deleting the temporary directory upon exiting the inner context of the latter. That caused the test to fail and the temporary directory and file to remain. * Use fixture instead of context manager for temporary directory * Edit whats-new entry --- doc/whats-new.rst | 2 +- xarray/tests/test_distributed.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 24a8042ee66..4729b74640f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -52,7 +52,7 @@ Bug fixes - Variables which are chunked using dask in larger (but aligned) chunks than the target zarr chunk size can now be stored using `to_zarr()` (:pull:`6258`) By `Tobias Kölling `_. -- Multi-file datasets containing encoded :py:class:`cftime.datetime` objects can be read in parallel again (:issue:`6226`, :pull:`6249`). By `Martin Bergemann `_. +- Multi-file datasets containing encoded :py:class:`cftime.datetime` objects can be read in parallel again (:issue:`6226`, :pull:`6249`, :pull:`6305`). By `Martin Bergemann `_ and `Stan West `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index b97032014c4..773733b7b89 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -1,8 +1,6 @@ """ isort:skip_file """ -import os import pickle import numpy as np -import tempfile import pytest from packaging.version import Version @@ -113,19 +111,18 @@ def test_dask_distributed_netcdf_roundtrip( @requires_cftime @requires_netCDF4 -def test_open_mfdataset_can_open_files_with_cftime_index(): +def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) da = xr.DataArray(data, coords={"time": T, "Lon": Lon}, name="test") + file_path = tmp_path / "test.nc" + da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): - with tempfile.TemporaryDirectory() as td: - data_file = os.path.join(td, "test.nc") - da.to_netcdf(data_file) - for parallel in (False, True): - with xr.open_mfdataset(data_file, parallel=parallel) as tf: - assert_identical(tf["test"], da) + for parallel in (False, True): + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) From 555a70e290d8a13deb178da64d1d994a10f6acaf Mon Sep 17 00:00:00 2001 From: Stijn Van Hoey Date: Tue, 1 Mar 2022 16:01:39 +0100 Subject: [PATCH 14/34] Fix class attributes versus init parameters (#6312) --- doc/internals/how-to-add-new-backend.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index 4e4ae4e2e14..940fa7a4d1e 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -338,8 +338,8 @@ This is an example ``BackendArray`` subclass implementation: # other backend specific keyword arguments ): self.shape = shape - self.dtype = lock - self.lock = dtype + self.dtype = dtype + self.lock = lock def __getitem__( self, key: xarray.core.indexing.ExplicitIndexer From 0f91f05e532f4424f3d6afd6e6d5bd5a02ceed55 Mon Sep 17 00:00:00 2001 From: Stan West <38358698+stanwest@users.noreply.github.com> Date: Tue, 1 Mar 2022 11:00:33 -0500 Subject: [PATCH 15/34] Enable running sphinx-build on Windows (#6237) --- .gitignore | 4 +-- doc/conf.py | 4 +-- doc/getting-started-guide/quick-overview.rst | 4 ++- doc/internals/zarr-encoding-spec.rst | 7 +++++ doc/user-guide/dask.rst | 26 +++++++-------- doc/user-guide/io.rst | 33 ++++++++++++-------- doc/user-guide/weather-climate.rst | 4 ++- doc/whats-new.rst | 3 ++ 8 files changed, 53 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 90f4a10ed5f..686c7efa701 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,9 @@ __pycache__ .hypothesis/ # temp files from docs build +doc/*.nc doc/auto_gallery -doc/example.nc +doc/rasm.zarr doc/savefig # C extensions @@ -72,4 +73,3 @@ xarray/tests/data/*.grib.*.idx Icon* .ipynb_checkpoints -doc/rasm.zarr diff --git a/doc/conf.py b/doc/conf.py index 5c4c0a52d43..8ce9efdce88 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -28,9 +28,9 @@ print("python exec:", sys.executable) print("sys.path:", sys.path) -if "conda" in sys.executable: +if "CONDA_DEFAULT_ENV" in os.environ or "conda" in sys.executable: print("conda environment:") - subprocess.run(["conda", "list"]) + subprocess.run([os.environ.get("CONDA_EXE", "conda"), "list"]) else: print("pip environment:") subprocess.run([sys.executable, "-m", "pip", "list"]) diff --git a/doc/getting-started-guide/quick-overview.rst b/doc/getting-started-guide/quick-overview.rst index cd4b66d2f6f..ee13fea8bf1 100644 --- a/doc/getting-started-guide/quick-overview.rst +++ b/doc/getting-started-guide/quick-overview.rst @@ -215,13 +215,15 @@ You can directly read and write xarray objects to disk using :py:meth:`~xarray.D .. ipython:: python ds.to_netcdf("example.nc") - xr.open_dataset("example.nc") + reopened = xr.open_dataset("example.nc") + reopened .. ipython:: python :suppress: import os + reopened.close() os.remove("example.nc") diff --git a/doc/internals/zarr-encoding-spec.rst b/doc/internals/zarr-encoding-spec.rst index f809ea337d5..7fb2383935f 100644 --- a/doc/internals/zarr-encoding-spec.rst +++ b/doc/internals/zarr-encoding-spec.rst @@ -63,3 +63,10 @@ re-open it directly with Zarr: print(os.listdir("rasm.zarr")) print(zgroup.tree()) dict(zgroup["Tair"].attrs) + +.. ipython:: python + :suppress: + + import shutil + + shutil.rmtree("rasm.zarr") diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index 4d8715d9c51..5110a970390 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -55,6 +55,8 @@ argument to :py:func:`~xarray.open_dataset` or using the .. ipython:: python :suppress: + import os + import numpy as np import pandas as pd import xarray as xr @@ -129,6 +131,11 @@ will return a ``dask.delayed`` object that can be computed later. with ProgressBar(): results = delayed_obj.compute() +.. ipython:: python + :suppress: + + os.remove("manipulated-example-data.nc") # Was not opened. + .. note:: When using Dask's distributed scheduler to write NETCDF4 files, @@ -147,13 +154,6 @@ A dataset can also be converted to a Dask DataFrame using :py:meth:`~xarray.Data Dask DataFrames do not support multi-indexes so the coordinate variables from the dataset are included as columns in the Dask DataFrame. -.. ipython:: python - :suppress: - - import os - - os.remove("example-data.nc") - os.remove("manipulated-example-data.nc") Using Dask with xarray ---------------------- @@ -210,7 +210,7 @@ Dask arrays using the :py:meth:`~xarray.Dataset.persist` method: .. ipython:: python - ds = ds.persist() + persisted = ds.persist() :py:meth:`~xarray.Dataset.persist` is particularly useful when using a distributed cluster because the data will be loaded into distributed memory @@ -232,11 +232,6 @@ chunk size depends both on your data and on the operations you want to perform. With xarray, both converting data to a Dask arrays and converting the chunk sizes of Dask arrays is done with the :py:meth:`~xarray.Dataset.chunk` method: -.. ipython:: python - :suppress: - - ds = ds.chunk({"time": 10}) - .. ipython:: python rechunked = ds.chunk({"latitude": 100, "longitude": 100}) @@ -508,6 +503,11 @@ Notice that the 0-shaped sizes were not printed to screen. Since ``template`` ha expected = ds + 10 + 10 mapped.identical(expected) +.. ipython:: python + :suppress: + + ds.close() # Closes "example-data.nc". + os.remove("example-data.nc") .. tip:: diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 834e9ad2464..ddde0bf5888 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -11,6 +11,8 @@ format (recommended). .. ipython:: python :suppress: + import os + import numpy as np import pandas as pd import xarray as xr @@ -84,6 +86,13 @@ We can load netCDF files to create a new Dataset using ds_disk = xr.open_dataset("saved_on_disk.nc") ds_disk +.. ipython:: python + :suppress: + + # Close "saved_on_disk.nc", but retain the file until after closing or deleting other + # datasets that will refer to it. + ds_disk.close() + Similarly, a DataArray can be saved to disk using the :py:meth:`DataArray.to_netcdf` method, and loaded from disk using the :py:func:`open_dataarray` function. As netCDF files @@ -204,11 +213,6 @@ You can view this encoding information (among others) in the Note that all operations that manipulate variables other than indexing will remove encoding information. -.. ipython:: python - :suppress: - - ds_disk.close() - .. _combining multiple files: @@ -484,13 +488,13 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: da.to_netcdf("complex.nc", engine="h5netcdf", invalid_netcdf=True) # Reading it back - xr.open_dataarray("complex.nc", engine="h5netcdf") + reopened = xr.open_dataarray("complex.nc", engine="h5netcdf") + reopened .. ipython:: python :suppress: - import os - + reopened.close() os.remove("complex.nc") .. warning:: @@ -989,16 +993,19 @@ To export just the dataset schema without the data itself, use the ds.to_dict(data=False) -This can be useful for generating indices of dataset contents to expose to -search indices or other automated data discovery tools. - .. ipython:: python :suppress: - import os - + # We're now done with the dataset named `ds`. Although the `with` statement closed + # the dataset, displaying the unpickled pickle of `ds` re-opened "saved_on_disk.nc". + # However, `ds` (rather than the unpickled dataset) refers to the open file. Delete + # `ds` to close the file. + del ds os.remove("saved_on_disk.nc") +This can be useful for generating indices of dataset contents to expose to +search indices or other automated data discovery tools. + .. _io.rasterio: Rasterio diff --git a/doc/user-guide/weather-climate.rst b/doc/user-guide/weather-climate.rst index d11c7c3a4f9..3c957978acf 100644 --- a/doc/user-guide/weather-climate.rst +++ b/doc/user-guide/weather-climate.rst @@ -218,13 +218,15 @@ For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: .. ipython:: python da.to_netcdf("example-no-leap.nc") - xr.open_dataset("example-no-leap.nc") + reopened = xr.open_dataset("example-no-leap.nc") + reopened .. ipython:: python :suppress: import os + reopened.close() os.remove("example-no-leap.nc") - And resampling along the time dimension for data indexed by a :py:class:`~xarray.CFTimeIndex`: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4729b74640f..ae51b9f6ce1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -57,6 +57,9 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Delete files of datasets saved to disk while building the documentation and enable + building on Windows via `sphinx-build` (:pull:`6237`). + By `Stan West `_. Internal Changes From f9037c41e36254bfe7efa9feaedd3eae5512bd11 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Tue, 1 Mar 2022 09:57:43 -0700 Subject: [PATCH 16/34] Disable CI runs on forks (#6315) --- .github/workflows/ci-additional.yaml | 1 + .github/workflows/ci.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index b476c224df6..6bd9bcd9d6b 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -119,6 +119,7 @@ jobs: doctest: name: Doctests runs-on: "ubuntu-latest" + if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4f6cbbc3871..1e5db3a73ed 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -115,6 +115,7 @@ jobs: event_file: name: "Event File" runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' steps: - name: Upload uses: actions/upload-artifact@v2 From 2ab9f36641a8a0248df0497aadd59246586b65ea Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Wed, 2 Mar 2022 06:56:25 -0700 Subject: [PATCH 17/34] Add General issue template (#6314) * Add General issue template * Update description Co-authored-by: Mathias Hauser * Remove title Co-authored-by: Mathias Hauser Co-authored-by: Mathias Hauser --- .github/ISSUE_TEMPLATE/misc.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/misc.yml diff --git a/.github/ISSUE_TEMPLATE/misc.yml b/.github/ISSUE_TEMPLATE/misc.yml new file mode 100644 index 00000000000..94dd2d86567 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/misc.yml @@ -0,0 +1,17 @@ +name: Issue +description: General Issue or discussion topic. For usage questions, please follow the "Usage question" link +labels: ["needs triage"] +body: + - type: markdown + attributes: + value: | + Please describe your issue here. + - type: textarea + id: issue-description + attributes: + label: What is your issue? + description: | + Thank you for filing an issue! Please give us further information on how we can help you. + placeholder: Please describe your issue. + validations: + required: true From cdab326bab0cf86f96bcce4292f4fae24bddc7b6 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 2 Mar 2022 14:57:29 +0100 Subject: [PATCH 18/34] fix typos (using codespell) (#6316) * fix typos (using codespell) * revert 'split' --- doc/examples/ROMS_ocean_model.ipynb | 2 +- doc/gallery/plot_colorbar_center.py | 2 +- doc/internals/how-to-add-new-backend.rst | 4 ++-- doc/internals/zarr-encoding-spec.rst | 2 +- doc/roadmap.rst | 2 +- doc/user-guide/time-series.rst | 2 +- doc/whats-new.rst | 12 ++++++------ xarray/backends/common.py | 2 +- xarray/backends/file_manager.py | 2 +- xarray/backends/pseudonetcdf_.py | 2 +- xarray/backends/zarr.py | 2 +- xarray/convert.py | 2 +- xarray/core/accessor_str.py | 4 ++-- xarray/core/combine.py | 2 +- xarray/core/computation.py | 2 +- xarray/core/concat.py | 2 +- xarray/core/dataarray.py | 4 ++-- xarray/core/dataset.py | 12 ++++++------ xarray/core/merge.py | 2 +- xarray/core/missing.py | 8 ++++---- xarray/core/rolling.py | 4 ++-- xarray/core/variable.py | 2 +- xarray/tests/test_backends.py | 6 +++--- xarray/tests/test_calendar_ops.py | 2 +- xarray/tests/test_coarsen.py | 4 ++-- xarray/tests/test_computation.py | 2 +- xarray/tests/test_dask.py | 2 +- xarray/tests/test_dataarray.py | 6 +++--- xarray/tests/test_dataset.py | 2 +- xarray/tests/test_interp.py | 2 +- xarray/tests/test_missing.py | 2 +- xarray/tests/test_plot.py | 6 +++--- 32 files changed, 56 insertions(+), 56 deletions(-) diff --git a/doc/examples/ROMS_ocean_model.ipynb b/doc/examples/ROMS_ocean_model.ipynb index 82d7a8d58af..d5c76380525 100644 --- a/doc/examples/ROMS_ocean_model.ipynb +++ b/doc/examples/ROMS_ocean_model.ipynb @@ -77,7 +77,7 @@ "ds = xr.tutorial.open_dataset(\"ROMS_example.nc\", chunks={\"ocean_time\": 1})\n", "\n", "# This is a way to turn on chunking and lazy evaluation. Opening with mfdataset, or\n", - "# setting the chunking in the open_dataset would also achive this.\n", + "# setting the chunking in the open_dataset would also achieve this.\n", "ds" ] }, diff --git a/doc/gallery/plot_colorbar_center.py b/doc/gallery/plot_colorbar_center.py index 42d6448adf6..da3447a1f25 100644 --- a/doc/gallery/plot_colorbar_center.py +++ b/doc/gallery/plot_colorbar_center.py @@ -38,6 +38,6 @@ ax4.set_title("Celsius: center=False") ax4.set_ylabel("") -# Mke it nice +# Make it nice plt.tight_layout() plt.show() diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index 940fa7a4d1e..17f94189504 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -317,7 +317,7 @@ grouped in three types of indexes :py:class:`~xarray.core.indexing.OuterIndexer` and :py:class:`~xarray.core.indexing.VectorizedIndexer`. This implies that the implementation of the method ``__getitem__`` can be tricky. -In oder to simplify this task, Xarray provides a helper function, +In order to simplify this task, Xarray provides a helper function, :py:func:`~xarray.core.indexing.explicit_indexing_adapter`, that transforms all the input ``indexer`` types (`basic`, `outer`, `vectorized`) in a tuple which is interpreted correctly by your backend. @@ -426,7 +426,7 @@ The ``OUTER_1VECTOR`` indexing shall supports number, slices and at most one list. The behaviour with the list shall be the same of ``OUTER`` indexing. If you support more complex indexing as `explicit indexing` or -`numpy indexing`, you can have a look to the implemetation of Zarr backend and Scipy backend, +`numpy indexing`, you can have a look to the implementation of Zarr backend and Scipy backend, currently available in :py:mod:`~xarray.backends` module. .. _RST preferred_chunks: diff --git a/doc/internals/zarr-encoding-spec.rst b/doc/internals/zarr-encoding-spec.rst index 7fb2383935f..f8bffa6e82f 100644 --- a/doc/internals/zarr-encoding-spec.rst +++ b/doc/internals/zarr-encoding-spec.rst @@ -14,7 +14,7 @@ for the storage of the NetCDF data model in Zarr; see discussion. First, Xarray can only read and write Zarr groups. There is currently no support -for reading / writting individual Zarr arrays. Zarr groups are mapped to +for reading / writing individual Zarr arrays. Zarr groups are mapped to Xarray ``Dataset`` objects. Second, from Xarray's point of view, the key difference between diff --git a/doc/roadmap.rst b/doc/roadmap.rst index c59d56fdd6d..d4098cfd35a 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -112,7 +112,7 @@ A cleaner model would be to elevate ``indexes`` to an explicit part of xarray's data model, e.g., as attributes on the ``Dataset`` and ``DataArray`` classes. Indexes would need to be propagated along with coordinates in xarray operations, but will no longer would need to have -a one-to-one correspondance with coordinate variables. Instead, an index +a one-to-one correspondence with coordinate variables. Instead, an index should be able to refer to multiple (possibly multidimensional) coordinates that define it. See `GH 1603 `__ for full details diff --git a/doc/user-guide/time-series.rst b/doc/user-guide/time-series.rst index 36a57e37475..f6b9d0bc35b 100644 --- a/doc/user-guide/time-series.rst +++ b/doc/user-guide/time-series.rst @@ -101,7 +101,7 @@ You can also select a particular time by indexing with a ds.sel(time=datetime.time(12)) -For more details, read the pandas documentation and the section on `Indexing Using Datetime Components `_ (i.e. using the ``.dt`` acessor). +For more details, read the pandas documentation and the section on `Indexing Using Datetime Components `_ (i.e. using the ``.dt`` accessor). .. _dt_accessor: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ae51b9f6ce1..6f939a36f25 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -138,7 +138,7 @@ Bug fixes By `Michael Delgado `_. - `dt.season `_ can now handle NaN and NaT. (:pull:`5876`). By `Pierre Loicq `_. -- Determination of zarr chunks handles empty lists for encoding chunks or variable chunks that occurs in certain cirumstances (:pull:`5526`). By `Chris Roat `_. +- Determination of zarr chunks handles empty lists for encoding chunks or variable chunks that occurs in certain circumstances (:pull:`5526`). By `Chris Roat `_. Internal Changes ~~~~~~~~~~~~~~~~ @@ -706,7 +706,7 @@ Breaking changes By `Alessandro Amici `_. - Functions that are identities for 0d data return the unchanged data if axis is empty. This ensures that Datasets where some variables do - not have the averaged dimensions are not accidentially changed + not have the averaged dimensions are not accidentally changed (:issue:`4885`, :pull:`5207`). By `David Schwörer `_. - :py:attr:`DataArray.coarsen` and :py:attr:`Dataset.coarsen` no longer support passing ``keep_attrs`` @@ -1419,7 +1419,7 @@ New Features Enhancements ~~~~~~~~~~~~ - Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` - We performs independant interpolation sequentially rather than interpolating in + We performs independent interpolation sequentially rather than interpolating in one large multidimensional space. (:issue:`2223`) By `Keisuke Fujii `_. - :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux `_. @@ -2770,7 +2770,7 @@ Breaking changes - ``Dataset.T`` has been removed as a shortcut for :py:meth:`Dataset.transpose`. Call :py:meth:`Dataset.transpose` directly instead. - Iterating over a ``Dataset`` now includes only data variables, not coordinates. - Similarily, calling ``len`` and ``bool`` on a ``Dataset`` now + Similarly, calling ``len`` and ``bool`` on a ``Dataset`` now includes only data variables. - ``DataArray.__contains__`` (used by Python's ``in`` operator) now checks array data, not coordinates. @@ -3908,7 +3908,7 @@ Bug fixes (:issue:`1606`). By `Joe Hamman `_. -- Fix bug when using ``pytest`` class decorators to skiping certain unittests. +- Fix bug when using ``pytest`` class decorators to skipping certain unittests. The previous behavior unintentionally causing additional tests to be skipped (:issue:`1531`). By `Joe Hamman `_. @@ -5656,7 +5656,7 @@ Bug fixes - Several bug fixes related to decoding time units from netCDF files (:issue:`316`, :issue:`330`). Thanks Stefan Pfenninger! - xray no longer requires ``decode_coords=False`` when reading datasets with - unparseable coordinate attributes (:issue:`308`). + unparsable coordinate attributes (:issue:`308`). - Fixed ``DataArray.loc`` indexing with ``...`` (:issue:`318`). - Fixed an edge case that resulting in an error when reindexing multi-dimensional variables (:issue:`315`). diff --git a/xarray/backends/common.py b/xarray/backends/common.py index f659d71760b..ad92a6c5869 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -160,7 +160,7 @@ def sync(self, compute=True): import dask.array as da # TODO: consider wrapping targets with dask.delayed, if this makes - # for any discernable difference in perforance, e.g., + # for any discernible difference in perforance, e.g., # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store( diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 47a4201539b..06be03e4e44 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -204,7 +204,7 @@ def _acquire_with_cache_info(self, needs_lock=True): kwargs["mode"] = self._mode file = self._opener(*self._args, **kwargs) if self._mode == "w": - # ensure file doesn't get overriden when opened again + # ensure file doesn't get overridden when opened again self._mode = "a" self._cache[self._key] = file return file, False diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index da178926dbe..a2ca7f0206c 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -105,7 +105,7 @@ class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): available = has_pseudonetcdf # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, - # unless the open_dataset_parameters are explicity defined like this: + # unless the open_dataset_parameters are explicitly defined like this: open_dataset_parameters = ( "filename_or_obj", "mask_and_scale", diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index b3f62bb798d..97517818d07 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -179,7 +179,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): - # Zarr arrays do not have dimenions. To get around this problem, we add + # Zarr arrays do not have dimensions. To get around this problem, we add # an attribute that specifies the dimension. We have to hide this attribute # when we send the attributes to the user. # zarr_obj can be either a zarr group or zarr array diff --git a/xarray/convert.py b/xarray/convert.py index 0fbd1e13163..93b0a30e57b 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -235,7 +235,7 @@ def _iris_cell_methods_to_str(cell_methods_obj): def _name(iris_obj, default="unknown"): - """Mimicks `iris_obj.name()` but with different name resolution order. + """Mimics `iris_obj.name()` but with different name resolution order. Similar to iris_obj.name() method, but using iris_obj.var_name first to enable roundtripping. diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 9c9de76c0ed..54c9b857a7a 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -456,7 +456,7 @@ def cat( Strings or array-like of strings to concatenate elementwise with the current DataArray. sep : str or array-like of str, default: "". - Seperator to use between strings. + Separator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -539,7 +539,7 @@ def join( Only one dimension is allowed at a time. Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. sep : str or array-like, default: "". - Seperator to use between strings. + Separator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 081b53391ba..d23a58522e6 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -135,7 +135,7 @@ def _infer_concat_order_from_coords(datasets): order = rank.astype(int).values - 1 # Append positions along extra dimension to structure which - # encodes the multi-dimensional concatentation order + # encodes the multi-dimensional concatenation order tile_ids = [ tile_id + (position,) for tile_id, position in zip(tile_ids, order) ] diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c11bd1a78a4..ce37251576a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1941,7 +1941,7 @@ def unify_chunks(*objects: T_Xarray) -> tuple[T_Xarray, ...]: for obj in objects ] - # Get argumets to pass into dask.array.core.unify_chunks + # Get arguments to pass into dask.array.core.unify_chunks unify_chunks_args = [] sizes: dict[Hashable, int] = {} for ds in datasets: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 4621e622d42..1e6e246322e 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -455,7 +455,7 @@ def _dataset_concat( if (dim in coord_names or dim in data_names) and dim not in dim_names: datasets = [ds.expand_dims(dim) for ds in datasets] - # determine which variables to concatentate + # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( datasets, dim, dim_names, data_vars, coords, compat ) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3d720f7fc8b..3df9f7ca8a4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2584,7 +2584,7 @@ def interpolate_na( ) def ffill(self, dim: Hashable, limit: int = None) -> DataArray: - """Fill NaN values by propogating values forward + """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -2609,7 +2609,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> DataArray: return ffill(self, dim, limit=limit) def bfill(self, dim: Hashable, limit: int = None) -> DataArray: - """Fill NaN values by propogating values backward + """Fill NaN values by propagating values backward *Requires bottleneck.* diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 90684c4db87..a1d7209bc75 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -379,7 +379,7 @@ def _check_chunks_compatibility(var, chunks, preferred_chunks): def _get_chunk(var, chunks): - # chunks need to be explicity computed to take correctly into accout + # chunks need to be explicitly computed to take correctly into account # backend preferred chunking import dask.array as da @@ -1529,7 +1529,7 @@ def __setitem__(self, key: Hashable | list[Hashable] | Mapping, value) -> None: except Exception as e: if processed: raise RuntimeError( - "An error occured while setting values of the" + "An error occurred while setting values of the" f" variable '{name}'. The following variables have" f" been successfully updated:\n{processed}" ) from e @@ -1976,7 +1976,7 @@ def to_zarr( metadata for existing stores (falling back to non-consolidated). append_dim : hashable, optional If set, the dimension along which the data will be appended. All - other dimensions on overriden variables must remain the same size. + other dimensions on overridden variables must remain the same size. region : dict, optional Optional mapping from dimension names to integer slices along dataset dimensions to indicate the region of existing zarr array(s) @@ -2001,7 +2001,7 @@ def to_zarr( Set False to override this restriction; however, data may become corrupted if Zarr arrays are written in parallel. This option may be useful in combination with ``compute=False`` to initialize a Zarr from an existing - Dataset with aribtrary chunk structure. + Dataset with arbitrary chunk structure. storage_options : dict, optional Any additional parameters for the storage backend (ignored for local paths). @@ -4930,7 +4930,7 @@ def interpolate_na( return new def ffill(self, dim: Hashable, limit: int = None) -> Dataset: - """Fill NaN values by propogating values forward + """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -4956,7 +4956,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> Dataset: return new def bfill(self, dim: Hashable, limit: int = None) -> Dataset: - """Fill NaN values by propogating values backward + """Fill NaN values by propagating values backward *Requires bottleneck.* diff --git a/xarray/core/merge.py b/xarray/core/merge.py index d5307678f89..e5407ae79c3 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -587,7 +587,7 @@ def merge_core( Parameters ---------- objects : list of mapping - All values must be convertable to labeled arrays. + All values must be convertible to labeled arrays. compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional Compatibility checks to use when merging variables. join : {"outer", "inner", "left", "right"}, optional diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 39e7730dd58..c1776145e21 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -573,7 +573,7 @@ def _localize(var, indexes_coords): def _floatize_x(x, new_x): """Make x and new_x float. - This is particulary useful for datetime dtype. + This is particularly useful for datetime dtype. x, new_x: tuple of np.ndarray """ x = list(x) @@ -624,7 +624,7 @@ def interp(var, indexes_coords, method, **kwargs): kwargs["bounds_error"] = kwargs.get("bounds_error", False) result = var - # decompose the interpolation into a succession of independant interpolation + # decompose the interpolation into a succession of independent interpolation for indexes_coords in decompose_interp(indexes_coords): var = result @@ -731,7 +731,7 @@ def interp_func(var, x, new_x, method, kwargs): for i in range(new_x[0].ndim) } - # if usefull, re-use localize for each chunk of new_x + # if useful, re-use localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) # scipy.interpolate.interp1d always forces to float. @@ -825,7 +825,7 @@ def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True def decompose_interp(indexes_coords): - """Decompose the interpolation into a succession of independant interpolation keeping the order""" + """Decompose the interpolation into a succession of independent interpolation keeping the order""" dest_dims = [ dest[1].dims if dest[1].ndim > 0 else [dim] diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 0bc07c1aaeb..f2ac9d979ae 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -33,7 +33,7 @@ Returns ------- reduced : same type as caller - New object with `{name}` applied along its rolling dimnension. + New object with `{name}` applied along its rolling dimension. """ @@ -767,7 +767,7 @@ def __init__(self, obj, windows, boundary, side, coord_func): exponential window along (e.g. `time`) to the size of the moving window. boundary : 'exact' | 'trim' | 'pad' If 'exact', a ValueError will be raised if dimension size is not a - multiple of window size. If 'trim', the excess indexes are trimed. + multiple of window size. If 'trim', the excess indexes are trimmed. If 'pad', NA will be padded. side : 'left' or 'right' or mapping from dimension to 'left' or 'right' coord_func : mapping from coordinate name to func. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6db795ce26f..74f394b68ca 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -700,7 +700,7 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None def _nonzero(self): - """Equivalent numpy's nonzero but returns a tuple of Varibles.""" + """Equivalent numpy's nonzero but returns a tuple of Variables.""" # TODO we should replace dask's native nonzero # after https://github.com/dask/dask/issues/1076 is implemented. nonzeros = np.nonzero(self.data) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1d0342dd344..e0bc0b10437 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1937,7 +1937,7 @@ def test_chunk_encoding_with_dask(self): with self.roundtrip(ds_chunk4) as actual: assert (4,) == actual["var1"].encoding["chunks"] - # TODO: remove this failure once syncronized overlapping writes are + # TODO: remove this failure once synchronized overlapping writes are # supported by xarray ds_chunk4["var1"].encoding.update({"chunks": 5}) with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"): @@ -2255,7 +2255,7 @@ def test_write_region_mode(self, mode): @requires_dask def test_write_preexisting_override_metadata(self): - """Metadata should be overriden if mode="a" but not in mode="r+".""" + """Metadata should be overridden if mode="a" but not in mode="r+".""" original = Dataset( {"u": (("x",), np.zeros(10), {"variable": "original"})}, attrs={"global": "original"}, @@ -2967,7 +2967,7 @@ def test_open_fileobj(self): with pytest.raises(TypeError, match="not a valid NetCDF 3"): open_dataset(f, engine="scipy") - # TOOD: this additional open is required since scipy seems to close the file + # TODO: this additional open is required since scipy seems to close the file # when it fails on the TypeError (though didn't when we used # `raises_regex`?). Ref https://github.com/pydata/xarray/pull/5191 with open(tmp_file, "rb") as f: diff --git a/xarray/tests/test_calendar_ops.py b/xarray/tests/test_calendar_ops.py index 8d1ddcf4689..0f0948aafc5 100644 --- a/xarray/tests/test_calendar_ops.py +++ b/xarray/tests/test_calendar_ops.py @@ -161,7 +161,7 @@ def test_convert_calendar_errors(): with pytest.raises(ValueError, match="Argument `align_on` must be specified"): convert_calendar(src_nl, "360_day") - # Standard doesn't suuport year 0 + # Standard doesn't support year 0 with pytest.raises( ValueError, match="Source time coordinate contains dates with year 0" ): diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index 9b1919b136c..7d6613421d5 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -158,7 +158,7 @@ def test_coarsen_keep_attrs(funcname, argument) -> None: @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) def test_coarsen_reduce(ds, window, name) -> None: - # Use boundary="trim" to accomodate all window sizes used in tests + # Use boundary="trim" to accommodate all window sizes used in tests coarsen_obj = ds.coarsen(time=window, boundary="trim") # add nan prefix to numpy methods to get similar behavior as bottleneck @@ -241,7 +241,7 @@ def test_coarsen_da_reduce(da, window, name) -> None: if da.isnull().sum() > 1 and window == 1: pytest.skip("These parameters lead to all-NaN slices") - # Use boundary="trim" to accomodate all window sizes used in tests + # Use boundary="trim" to accommodate all window sizes used in tests coarsen_obj = da.coarsen(time=window, boundary="trim") # add nan prefix to numpy methods to get similar # behavior as bottleneck diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 7f601c6195a..dac3c17b1f1 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -2038,7 +2038,7 @@ def test_polyval(use_dask, use_datetime) -> None: "cartesian", -1, ], - [ # Test filling inbetween with coords: + [ # Test filling in between with coords: xr.DataArray( [1, 2], dims=["cartesian"], diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 48432f319b2..42d8df57cb7 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -460,7 +460,7 @@ def test_concat_loads_variables(self): assert isinstance(out["c"].data, dask.array.Array) out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=[], coords=[]) - # variables are loaded once as we are validing that they're identical + # variables are loaded once as we are validating that they're identical assert kernel_call_count == 12 assert isinstance(out["d"].data, np.ndarray) assert isinstance(out["c"].data, np.ndarray) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fc82c03c5d9..3939da08a67 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -174,7 +174,7 @@ def test_get_index_size_zero(self): def test_struct_array_dims(self): """ - This test checks subraction of two DataArrays for the case + This test checks subtraction of two DataArrays for the case when dimension is a structured array. """ # GH837, GH861 @@ -197,7 +197,7 @@ def test_struct_array_dims(self): assert_identical(actual, expected) - # checking array subraction when dims are not the same + # checking array subtraction when dims are not the same p_data_alt = np.array( [("Abe", 180), ("Stacy", 151), ("Dick", 200)], dtype=[("name", "|S256"), ("height", object)], @@ -213,7 +213,7 @@ def test_struct_array_dims(self): assert_identical(actual, expected) - # checking array subraction when dims are not the same and one + # checking array subtraction when dims are not the same and one # is np.nan p_data_nan = np.array( [("Abe", 180), ("Stacy", np.nan), ("Dick", 200)], diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7ff75fb791b..d726920acce 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -587,7 +587,7 @@ def test_get_index(self): def test_attr_access(self): ds = Dataset( - {"tmin": ("x", [42], {"units": "Celcius"})}, attrs={"title": "My test data"} + {"tmin": ("x", [42], {"units": "Celsius"})}, attrs={"title": "My test data"} ) assert_identical(ds.tmin, ds["tmin"]) assert_identical(ds.tmin.x, ds.x) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index f6d8a7cfcb0..2a6de0be550 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -770,7 +770,7 @@ def test_decompose(method): ], ) def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked): - """Interpolate nd array with multiple independant indexers + """Interpolate nd array with multiple independent indexers It should do a series of 1d interpolation """ diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 2bf5af31fa5..3721c92317d 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -116,7 +116,7 @@ def test_interpolate_pd_compat(): @pytest.mark.parametrize("method", ["barycentric", "krog", "pchip", "spline", "akima"]) def test_scipy_methods_function(method): # Note: Pandas does some wacky things with these methods and the full - # integration tests wont work. + # integration tests won't work. da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) actual = da.interpolate_na(method=method, dim="time") assert (da.count("time") <= actual.count("time")).all() diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c0b6d355441..8ded4c6515f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -526,7 +526,7 @@ def test__infer_interval_breaks_logscale_invalid_coords(self): x = np.linspace(0, 5, 6) with pytest.raises(ValueError): _infer_interval_breaks(x, scale="log") - # Check if error is raised after nagative values in the array + # Check if error is raised after negative values in the array x = np.linspace(-5, 5, 11) with pytest.raises(ValueError): _infer_interval_breaks(x, scale="log") @@ -1506,7 +1506,7 @@ def test_convenient_facetgrid(self): else: assert "" == ax.get_xlabel() - # Infering labels + # Inferring labels g = self.plotfunc(d, col="z", col_wrap=2) assert_array_equal(g.axes.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axes): @@ -1986,7 +1986,7 @@ def test_convenient_facetgrid(self): assert "y" == ax.get_ylabel() assert "x" == ax.get_xlabel() - # Infering labels + # Inferring labels g = self.plotfunc(d, col="z", col_wrap=2) assert_array_equal(g.axes.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axes): From 9b4d0b29c319f4b68f89328b1bf558711f339504 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Wed, 2 Mar 2022 10:49:23 -0500 Subject: [PATCH 19/34] v2022.03.0 release notes (#6319) * release summary * update first calver relase number --- HOW_TO_RELEASE.md | 2 +- doc/whats-new.rst | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 893a6d77168..8d82277ae55 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -109,6 +109,6 @@ upstream https://github.com/pydata/xarray (push) ## Note on version numbering -As of 2022.02.0, we utilize the [CALVER](https://calver.org/) version system. +As of 2022.03.0, we utilize the [CALVER](https://calver.org/) version system. Specifically, we have adopted the pattern `YYYY.MM.X`, where `YYYY` is a 4-digit year (e.g. `2022`), `MM` is a 2-digit zero-padded month (e.g. `01` for January), and `X` is the release number (starting at zero at the start of each month and incremented once for each additional release). diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6f939a36f25..25e7071f9d6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,10 +14,18 @@ What's New np.random.seed(123456) -.. _whats-new.2022.02.0: +.. _whats-new.2022.03.0: -v2022.02.0 (unreleased) ------------------------ +v2022.03.0 (2 March 2022) +------------------------- + +This release brings a number of small improvements, as well as a move to `calendar versioning `_ (:issue:`6176`). + +Many thanks to the 16 contributors to the v2022.02.0 release! + +Aaron Spring, Alan D. Snow, Anderson Banihirwe, crusaderky, Illviljan, Joe Hamman, Jonas Gliß, +Lukas Pilz, Martin Bergemann, Mathias Hauser, Maximilian Roos, Romain Caneill, Stan West, Stijn Van Hoey, +Tobias Kölling, and Tom Nicholas. New Features @@ -27,7 +35,6 @@ New Features :py:meth:`CFTimeIndex.shift` if ``shift_freq`` is between ``Day`` and ``Microsecond``. (:issue:`6134`, :pull:`6135`). By `Aaron Spring `_. - - Enbable to provide more keyword arguments to `pydap` backend when reading OpenDAP datasets (:issue:`6274`). By `Jonas Gliß `. From 56c7b8a0782d132e8c4d5913fb6a40ecb0ee60e8 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 2 Mar 2022 11:20:29 -0500 Subject: [PATCH 20/34] New whatsnew section --- doc/whats-new.rst | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6f939a36f25..f12d6525133 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,35 @@ What's New np.random.seed(123456) +.. _whats-new.2022.03.1: + +v2022.03.1 (unreleased) +--------------------- + +New Features +~~~~~~~~~~~~ + + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + .. _whats-new.2022.02.0: v2022.02.0 (unreleased) From a01460bd90e3ae31a32e40005f33c2919efba8bb Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 3 Mar 2022 10:43:35 +0100 Subject: [PATCH 21/34] quantile: use skipna=None (#6303) * quantile: use skipna=None * better term * move whats new entry * remove duplicated entry --- doc/whats-new.rst | 4 +++- xarray/core/dataarray.py | 7 +++++-- xarray/core/dataset.py | 7 +++++-- xarray/core/groupby.py | 7 +++++-- xarray/core/variable.py | 12 ++++++++++-- xarray/tests/test_dataarray.py | 12 ++++++++---- xarray/tests/test_dataset.py | 3 ++- xarray/tests/test_groupby.py | 25 +++++++++++++++++++++++++ xarray/tests/test_variable.py | 12 ++++++++---- 9 files changed, 71 insertions(+), 18 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fdedc18300f..bfd5f27cbec 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and + ensure it skips missing values for float dtypes (consistent with other methods). This should + not change the behavior (:pull:`6303`). By `Mathias Hauser `_. Documentation ~~~~~~~~~~~~~ @@ -86,7 +89,6 @@ Deprecations Bug fixes ~~~~~~~~~ - - Variables which are chunked using dask in larger (but aligned) chunks than the target zarr chunk size can now be stored using `to_zarr()` (:pull:`6258`) By `Tobias Kölling `_. - Multi-file datasets containing encoded :py:class:`cftime.datetime` objects can be read in parallel again (:issue:`6226`, :pull:`6249`, :pull:`6305`). By `Martin Bergemann `_ and `Stan West `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3df9f7ca8a4..e04e5cb9c51 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3440,7 +3440,7 @@ def quantile( dim: str | Sequence[Hashable] | None = None, method: QUANTILE_METHODS = "linear", keep_attrs: bool = None, - skipna: bool = True, + skipna: bool = None, interpolation: QUANTILE_METHODS = None, ) -> DataArray: """Compute the qth quantile of the data along the specified dimension. @@ -3486,7 +3486,10 @@ def quantile( the original object to the new one. If False (default), the new object will be returned without attributes. skipna : bool, optional - Whether to skip missing values when aggregating. + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a1d7209bc75..b3112bdc7ab 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6160,7 +6160,7 @@ def quantile( method: QUANTILE_METHODS = "linear", numeric_only: bool = False, keep_attrs: bool = None, - skipna: bool = True, + skipna: bool = None, interpolation: QUANTILE_METHODS = None, ): """Compute the qth quantile of the data along the specified dimension. @@ -6209,7 +6209,10 @@ def quantile( numeric_only : bool, optional If True, only apply ``func`` to variables with a numeric dtype. skipna : bool, optional - Whether to skip missing values when aggregating. + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). Returns ------- diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d7aa6749592..d3ec824159c 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -554,7 +554,7 @@ def quantile( dim=None, method="linear", keep_attrs=None, - skipna=True, + skipna=None, interpolation=None, ): """Compute the qth quantile over each array in the groups and @@ -597,7 +597,10 @@ def quantile( version 1.22.0. skipna : bool, optional - Whether to skip missing values when aggregating. + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). Returns ------- diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 74f394b68ca..c8d46d20d46 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1978,7 +1978,7 @@ def quantile( dim: str | Sequence[Hashable] | None = None, method: QUANTILE_METHODS = "linear", keep_attrs: bool = None, - skipna: bool = True, + skipna: bool = None, interpolation: QUANTILE_METHODS = None, ) -> Variable: """Compute the qth quantile of the data along the specified dimension. @@ -2024,6 +2024,11 @@ def quantile( If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). Returns ------- @@ -2059,7 +2064,10 @@ def quantile( method = interpolation - _quantile_func = np.nanquantile if skipna else np.quantile + if skipna or (skipna is None and self.dtype.kind in "cfO"): + _quantile_func = np.nanquantile + else: + _quantile_func = np.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3939da08a67..55c68b7ff6b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2516,15 +2516,19 @@ def test_reduce_out(self): with pytest.raises(TypeError): orig.mean(out=np.ones(orig.shape)) - @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) ) def test_quantile(self, q, axis, dim, skipna) -> None: - actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna) - _percentile_func = np.nanpercentile if skipna else np.percentile - expected = _percentile_func(self.dv.values, np.array(q) * 100, axis=axis) + + va = self.va.copy(deep=True) + va[0, 0] = np.NaN + + actual = DataArray(va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna) + _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile + expected = _percentile_func(va.values, np.array(q) * 100, axis=axis) np.testing.assert_allclose(actual.values, expected) if is_scalar(q): assert "quantile" not in actual.dims diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d726920acce..8d6c4f96857 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4718,10 +4718,11 @@ def test_reduce_keepdims(self): ) assert_identical(expected, actual) - @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) def test_quantile(self, q, skipna) -> None: ds = create_test_data(seed=123) + ds.var1.data[0, 0] = np.NaN for dim in [None, "dim1", ["dim1"]]: ds_quantile = ds.quantile(q, dim=dim, skipna=skipna) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1ec2a53c131..4b6da82bdc7 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -203,6 +203,17 @@ def test_da_groupby_quantile() -> None: actual = array.groupby("x").quantile([0, 1]) assert_identical(expected, actual) + array = xr.DataArray( + data=[np.NaN, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + + for skipna in (True, False, None): + e = [np.NaN, 5] if skipna is False else [2.5, 5] + + expected = xr.DataArray(data=e, coords={"x": [1, 2], "quantile": 0.5}, dims="x") + actual = array.groupby("x").quantile(0.5, skipna=skipna) + assert_identical(expected, actual) + # Multiple dimensions array = xr.DataArray( data=[[1, 11, 26], [2, 12, 22], [3, 13, 23], [4, 16, 24], [5, 15, 25]], @@ -306,6 +317,20 @@ def test_ds_groupby_quantile() -> None: actual = ds.groupby("x").quantile([0, 1]) assert_identical(expected, actual) + ds = xr.Dataset( + data_vars={"a": ("x", [np.NaN, 2, 3, 4, 5, 6])}, + coords={"x": [1, 1, 1, 2, 2, 2]}, + ) + + for skipna in (True, False, None): + e = [np.NaN, 5] if skipna is False else [2.5, 5] + + expected = xr.Dataset( + data_vars={"a": ("x", e)}, coords={"quantile": 0.5, "x": [1, 2]} + ) + actual = ds.groupby("x").quantile(0.5, skipna=skipna) + assert_identical(expected, actual) + # Multiple dimensions ds = xr.Dataset( data_vars={ diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index a88d5a22c0d..b8e2f6f4582 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1700,16 +1700,20 @@ def raise_if_called(*args, **kwargs): with set_options(use_bottleneck=False): v.min() - @pytest.mark.parametrize("skipna", [True, False]) + @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) ) def test_quantile(self, q, axis, dim, skipna): - v = Variable(["x", "y"], self.d) + + d = self.d.copy() + d[0, 0] = np.NaN + + v = Variable(["x", "y"], d) actual = v.quantile(q, dim=dim, skipna=skipna) - _percentile_func = np.nanpercentile if skipna else np.percentile - expected = _percentile_func(self.d, np.array(q) * 100, axis=axis) + _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile + expected = _percentile_func(d, np.array(q) * 100, axis=axis) np.testing.assert_allclose(actual.values, expected) @requires_dask From f42ac28629b7b2047f859f291e1d755c36f2e834 Mon Sep 17 00:00:00 2001 From: Stan West <38358698+stanwest@users.noreply.github.com> Date: Thu, 3 Mar 2022 12:01:14 -0500 Subject: [PATCH 22/34] Lengthen underline, correct spelling, and reword (#6326) --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bfd5f27cbec..b22c6e4d858 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,7 +17,7 @@ What's New .. _whats-new.2022.03.1: v2022.03.1 (unreleased) ---------------------- +----------------------- New Features ~~~~~~~~~~~~ @@ -68,7 +68,7 @@ New Features :py:meth:`CFTimeIndex.shift` if ``shift_freq`` is between ``Day`` and ``Microsecond``. (:issue:`6134`, :pull:`6135`). By `Aaron Spring `_. -- Enbable to provide more keyword arguments to `pydap` backend when reading +- Enable providing more keyword arguments to the `pydap` backend when reading OpenDAP datasets (:issue:`6274`). By `Jonas Gliß `. - Allow :py:meth:`DataArray.drop_duplicates` to drop duplicates along multiple dimensions at once, From 29a87cc110f1a1ff7b21c308ba7277963b51ada3 Mon Sep 17 00:00:00 2001 From: Stan West <38358698+stanwest@users.noreply.github.com> Date: Mon, 7 Mar 2022 08:13:50 -0500 Subject: [PATCH 23/34] In documentation on adding a new backend, add missing import and tweak headings (#6330) --- doc/internals/how-to-add-new-backend.rst | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index 17f94189504..506a8eb21be 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -172,6 +172,7 @@ Xarray :py:meth:`~xarray.open_dataset`, and returns a boolean. Decoders ^^^^^^^^ + The decoders implement specific operations to transform data from on-disk representation to Xarray representation. @@ -199,6 +200,11 @@ performs the inverse transformation. In the following an example on how to use the coders ``decode`` method: +.. ipython:: python + :suppress: + + import xarray as xr + .. ipython:: python var = xr.Variable( @@ -239,7 +245,7 @@ interface only the boolean keywords related to the supported decoders. .. _RST backend_registration: How to register a backend -+++++++++++++++++++++++++++ ++++++++++++++++++++++++++ Define a new entrypoint in your ``setup.py`` (or ``setup.cfg``) with: @@ -280,8 +286,9 @@ See https://python-poetry.org/docs/pyproject/#plugins for more information on Po .. _RST lazy_loading: -How to support Lazy Loading +How to support lazy loading +++++++++++++++++++++++++++ + If you want to make your backend effective with big datasets, then you should support lazy loading. Basically, you shall replace the :py:class:`numpy.ndarray` inside the @@ -380,8 +387,9 @@ opening files, we therefore suggest to use the helper class provided by Xarray .. _RST indexing: -Indexing Examples +Indexing examples ^^^^^^^^^^^^^^^^^ + **BASIC** In the ``BASIC`` indexing support, numbers and slices are supported. From 265ec7b4b8f6ee46120f125875685569e4115634 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Mar 2022 07:28:15 -0700 Subject: [PATCH 24/34] Bump actions/checkout from 2 to 3 (#6337) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/benchmarks.yml | 2 +- .github/workflows/ci-additional.yaml | 8 ++++---- .github/workflows/ci.yaml | 4 ++-- .github/workflows/pypi-release.yaml | 2 +- .github/workflows/upstream-dev-ci.yaml | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index de506546ac9..6d482445f96 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -16,7 +16,7 @@ jobs: steps: # We need the full repo to avoid this issue # https://github.com/actions/checkout/issues/23 - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 6bd9bcd9d6b..50c95cdebb7 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1.1 @@ -53,7 +53,7 @@ jobs: "py39-flaky", ] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -125,7 +125,7 @@ jobs: shell: bash -l {0} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 # Fetch all history for all branches and tags. - uses: conda-incubator/setup-miniconda@v2 @@ -162,7 +162,7 @@ jobs: shell: bash -l {0} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 # Fetch all history for all branches and tags. - uses: conda-incubator/setup-miniconda@v2 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1e5db3a73ed..4747b5ae20d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1.1 @@ -44,7 +44,7 @@ jobs: # Bookend python versions python-version: ["3.8", "3.9", "3.10"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set environment variables diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index f09291b9c6e..19ba1f0aa22 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest if: github.repository == 'pydata/xarray' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-python@v2 diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index f6f97fd67e3..6860f99c4da 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -24,7 +24,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1.1 @@ -52,7 +52,7 @@ jobs: outputs: artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 # Fetch all history for all branches and tags. - uses: conda-incubator/setup-miniconda@v2 @@ -110,7 +110,7 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v2 with: python-version: "3.x" From d293f50f9590251ce09543319d1f0dc760466f1b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Mar 2022 07:58:07 -0700 Subject: [PATCH 25/34] Bump actions/setup-python from 2 to 3 (#6338) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pypi-release.yaml | 4 ++-- .github/workflows/upstream-dev-ci.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index 19ba1f0aa22..c88cf556a50 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 name: Install Python with: python-version: 3.8 @@ -50,7 +50,7 @@ jobs: needs: build-artifacts runs-on: ubuntu-latest steps: - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 name: Install Python with: python-version: 3.8 diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 6860f99c4da..6091306ed8b 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -111,7 +111,7 @@ jobs: shell: bash steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 with: python-version: "3.x" - uses: actions/download-artifact@v2 From 9a71cc071603b1624f8eeca238a514267cec1bbc Mon Sep 17 00:00:00 2001 From: keewis Date: Fri, 11 Mar 2022 15:54:43 +0100 Subject: [PATCH 26/34] explicitly install `ipython_genutils` (#6350) * explicitly install ipython_genutils * try upgrading sphinx [skip-ci] * undo the upgrade [skip-ci] --- ci/requirements/doc.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index e5fcc500f70..f0511ce2006 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -12,6 +12,7 @@ dependencies: - h5netcdf>=0.7.4 - ipykernel - ipython + - ipython_genutils # remove once `nbconvert` fixed its dependencies - iris>=2.3 - jupyter_client - matplotlib-base From 229dad93e2c964c363afa4ab1673bb5b37151f7b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 12 Mar 2022 11:17:47 +0300 Subject: [PATCH 27/34] Generate reductions for DataArray, Dataset, GroupBy and Resample (#5950) Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Mathias Hauser Co-authored-by: Stephan Hoyer Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/user-guide/computation.rst | 2 + xarray/core/_reductions.py | 3527 ++++++++++++++++++++++----- xarray/core/arithmetic.py | 2 - xarray/core/common.py | 16 +- xarray/core/dataarray.py | 6 +- xarray/core/dataset.py | 6 +- xarray/core/groupby.py | 46 +- xarray/core/resample.py | 25 +- xarray/tests/test_dataarray.py | 4 +- xarray/tests/test_dataset.py | 26 +- xarray/tests/test_duck_array_ops.py | 82 - xarray/util/generate_reductions.py | 535 ++-- 12 files changed, 3307 insertions(+), 970 deletions(-) diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index d830076e37b..de2afa9060c 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -107,6 +107,8 @@ Xarray also provides the ``max_gap`` keyword argument to limit the interpolation data gaps of length ``max_gap`` or smaller. See :py:meth:`~xarray.DataArray.interpolate_na` for more. +.. _agg: + Aggregation =========== diff --git a/xarray/core/_reductions.py b/xarray/core/_reductions.py index 83aaa10a20c..31365f39e65 100644 --- a/xarray/core/_reductions.py +++ b/xarray/core/_reductions.py @@ -1,43 +1,1975 @@ """Mixin classes with reduction operations.""" # This file was generated using xarray.util.generate_reductions. Do not edit manually. -from typing import Any, Callable, Hashable, Optional, Protocol, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops -from .types import T_DataArray, T_Dataset +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + + +class DatasetReductions: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "Dataset": + raise NotImplementedError() + + def count( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.count + dask.array.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.count() + + Dimensions: () + Data variables: + da int64 5 + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([True, True, True, True, True, False], dtype=bool), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.all() + + Dimensions: () + Data variables: + da bool False + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([True, True, True, True, True, False], dtype=bool), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.any() + + Dimensions: () + Data variables: + da bool True + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.max() + + Dimensions: () + Data variables: + da float64 3.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.max(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.min() + + Dimensions: () + Data variables: + da float64 1.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.min(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.mean() + + Dimensions: () + Data variables: + da float64 1.8 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.mean(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + min_count: Optional[int] = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.prod() + + Dimensions: () + Data variables: + da float64 12.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.prod(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> ds.prod(skipna=True, min_count=2) + + Dimensions: () + Data variables: + da float64 12.0 + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + min_count: Optional[int] = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.sum() + + Dimensions: () + Data variables: + da float64 9.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.sum(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> ds.sum(skipna=True, min_count=2) + + Dimensions: () + Data variables: + da float64 9.0 + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + ddof: int = 0, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.std() + + Dimensions: () + Data variables: + da float64 0.7483 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.std(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.std(skipna=True, ddof=1) + + Dimensions: () + Data variables: + da float64 0.8367 + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + ddof: int = 0, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.var() + + Dimensions: () + Data variables: + da float64 0.56 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.var(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.var(skipna=True, ddof=1) + + Dimensions: () + Data variables: + da float64 0.7 + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "Dataset": + """ + Reduce this Dataset's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : Dataset + New Dataset with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da)) + >>> ds + + Dimensions: (time: 6) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> ds.median() + + Dimensions: () + Data variables: + da float64 2.0 + + Use ``skipna`` to control whether NaNs are ignored. + + >>> ds.median(skipna=False) + + Dimensions: () + Data variables: + da float64 nan + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) + + +class DataArrayReductions: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "DataArray": + raise NotImplementedError() + + def count( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.count + dask.array.count + Dataset.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.count() + + array(5) + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) + + def all( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([True, True, True, True, True, False], dtype=bool), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ True, True, True, True, True, False]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.all() + + array(False) + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) + + def any( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([True, True, True, True, True, False], dtype=bool), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ True, True, True, True, True, False]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.any() + + array(True) + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) + + def max( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.max() + + array(3.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.max(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) + + def min( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.min() + + array(1.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.min(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) + + def mean( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.mean() + + array(1.8) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.mean(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) + + def prod( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + min_count: Optional[int] = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.prod() + + array(12.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.prod(skipna=False) + + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> da.prod(skipna=True, min_count=2) + + array(12.) + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) + + def sum( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + min_count: Optional[int] = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.sum() + + array(9.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.sum(skipna=False) + + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> da.sum(skipna=True, min_count=2) + + array(9.) + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) + + def std( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + ddof: int = 0, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.std() + + array(0.74833148) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.std(skipna=False) + + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.std(skipna=True, ddof=1) + + array(0.83666003) + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) + + def var( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + ddof: int = 0, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.var() + + array(0.56) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.var(skipna=False) + + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.var(skipna=True, ddof=1) + + array(0.7) + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) + + def median( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + skipna: bool = None, + keep_attrs: bool = None, + **kwargs, + ) -> "DataArray": + """ + Reduce this DataArray's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : DataArray + New DataArray with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> da = xr.DataArray( + ... np.array([1, 2, 3, 1, 2, np.nan]), + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> da + + array([ 1., 2., 3., 1., 2., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 + labels (time) >> da.median() + + array(2.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> da.median(skipna=False) + + array(nan) + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) + + +class DatasetGroupByReductions: + __slots__ = () -class DatasetReduce(Protocol): def reduce( self, func: Callable[..., Any], dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, axis: Union[None, int, Sequence[int]] = None, keep_attrs: bool = None, keepdims: bool = False, **kwargs: Any, - ) -> T_Dataset: - ... - - -class DatasetGroupByReductions: - __slots__ = () + ) -> "Dataset": + raise NotImplementedError() def count( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``count`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -45,6 +1977,7 @@ def count( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -52,6 +1985,14 @@ def count( New Dataset with ``count`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.count + dask.array.count + Dataset.count + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -79,13 +2020,6 @@ def count( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) int64 1 2 2 - - See Also - -------- - numpy.count - Dataset.count - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.count, @@ -96,20 +2030,20 @@ def count( ) def all( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``all`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -117,6 +2051,7 @@ def all( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -124,6 +2059,14 @@ def all( New Dataset with ``all`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.all + dask.array.all + Dataset.all + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -151,13 +2094,6 @@ def all( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) bool False True True - - See Also - -------- - numpy.all - Dataset.all - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.array_all, @@ -168,20 +2104,20 @@ def all( ) def any( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``any`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -189,6 +2125,7 @@ def any( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -196,6 +2133,14 @@ def any( New Dataset with ``any`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.any + dask.array.any + Dataset.any + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -223,13 +2168,6 @@ def any( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) bool True True True - - See Also - -------- - numpy.any - Dataset.any - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.array_any, @@ -240,25 +2178,25 @@ def any( ) def max( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``max`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -267,6 +2205,7 @@ def max( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -274,6 +2213,14 @@ def max( New Dataset with ``max`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.max + dask.array.max + Dataset.max + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -311,13 +2258,6 @@ def max( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 2.0 3.0 - - See Also - -------- - numpy.max - Dataset.max - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.max, @@ -329,25 +2269,25 @@ def max( ) def min( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``min`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -356,6 +2296,7 @@ def min( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -363,6 +2304,14 @@ def min( New Dataset with ``min`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.min + dask.array.min + Dataset.min + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -400,13 +2349,6 @@ def min( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 2.0 1.0 - - See Also - -------- - numpy.min - Dataset.min - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.min, @@ -418,25 +2360,25 @@ def min( ) def mean( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``mean`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -445,6 +2387,7 @@ def mean( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -452,6 +2395,18 @@ def mean( New Dataset with ``mean`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -489,13 +2444,6 @@ def mean( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 2.0 2.0 - - See Also - -------- - numpy.mean - Dataset.mean - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.mean, @@ -507,26 +2455,26 @@ def mean( ) def prod( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``prod`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -541,6 +2489,7 @@ def prod( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -548,6 +2497,18 @@ def prod( New Dataset with ``prod`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -595,13 +2556,6 @@ def prod( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 4.0 3.0 - - See Also - -------- - numpy.prod - Dataset.prod - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.prod, @@ -614,26 +2568,26 @@ def prod( ) def sum( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``sum`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -648,6 +2602,7 @@ def sum( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -655,6 +2610,18 @@ def sum( New Dataset with ``sum`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -702,13 +2669,6 @@ def sum( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 4.0 4.0 - - See Also - -------- - numpy.sum - Dataset.sum - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.sum, @@ -721,26 +2681,30 @@ def sum( ) def std( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``std`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -748,6 +2712,7 @@ def std( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -755,6 +2720,18 @@ def std( New Dataset with ``std`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.std + dask.array.std + Dataset.std + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -793,43 +2770,51 @@ def std( Data variables: da (labels) float64 nan 0.0 1.0 - See Also - -------- - numpy.std - Dataset.std - :ref:`groupby` - User guide on groupby operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.groupby("labels").std(skipna=True, ddof=1) + + Dimensions: (labels: 3) + Coordinates: + * labels (labels) object 'a' 'b' 'c' + Data variables: + da (labels) float64 nan 0.0 1.414 """ return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, + ddof=ddof, numeric_only=True, keep_attrs=keep_attrs, **kwargs, ) def var( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``var`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -837,6 +2822,7 @@ def var( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -844,6 +2830,18 @@ def var( New Dataset with ``var`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.var + dask.array.var + Dataset.var + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -882,42 +2880,46 @@ def var( Data variables: da (labels) float64 nan 0.0 1.0 - See Also - -------- - numpy.var - Dataset.var - :ref:`groupby` - User guide on groupby operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.groupby("labels").var(skipna=True, ddof=1) + + Dimensions: (labels: 3) + Coordinates: + * labels (labels) object 'a' 'b' 'c' + Data variables: + da (labels) float64 nan 0.0 2.0 """ return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, + ddof=ddof, numeric_only=True, keep_attrs=keep_attrs, **kwargs, ) def median( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``median`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -926,6 +2928,7 @@ def median( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -933,6 +2936,18 @@ def median( New Dataset with ``median`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.median + dask.array.median + Dataset.median + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -970,13 +2985,6 @@ def median( * labels (labels) object 'a' 'b' 'c' Data variables: da (labels) float64 nan 2.0 2.0 - - See Also - -------- - numpy.median - Dataset.median - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.median, @@ -991,21 +2999,33 @@ def median( class DatasetResampleReductions: __slots__ = () + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "Dataset": + raise NotImplementedError() + def count( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``count`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1013,6 +3033,7 @@ def count( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1020,6 +3041,14 @@ def count( New Dataset with ``count`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.count + dask.array.count + Dataset.count + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -1047,13 +3076,6 @@ def count( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) int64 1 3 1 - - See Also - -------- - numpy.count - Dataset.count - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.count, @@ -1064,20 +3086,20 @@ def count( ) def all( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``all`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1085,6 +3107,7 @@ def all( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1092,6 +3115,14 @@ def all( New Dataset with ``all`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.all + dask.array.all + Dataset.all + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -1119,13 +3150,6 @@ def all( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) bool True True False - - See Also - -------- - numpy.all - Dataset.all - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.array_all, @@ -1136,20 +3160,20 @@ def all( ) def any( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``any`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1157,6 +3181,7 @@ def any( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1164,6 +3189,14 @@ def any( New Dataset with ``any`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.any + dask.array.any + Dataset.any + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -1185,19 +3218,12 @@ def any( da (time) bool True True True True True False >>> ds.resample(time="3M").any() - - Dimensions: (time: 3) - Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - Data variables: - da (time) bool True True True - - See Also - -------- - numpy.any - Dataset.any - :ref:`resampling` - User guide on resampling operations. + + Dimensions: (time: 3) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + Data variables: + da (time) bool True True True """ return self.reduce( duck_array_ops.array_any, @@ -1208,25 +3234,25 @@ def any( ) def max( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``max`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -1235,6 +3261,7 @@ def max( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1242,6 +3269,14 @@ def max( New Dataset with ``max`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.max + dask.array.max + Dataset.max + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -1279,13 +3314,6 @@ def max( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 1.0 3.0 nan - - See Also - -------- - numpy.max - Dataset.max - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.max, @@ -1297,25 +3325,25 @@ def max( ) def min( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``min`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -1324,6 +3352,7 @@ def min( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1331,6 +3360,14 @@ def min( New Dataset with ``min`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.min + dask.array.min + Dataset.min + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -1368,13 +3405,6 @@ def min( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 1.0 1.0 nan - - See Also - -------- - numpy.min - Dataset.min - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.min, @@ -1386,25 +3416,25 @@ def min( ) def mean( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``mean`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -1413,6 +3443,7 @@ def mean( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1420,6 +3451,18 @@ def mean( New Dataset with ``mean`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1457,13 +3500,6 @@ def mean( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 1.0 2.0 nan - - See Also - -------- - numpy.mean - Dataset.mean - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.mean, @@ -1475,26 +3511,26 @@ def mean( ) def prod( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``prod`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -1509,6 +3545,7 @@ def prod( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1516,6 +3553,18 @@ def prod( New Dataset with ``prod`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1563,13 +3612,6 @@ def prod( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 nan 6.0 nan - - See Also - -------- - numpy.prod - Dataset.prod - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.prod, @@ -1582,26 +3624,26 @@ def prod( ) def sum( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``sum`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -1616,6 +3658,7 @@ def sum( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1623,6 +3666,18 @@ def sum( New Dataset with ``sum`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1670,13 +3725,6 @@ def sum( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 nan 6.0 nan - - See Also - -------- - numpy.sum - Dataset.sum - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.sum, @@ -1689,26 +3737,30 @@ def sum( ) def std( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``std`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1716,6 +3768,7 @@ def std( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1723,6 +3776,18 @@ def std( New Dataset with ``std`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.std + dask.array.std + Dataset.std + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1761,43 +3826,51 @@ def std( Data variables: da (time) float64 0.0 0.8165 nan - See Also - -------- - numpy.std - Dataset.std - :ref:`resampling` - User guide on resampling operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.resample(time="3M").std(skipna=True, ddof=1) + + Dimensions: (time: 3) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + Data variables: + da (time) float64 nan 1.0 nan """ return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, + ddof=ddof, numeric_only=True, keep_attrs=keep_attrs, **kwargs, ) def var( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``var`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1805,6 +3878,7 @@ def var( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1812,6 +3886,18 @@ def var( New Dataset with ``var`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.var + dask.array.var + Dataset.var + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1850,42 +3936,46 @@ def var( Data variables: da (time) float64 0.0 0.6667 nan - See Also - -------- - numpy.var - Dataset.var - :ref:`resampling` - User guide on resampling operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> ds.resample(time="3M").var(skipna=True, ddof=1) + + Dimensions: (time: 3) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + Data variables: + da (time) float64 nan 1.0 nan """ return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, + ddof=ddof, numeric_only=True, keep_attrs=keep_attrs, **kwargs, ) def median( - self: DatasetReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_Dataset: + ) -> "Dataset": """ Reduce this Dataset's data by applying ``median`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -1894,6 +3984,7 @@ def median( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -1901,6 +3992,18 @@ def median( New Dataset with ``median`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.median + dask.array.median + Dataset.median + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -1938,13 +4041,6 @@ def median( * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 Data variables: da (time) float64 1.0 2.0 nan - - See Also - -------- - numpy.median - Dataset.median - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.median, @@ -1956,37 +4052,36 @@ def median( ) -class DataArrayReduce(Protocol): +class DataArrayGroupByReductions: + __slots__ = () + def reduce( self, func: Callable[..., Any], dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, axis: Union[None, int, Sequence[int]] = None, keep_attrs: bool = None, keepdims: bool = False, **kwargs: Any, - ) -> T_DataArray: - ... - - -class DataArrayGroupByReductions: - __slots__ = () + ) -> "DataArray": + raise NotImplementedError() def count( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``count`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -1994,6 +4089,7 @@ def count( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2001,6 +4097,14 @@ def count( New DataArray with ``count`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.count + dask.array.count + DataArray.count + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -2023,13 +4127,6 @@ def count( array([1, 2, 2]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.count - DataArray.count - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.count, @@ -2039,20 +4136,20 @@ def count( ) def all( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``all`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2060,6 +4157,7 @@ def all( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2067,6 +4165,14 @@ def all( New DataArray with ``all`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.all + dask.array.all + DataArray.all + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -2089,13 +4195,6 @@ def all( array([False, True, True]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.all - DataArray.all - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.array_all, @@ -2105,20 +4204,20 @@ def all( ) def any( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``any`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2126,6 +4225,7 @@ def any( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2133,6 +4233,14 @@ def any( New DataArray with ``any`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.any + dask.array.any + DataArray.any + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -2155,13 +4263,6 @@ def any( array([ True, True, True]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.any - DataArray.any - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.array_any, @@ -2171,25 +4272,25 @@ def any( ) def max( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``max`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -2198,6 +4299,7 @@ def max( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2205,6 +4307,14 @@ def max( New DataArray with ``max`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.max + dask.array.max + DataArray.max + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -2235,13 +4345,6 @@ def max( array([nan, 2., 3.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.max - DataArray.max - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.max, @@ -2252,25 +4355,25 @@ def max( ) def min( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``min`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -2279,6 +4382,7 @@ def min( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2286,6 +4390,14 @@ def min( New DataArray with ``min`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.min + dask.array.min + DataArray.min + :ref:`groupby` + User guide on groupby operations. + Examples -------- >>> da = xr.DataArray( @@ -2316,13 +4428,6 @@ def min( array([nan, 2., 1.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.min - DataArray.min - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.min, @@ -2333,25 +4438,25 @@ def min( ) def mean( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``mean`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -2360,6 +4465,7 @@ def mean( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2367,6 +4473,18 @@ def mean( New DataArray with ``mean`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.mean + dask.array.mean + DataArray.mean + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2397,13 +4515,6 @@ def mean( array([nan, 2., 2.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.mean - DataArray.mean - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.mean, @@ -2414,26 +4525,26 @@ def mean( ) def prod( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``prod`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -2448,6 +4559,7 @@ def prod( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2455,6 +4567,18 @@ def prod( New DataArray with ``prod`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.prod + dask.array.prod + DataArray.prod + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2493,13 +4617,6 @@ def prod( array([nan, 4., 3.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.prod - DataArray.prod - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.prod, @@ -2511,26 +4628,26 @@ def prod( ) def sum( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``sum`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -2545,6 +4662,7 @@ def sum( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2552,6 +4670,18 @@ def sum( New DataArray with ``sum`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.sum + dask.array.sum + DataArray.sum + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2590,13 +4720,6 @@ def sum( array([nan, 4., 4.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.sum - DataArray.sum - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.sum, @@ -2608,26 +4731,30 @@ def sum( ) def std( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``std`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2635,6 +4762,7 @@ def std( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2642,6 +4770,18 @@ def std( New DataArray with ``std`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.std + dask.array.std + DataArray.std + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2673,42 +4813,48 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' - See Also - -------- - numpy.std - DataArray.std - :ref:`groupby` - User guide on groupby operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.groupby("labels").std(skipna=True, ddof=1) + + array([ nan, 0. , 1.41421356]) + Coordinates: + * labels (labels) object 'a' 'b' 'c' """ return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, + ddof=ddof, keep_attrs=keep_attrs, **kwargs, ) def var( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``var`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2716,6 +4862,7 @@ def var( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2723,6 +4870,18 @@ def var( New DataArray with ``var`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.var + dask.array.var + DataArray.var + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2754,41 +4913,43 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' - See Also - -------- - numpy.var - DataArray.var - :ref:`groupby` - User guide on groupby operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.groupby("labels").var(skipna=True, ddof=1) + + array([nan, 0., 2.]) + Coordinates: + * labels (labels) object 'a' 'b' 'c' """ return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, + ddof=ddof, keep_attrs=keep_attrs, **kwargs, ) def median( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``median`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -2797,6 +4958,7 @@ def median( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2804,6 +4966,18 @@ def median( New DataArray with ``median`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.median + dask.array.median + DataArray.median + :ref:`groupby` + User guide on groupby operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -2834,13 +5008,6 @@ def median( array([nan, 2., 2.]) Coordinates: * labels (labels) object 'a' 'b' 'c' - - See Also - -------- - numpy.median - DataArray.median - :ref:`groupby` - User guide on groupby operations. """ return self.reduce( duck_array_ops.median, @@ -2854,21 +5021,33 @@ def median( class DataArrayResampleReductions: __slots__ = () + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "DataArray": + raise NotImplementedError() + def count( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``count`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2876,6 +5055,7 @@ def count( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2883,6 +5063,14 @@ def count( New DataArray with ``count`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.count + dask.array.count + DataArray.count + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -2905,13 +5093,6 @@ def count( array([1, 3, 1]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.count - DataArray.count - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.count, @@ -2921,20 +5102,20 @@ def count( ) def all( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``all`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -2942,6 +5123,7 @@ def all( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -2949,6 +5131,14 @@ def all( New DataArray with ``all`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.all + dask.array.all + DataArray.all + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -2971,13 +5161,6 @@ def all( array([ True, True, False]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.all - DataArray.all - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.array_all, @@ -2987,20 +5170,20 @@ def all( ) def any( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``any`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -3008,6 +5191,7 @@ def any( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3015,6 +5199,14 @@ def any( New DataArray with ``any`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.any + dask.array.any + DataArray.any + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -3037,13 +5229,6 @@ def any( array([ True, True, True]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.any - DataArray.any - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.array_any, @@ -3053,25 +5238,25 @@ def any( ) def max( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``max`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -3080,6 +5265,7 @@ def max( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3087,6 +5273,14 @@ def max( New DataArray with ``max`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.max + dask.array.max + DataArray.max + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -3117,13 +5311,6 @@ def max( array([ 1., 3., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.max - DataArray.max - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.max, @@ -3134,25 +5321,25 @@ def max( ) def min( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``min`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -3161,6 +5348,7 @@ def min( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3168,6 +5356,14 @@ def min( New DataArray with ``min`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.min + dask.array.min + DataArray.min + :ref:`resampling` + User guide on resampling operations. + Examples -------- >>> da = xr.DataArray( @@ -3198,13 +5394,6 @@ def min( array([ 1., 1., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.min - DataArray.min - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.min, @@ -3215,25 +5404,25 @@ def min( ) def mean( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``mean`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -3242,6 +5431,7 @@ def mean( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3249,6 +5439,18 @@ def mean( New DataArray with ``mean`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.mean + dask.array.mean + DataArray.mean + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3279,13 +5481,6 @@ def mean( array([ 1., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.mean - DataArray.mean - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.mean, @@ -3296,26 +5491,26 @@ def mean( ) def prod( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``prod`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -3330,6 +5525,7 @@ def prod( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3337,6 +5533,18 @@ def prod( New DataArray with ``prod`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.prod + dask.array.prod + DataArray.prod + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3375,13 +5583,6 @@ def prod( array([nan, 6., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.prod - DataArray.prod - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.prod, @@ -3393,26 +5594,26 @@ def prod( ) def sum( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, min_count: Optional[int] = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``sum`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If @@ -3427,6 +5628,7 @@ def sum( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3434,6 +5636,18 @@ def sum( New DataArray with ``sum`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.sum + dask.array.sum + DataArray.sum + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3472,13 +5686,6 @@ def sum( array([nan, 6., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.sum - DataArray.sum - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.sum, @@ -3490,26 +5697,30 @@ def sum( ) def std( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``std`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -3517,6 +5728,7 @@ def std( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3524,6 +5736,18 @@ def std( New DataArray with ``std`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.std + dask.array.std + DataArray.std + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3555,42 +5779,48 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - See Also - -------- - numpy.std - DataArray.std - :ref:`resampling` - User guide on resampling operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.resample(time="3M").std(skipna=True, ddof=1) + + array([nan, 1., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ return self.reduce( duck_array_ops.std, dim=dim, skipna=skipna, + ddof=ddof, keep_attrs=keep_attrs, **kwargs, ) def var( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, + ddof: int = 0, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``var`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. keep_attrs : bool, optional If True, ``attrs`` will be copied from the original object to the new one. If False (default), the new object will be @@ -3598,6 +5828,7 @@ def var( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3605,6 +5836,18 @@ def var( New DataArray with ``var`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.var + dask.array.var + DataArray.var + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3636,41 +5879,43 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - See Also - -------- - numpy.var - DataArray.var - :ref:`resampling` - User guide on resampling operations. + Specify ``ddof=1`` for an unbiased estimate. + + >>> da.resample(time="3M").var(skipna=True, ddof=1) + + array([nan, 1., nan]) + Coordinates: + * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ return self.reduce( duck_array_ops.var, dim=dim, skipna=skipna, + ddof=ddof, keep_attrs=keep_attrs, **kwargs, ) def median( - self: DataArrayReduce, + self, dim: Union[None, Hashable, Sequence[Hashable]] = None, - skipna: bool = True, + *, + skipna: bool = None, keep_attrs: bool = None, **kwargs, - ) -> T_DataArray: + ) -> "DataArray": """ Reduce this DataArray's data by applying ``median`` along some dimension(s). Parameters ---------- - dim : hashable or iterable of hashable, optional + dim : hashable or iterable of hashable, default: None Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. If ``None``, will reduce over all dimensions - present in the grouped variable. - skipna : bool, optional + or ``dim=["x", "y"]``. If None, will reduce over all dimensions. + skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). keep_attrs : bool, optional If True, ``attrs`` will be copied from the original @@ -3679,6 +5924,7 @@ def median( **kwargs : dict Additional keyword arguments passed on to the appropriate array function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. Returns ------- @@ -3686,6 +5932,18 @@ def median( New DataArray with ``median`` applied to its data and the indicated dimension(s) removed + See Also + -------- + numpy.median + dask.array.median + DataArray.median + :ref:`resampling` + User guide on resampling operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + Examples -------- >>> da = xr.DataArray( @@ -3716,13 +5974,6 @@ def median( array([ 1., 2., nan]) Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 - - See Also - -------- - numpy.median - DataArray.median - :ref:`resampling` - User guide on resampling operations. """ return self.reduce( duck_array_ops.median, diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 814e9a59877..bf8d6ccaeb6 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -105,7 +105,6 @@ class VariableArithmetic( class DatasetArithmetic( ImplementsDatasetReduce, - IncludeReduceMethods, IncludeCumMethods, SupportsArithmetic, DatasetOpsMixin, @@ -116,7 +115,6 @@ class DatasetArithmetic( class DataArrayArithmetic( ImplementsArrayReduce, - IncludeReduceMethods, IncludeCumMethods, IncludeNumpySameMethods, SupportsArithmetic, diff --git a/xarray/core/common.py b/xarray/core/common.py index bee59c6cc7d..cb6da986892 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -55,12 +55,14 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool if include_skipna: def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs): - return self.reduce(func, dim, axis, skipna=skipna, **kwargs) + return self.reduce( + func=func, dim=dim, axis=axis, skipna=skipna, **kwargs + ) else: def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore[misc] - return self.reduce(func, dim, axis, **kwargs) + return self.reduce(func=func, dim=dim, axis=axis, **kwargs) return wrapped_func @@ -93,13 +95,19 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool def wrapped_func(self, dim=None, skipna=None, **kwargs): return self.reduce( - func, dim, skipna=skipna, numeric_only=numeric_only, **kwargs + func=func, + dim=dim, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, ) else: def wrapped_func(self, dim=None, **kwargs): # type: ignore[misc] - return self.reduce(func, dim, numeric_only=numeric_only, **kwargs) + return self.reduce( + func=func, dim=dim, numeric_only=numeric_only, **kwargs + ) return wrapped_func diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e04e5cb9c51..d7c3fd9bab7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -32,6 +32,7 @@ utils, weighted, ) +from ._reductions import DataArrayReductions from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor from .alignment import ( @@ -213,7 +214,9 @@ def __setitem__(self, key, value) -> None: _THIS_ARRAY = ReprObject("") -class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): +class DataArray( + AbstractArray, DataWithCoords, DataArrayArithmetic, DataArrayReductions +): """N-dimensional array with labeled coordinates and dimensions. DataArray provides a wrapper around numpy ndarrays that uses @@ -2655,6 +2658,7 @@ def reduce( self, func: Callable[..., Any], dim: None | Hashable | Sequence[Hashable] = None, + *, axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b3112bdc7ab..dd7807c2e7c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -48,6 +48,7 @@ utils, weighted, ) +from ._reductions import DatasetReductions from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DatasetArithmetic from .common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes @@ -573,7 +574,7 @@ def __setitem__(self, key, value) -> None: self.dataset[pos_indexers] = value -class Dataset(DataWithCoords, DatasetArithmetic, Mapping): +class Dataset(DataWithCoords, DatasetReductions, DatasetArithmetic, Mapping): """A multi-dimensional, in memory, array database. A dataset resembles an in-memory representation of a NetCDF file, @@ -5004,6 +5005,7 @@ def reduce( self, func: Callable, dim: Hashable | Iterable[Hashable] = None, + *, keep_attrs: bool = None, keepdims: bool = False, numeric_only: bool = False, @@ -5039,7 +5041,7 @@ def reduce( Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. """ - if "axis" in kwargs: + if kwargs.get("axis", None) is not None: raise ValueError( "passing 'axis' to Dataset reduce methods is ambiguous." " Please use 'dim' instead." diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d3ec824159c..dea8949b84f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,5 +1,6 @@ import datetime import warnings +from typing import Any, Callable, Hashable, Sequence, Union import numpy as np import pandas as pd @@ -866,7 +867,15 @@ def _combine(self, applied, shortcut=False): return combined def reduce( - self, func, dim=None, axis=None, keep_attrs=None, shortcut=True, **kwargs + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, ): """Reduce the items in this group by applying `func` along some dimension(s). @@ -899,11 +908,15 @@ def reduce( if dim is None: dim = self._group_dim - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - def reduce_array(ar): - return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) + return ar.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) check_reduce_dims(dim, self.dims) @@ -981,7 +994,16 @@ def _combine(self, applied): combined = self._maybe_unstack(combined) return combined - def reduce(self, func, dim=None, keep_attrs=None, **kwargs): + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ): """Reduce the items in this group by applying `func` along some dimension(s). @@ -1013,11 +1035,15 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): if dim is None: dim = self._group_dim - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - def reduce_dataset(ds): - return ds.reduce(func, dim, keep_attrs, **kwargs) + return ds.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) check_reduce_dims(dim, self.dims) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index e2f599e8b4e..ed665ad4048 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,4 +1,5 @@ import warnings +from typing import Any, Callable, Hashable, Sequence, Union from ._reductions import DataArrayResampleReductions, DatasetResampleReductions from .groupby import DataArrayGroupByBase, DatasetGroupByBase @@ -157,7 +158,7 @@ def _interpolate(self, kind="linear"): ) -class DataArrayResample(DataArrayResampleReductions, DataArrayGroupByBase, Resample): +class DataArrayResample(DataArrayGroupByBase, DataArrayResampleReductions, Resample): """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ @@ -248,7 +249,7 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -class DatasetResample(DatasetResampleReductions, DatasetGroupByBase, Resample): +class DatasetResample(DatasetGroupByBase, DatasetResampleReductions, Resample): """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): @@ -316,7 +317,16 @@ def apply(self, func, args=(), shortcut=None, **kwargs): ) return self.map(func=func, shortcut=shortcut, args=args, **kwargs) - def reduce(self, func, dim=None, keep_attrs=None, **kwargs): + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ): """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. @@ -341,4 +351,11 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ - return super().reduce(func, dim, keep_attrs, **kwargs) + return super().reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 55c68b7ff6b..438bdf8bdc3 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2404,7 +2404,7 @@ def test_cumops(self): expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=["x", "y"]) assert_identical(expected, actual) - def test_reduce(self): + def test_reduce(self) -> None: coords = { "x": [-1, -2], "y": ["ab", "cd", "ef"], @@ -2445,7 +2445,7 @@ def test_reduce(self): expected = DataArray(orig.values.astype(int), dims=["x", "y"]).mean("x") assert_equal(actual, expected) - def test_reduce_keepdims(self): + def test_reduce_keepdims(self) -> None: coords = { "x": [-1, -2], "y": ["ab", "cd", "ef"], diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8d6c4f96857..41fdd9a373d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4469,7 +4469,7 @@ def test_where_drop_no_indexes(self): actual = ds.where(ds == 1, drop=True) assert_identical(expected, actual) - def test_reduce(self): + def test_reduce(self) -> None: data = create_test_data() assert len(data.mean().coords) == 0 @@ -4480,21 +4480,21 @@ def test_reduce(self): assert_equal(data.min(dim=["dim1"]), data.min(dim="dim1")) - for reduct, expected in [ + for reduct, expected_dims in [ ("dim2", ["dim3", "time", "dim1"]), (["dim2", "time"], ["dim3", "dim1"]), (("dim2", "time"), ["dim3", "dim1"]), ((), ["dim2", "dim3", "time", "dim1"]), ]: - actual = list(data.min(dim=reduct).dims) - assert actual == expected + actual_dims = list(data.min(dim=reduct).dims) + assert actual_dims == expected_dims assert_equal(data.mean(dim=[]), data) with pytest.raises(ValueError): data.mean(axis=0) - def test_reduce_coords(self): + def test_reduce_coords(self) -> None: # regression test for GH1470 data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4}) expected = xr.Dataset({"a": 2}, coords={"b": 4}) @@ -4518,7 +4518,7 @@ def test_mean_uint_dtype(self): ) assert_identical(actual, expected) - def test_reduce_bad_dim(self): + def test_reduce_bad_dim(self) -> None: data = create_test_data() with pytest.raises(ValueError, match=r"Dataset does not contain"): data.mean(dim="bad_dim") @@ -4553,7 +4553,7 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func): actual = getattr(data, func)(dim=reduct).dims assert list(actual) == expected - def test_reduce_non_numeric(self): + def test_reduce_non_numeric(self) -> None: data1 = create_test_data(seed=44) data2 = create_test_data(seed=44) add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} @@ -4570,7 +4570,7 @@ def test_reduce_non_numeric(self): @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" ) - def test_reduce_strings(self): + def test_reduce_strings(self) -> None: expected = Dataset({"x": "a"}) ds = Dataset({"x": ("y", ["a", "b"])}) ds.coords["y"] = [-10, 10] @@ -4607,7 +4607,7 @@ def test_reduce_strings(self): actual = ds.min() assert_identical(expected, actual) - def test_reduce_dtypes(self): + def test_reduce_dtypes(self) -> None: # regression test for GH342 expected = Dataset({"x": 1}) actual = Dataset({"x": True}).sum() @@ -4622,7 +4622,7 @@ def test_reduce_dtypes(self): actual = Dataset({"x": ("y", [1, 1j])}).sum() assert_identical(expected, actual) - def test_reduce_keep_attrs(self): + def test_reduce_keep_attrs(self) -> None: data = create_test_data() _attrs = {"attr1": "value1", "attr2": 2929} @@ -4664,7 +4664,7 @@ def test_reduce_scalars(self): actual = ds.var("a") assert_identical(expected, actual) - def test_reduce_only_one_axis(self): + def test_reduce_only_one_axis(self) -> None: def mean_only_one_axis(x, axis): if not isinstance(axis, integer_types): raise TypeError("non-integer axis") @@ -4680,7 +4680,7 @@ def mean_only_one_axis(x, axis): ): ds.reduce(mean_only_one_axis) - def test_reduce_no_axis(self): + def test_reduce_no_axis(self) -> None: def total_sum(x): return np.sum(x.flatten()) @@ -4692,7 +4692,7 @@ def total_sum(x): with pytest.raises(TypeError, match=r"unexpected keyword argument 'axis'"): ds.reduce(total_sum, dim="x") - def test_reduce_keepdims(self): + def test_reduce_keepdims(self) -> None: ds = Dataset( {"a": (["x", "y"], [[0, 1, 2, 3, 4]])}, coords={ diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index e12798b70c9..c329bc50c56 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,6 +1,5 @@ import datetime as dt import warnings -from textwrap import dedent import numpy as np import pandas as pd @@ -676,87 +675,6 @@ def test_multiple_dims(dtype, dask, skipna, func): assert_allclose(actual, expected) -def test_docs(): - # with min_count - actual = DataArray.sum.__doc__ - expected = dedent( - """\ - Reduce this DataArray's data by applying `sum` along some dimension(s). - - Parameters - ---------- - dim : str or sequence of str, optional - Dimension(s) over which to apply `sum`. - axis : int or sequence of int, optional - Axis(es) over which to apply `sum`. Only one of the 'dim' - and 'axis' arguments can be supplied. If neither are supplied, then - `sum` is calculated over axes. - skipna : bool, optional - If True, skip missing values (as marked by NaN). By default, only - skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). - min_count : int, default: None - The required number of valid values to perform the operation. If - fewer than min_count non-NA values are present the result will be - NA. Only used if skipna is set to True or defaults to True for the - array's dtype. New in version 0.10.8: Added with the default being - None. Changed in version 0.17.0: if specified on an integer array - and skipna=True, the result will be a float array. - keep_attrs : bool, optional - If True, the attributes (`attrs`) will be copied from the original - object to the new one. If False (default), the new object will be - returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to the appropriate array - function for calculating `sum` on this object's data. - - Returns - ------- - reduced : DataArray - New DataArray object with `sum` applied to its data and the - indicated dimension(s) removed. - """ - ) - assert actual == expected - - # without min_count - actual = DataArray.std.__doc__ - expected = dedent( - """\ - Reduce this DataArray's data by applying `std` along some dimension(s). - - Parameters - ---------- - dim : str or sequence of str, optional - Dimension(s) over which to apply `std`. - axis : int or sequence of int, optional - Axis(es) over which to apply `std`. Only one of the 'dim' - and 'axis' arguments can be supplied. If neither are supplied, then - `std` is calculated over axes. - skipna : bool, optional - If True, skip missing values (as marked by NaN). By default, only - skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). - keep_attrs : bool, optional - If True, the attributes (`attrs`) will be copied from the original - object to the new one. If False (default), the new object will be - returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to the appropriate array - function for calculating `std` on this object's data. - - Returns - ------- - reduced : DataArray - New DataArray object with `std` applied to its data and the - indicated dimension(s) removed. - """ - ) - assert actual == expected - - def test_datetime_to_numeric_datetime64(): times = pd.date_range("2000", periods=5, freq="7D").values result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h") diff --git a/xarray/util/generate_reductions.py b/xarray/util/generate_reductions.py index 70c92d1a96f..e79c94e8907 100644 --- a/xarray/util/generate_reductions.py +++ b/xarray/util/generate_reductions.py @@ -3,157 +3,249 @@ For internal xarray development use only. Usage: - python xarray/util/generate_reductions.py > xarray/core/_reductions.py + python xarray/util/generate_reductions.py pytest --doctest-modules xarray/core/_reductions.py --accept || true - pytest --doctest-modules xarray/core/_reductions.py --accept + pytest --doctest-modules xarray/core/_reductions.py This requires [pytest-accept](https://github.com/max-sixty/pytest-accept). The second run of pytest is deliberate, since the first will return an error while replacing the doctests. """ - import collections import textwrap -from functools import partial -from typing import Callable, Optional +from dataclasses import dataclass MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" # This file was generated using xarray.util.generate_reductions. Do not edit manually. -from typing import Any, Callable, Hashable, Optional, Protocol, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops -from .types import T_DataArray, T_Dataset''' -OBJ_PREAMBLE = """ +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset''' + + +CLASS_PREAMBLE = """ + +class {obj}{cls}Reductions: + __slots__ = () -class {obj}Reduce(Protocol): def reduce( self, func: Callable[..., Any], dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, axis: Union[None, int, Sequence[int]] = None, keep_attrs: bool = None, keepdims: bool = False, **kwargs: Any, - ) -> T_{obj}: - ...""" + ) -> "{obj}": + raise NotImplementedError()""" +TEMPLATE_REDUCTION_SIGNATURE = ''' + def {method}( + self, + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *,{extra_kwargs} + keep_attrs: bool = None, + **kwargs, + ) -> "{obj}": + """ + Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). -CLASS_PREAMBLE = """ + Parameters + ----------''' -class {obj}{cls}Reductions: - __slots__ = ()""" +TEMPLATE_RETURNS = """ + Returns + ------- + reduced : {obj} + New {obj} with ``{method}`` applied to its data and the + indicated dimension(s) removed""" -_SKIPNA_DOCSTRING = """ -skipna : bool, optional +TEMPLATE_SEE_ALSO = """ + See Also + -------- + numpy.{method} + dask.array.{method} + {see_also_obj}.{method} + :ref:`{docref}` + User guide on {docref_description}.""" + +TEMPLATE_NOTES = """ + Notes + ----- + {notes}""" + +_DIM_DOCSTRING = """dim : hashable or iterable of hashable, default: None + Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If None, will reduce over all dimensions.""" + +_SKIPNA_DOCSTRING = """skipna : bool, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been + have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64).""" -_MINCOUNT_DOCSTRING = """ -min_count : int, default: None +_MINCOUNT_DOCSTRING = """min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be NA. Only used if skipna is set to True or defaults to True for the array's dtype. Changed in version 0.17.0: if specified on an integer array and skipna=True, the result will be a float array.""" +_DDOF_DOCSTRING = """ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements.""" + +_KEEP_ATTRS_DOCSTRING = """keep_attrs : bool, optional + If True, ``attrs`` will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes.""" + +_KWARGS_DOCSTRING = """**kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating ``{method}`` on this object's data. + These could include dask-specific kwargs like ``split_every``.""" -BOOL_REDUCE_METHODS = ["all", "any"] -NAN_REDUCE_METHODS = [ - "max", - "min", - "mean", - "prod", - "sum", - "std", - "var", - "median", -] NAN_CUM_METHODS = ["cumsum", "cumprod"] -MIN_COUNT_METHODS = ["prod", "sum"] + NUMERIC_ONLY_METHODS = [ - "mean", - "std", - "var", - "sum", - "prod", - "median", "cumsum", "cumprod", ] +_NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." + +ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") +skipna = ExtraKwarg( + docs=_SKIPNA_DOCSTRING, + kwarg="skipna: bool = None,", + call="skipna=skipna,", + example="""\n + Use ``skipna`` to control whether NaNs are ignored. -TEMPLATE_REDUCTION = ''' - def {method}( - self: {obj}Reduce, - dim: Union[None, Hashable, Sequence[Hashable]] = None,{skip_na.kwarg}{min_count.kwarg} - keep_attrs: bool = None, - **kwargs, - ) -> T_{obj}: - """ - Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). + >>> {calculation}(skipna=False)""", +) +min_count = ExtraKwarg( + docs=_MINCOUNT_DOCSTRING, + kwarg="min_count: Optional[int] = None,", + call="min_count=min_count,", + example="""\n + Specify ``min_count`` for finer control over when NaNs are ignored. - Parameters - ---------- - dim : hashable or iterable of hashable, optional - Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"`` - or ``dim=["x", "y"]``. {extra_dim}{extra_args}{skip_na.docs}{min_count.docs} - keep_attrs : bool, optional - If True, ``attrs`` will be copied from the original - object to the new one. If False (default), the new object will be - returned without attributes. - **kwargs : dict - Additional keyword arguments passed on to the appropriate array - function for calculating ``{method}`` on this object's data. + >>> {calculation}(skipna=True, min_count=2)""", +) +ddof = ExtraKwarg( + docs=_DDOF_DOCSTRING, + kwarg="ddof: int = 0,", + call="ddof=ddof,", + example="""\n + Specify ``ddof=1`` for an unbiased estimate. - Returns - ------- - reduced : {obj} - New {obj} with ``{method}`` applied to its data and the - indicated dimension(s) removed + >>> {calculation}(skipna=True, ddof=1)""", +) - Examples - --------{example} - See Also - -------- - numpy.{method} - {obj}.{method} - :ref:`{docref}` - User guide on {docref} operations. - """ - return self.reduce( - duck_array_ops.{array_method}, - dim=dim,{skip_na.call}{min_count.call}{numeric_only_call} - keep_attrs=keep_attrs, - **kwargs, - )''' +class Method: + def __init__( + self, + name, + bool_reduce=False, + extra_kwargs=tuple(), + numeric_only=False, + ): + self.name = name + self.extra_kwargs = extra_kwargs + self.numeric_only = numeric_only + + if bool_reduce: + self.array_method = f"array_{name}" + self.np_example_array = """ + ... np.array([True, True, True, True, True, False], dtype=bool),""" + + else: + self.array_method = name + self.np_example_array = """ + ... np.array([1, 2, 3, 1, 2, np.nan]),""" + +class ReductionGenerator: + def __init__( + self, + cls, + datastructure, + methods, + docref, + docref_description, + example_call_preamble, + see_also_obj=None, + ): + self.datastructure = datastructure + self.cls = cls + self.methods = methods + self.docref = docref + self.docref_description = docref_description + self.example_call_preamble = example_call_preamble + self.preamble = CLASS_PREAMBLE.format(obj=datastructure.name, cls=cls) + if not see_also_obj: + self.see_also_obj = self.datastructure.name + else: + self.see_also_obj = see_also_obj -def generate_groupby_example(obj: str, cls: str, method: str): - """Generate examples for method.""" - dx = "ds" if obj == "Dataset" else "da" - if cls == "Resample": - calculation = f'{dx}.resample(time="3M").{method}' - elif cls == "GroupBy": - calculation = f'{dx}.groupby("labels").{method}' - else: - raise ValueError + def generate_methods(self): + yield [self.preamble] + for method in self.methods: + yield self.generate_method(method) - if method in BOOL_REDUCE_METHODS: - np_array = """ - ... np.array([True, True, True, True, True, False], dtype=bool),""" + def generate_method(self, method): + template_kwargs = dict(obj=self.datastructure.name, method=method.name) - else: - np_array = """ - ... np.array([1, 2, 3, 1, 2, np.nan]),""" + if method.extra_kwargs: + extra_kwargs = "\n " + "\n ".join( + [kwarg.kwarg for kwarg in method.extra_kwargs if kwarg.kwarg] + ) + else: + extra_kwargs = "" + + yield TEMPLATE_REDUCTION_SIGNATURE.format( + **template_kwargs, + extra_kwargs=extra_kwargs, + ) + + for text in [ + _DIM_DOCSTRING.format(method=method.name), + *(kwarg.docs for kwarg in method.extra_kwargs if kwarg.docs), + _KEEP_ATTRS_DOCSTRING, + _KWARGS_DOCSTRING.format(method=method.name), + ]: + if text: + yield textwrap.indent(text, 8 * " ") + + yield TEMPLATE_RETURNS.format(**template_kwargs) + + yield TEMPLATE_SEE_ALSO.format( + **template_kwargs, + docref=self.docref, + docref_description=self.docref_description, + see_also_obj=self.see_also_obj, + ) - create_da = f""" - >>> da = xr.DataArray({np_array} + if method.numeric_only: + yield TEMPLATE_NOTES.format(notes=_NUMERIC_ONLY_NOTES) + + yield textwrap.indent(self.generate_example(method=method), "") + + yield ' """' + + yield self.generate_code(method) + + def generate_example(self, method): + create_da = f""" + >>> da = xr.DataArray({method.np_example_array} ... dims="time", ... coords=dict( ... time=("time", pd.date_range("01-01-2001", freq="M", periods=6)), @@ -161,130 +253,149 @@ def generate_groupby_example(obj: str, cls: str, method: str): ... ), ... )""" - if obj == "Dataset": - maybe_dataset = """ - >>> ds = xr.Dataset(dict(da=da)) - >>> ds""" - else: - maybe_dataset = """ - >>> da""" - - if method in NAN_REDUCE_METHODS: - maybe_skipna = f""" - - Use ``skipna`` to control whether NaNs are ignored. - - >>> {calculation}(skipna=False)""" - else: - maybe_skipna = "" - - if method in MIN_COUNT_METHODS: - maybe_mincount = f""" - - Specify ``min_count`` for finer control over when NaNs are ignored. + calculation = f"{self.datastructure.example_var_name}{self.example_call_preamble}.{method.name}" + if method.extra_kwargs: + extra_examples = "".join( + kwarg.example for kwarg in method.extra_kwargs if kwarg.example + ).format(calculation=calculation, method=method.name) + else: + extra_examples = "" - >>> {calculation}(skipna=True, min_count=2)""" - else: - maybe_mincount = "" + return f""" + Examples + --------{create_da}{self.datastructure.docstring_create} - return f"""{create_da}{maybe_dataset} + >>> {calculation}(){extra_examples}""" - >>> {calculation}(){maybe_skipna}{maybe_mincount}""" +class GenericReductionGenerator(ReductionGenerator): + def generate_code(self, method): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] -def generate_method( - obj: str, - docref: str, - method: str, - skipna: bool, - example_generator: Callable, - array_method: Optional[str] = None, -): - if not array_method: - array_method = method + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") - if obj == "Dataset": - if method in NUMERIC_ONLY_METHODS: - numeric_only_call = "\n numeric_only=True," + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), 12 * " ") else: - numeric_only_call = "\n numeric_only=False," - else: - numeric_only_call = "" - - kwarg = collections.namedtuple("kwarg", "docs kwarg call") - if skipna: - skip_na = kwarg( - docs=textwrap.indent(_SKIPNA_DOCSTRING, " "), - kwarg="\n skipna: bool = True,", - call="\n skipna=skipna,", - ) - else: - skip_na = kwarg(docs="", kwarg="", call="") - - if method in MIN_COUNT_METHODS: - min_count = kwarg( - docs=textwrap.indent(_MINCOUNT_DOCSTRING, " "), - kwarg="\n min_count: Optional[int] = None,", - call="\n min_count=min_count,", - ) - else: - min_count = kwarg(docs="", kwarg="", call="") - - return TEMPLATE_REDUCTION.format( - obj=obj, - docref=docref, - method=method, - array_method=array_method, - extra_dim="""If ``None``, will reduce over all dimensions - present in the grouped variable.""", - extra_args="", - skip_na=skip_na, - min_count=min_count, - numeric_only_call=numeric_only_call, - example=example_generator(obj=obj, method=method), - ) - - -def render(obj: str, cls: str, docref: str, example_generator: Callable): - yield CLASS_PREAMBLE.format(obj=obj, cls=cls) - yield generate_method( - obj, - method="count", - docref=docref, - skipna=False, - example_generator=example_generator, - ) - for method in BOOL_REDUCE_METHODS: - yield generate_method( - obj, - method=method, - docref=docref, - skipna=False, - array_method=f"array_{method}", - example_generator=example_generator, - ) - for method in NAN_REDUCE_METHODS: - yield generate_method( - obj, - method=method, - docref=docref, - skipna=True, - example_generator=example_generator, - ) + extra_kwargs = "" + return f"""\ + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + +REDUCTION_METHODS = ( + Method("count"), + Method("all", bool_reduce=True), + Method("any", bool_reduce=True), + Method("max", extra_kwargs=(skipna,)), + Method("min", extra_kwargs=(skipna,)), + Method("mean", extra_kwargs=(skipna,), numeric_only=True), + Method("prod", extra_kwargs=(skipna, min_count), numeric_only=True), + Method("sum", extra_kwargs=(skipna, min_count), numeric_only=True), + Method("std", extra_kwargs=(skipna, ddof), numeric_only=True), + Method("var", extra_kwargs=(skipna, ddof), numeric_only=True), + Method("median", extra_kwargs=(skipna,), numeric_only=True), +) + + +@dataclass +class DataStructure: + name: str + docstring_create: str + example_var_name: str + numeric_only: bool = False + + +DATASET_OBJECT = DataStructure( + name="Dataset", + docstring_create=""" + >>> ds = xr.Dataset(dict(da=da)) + >>> ds""", + example_var_name="ds", + numeric_only=True, +) +DATAARRAY_OBJECT = DataStructure( + name="DataArray", + docstring_create=""" + >>> da""", + example_var_name="da", + numeric_only=False, +) + +DATASET_GENERATOR = GenericReductionGenerator( + cls="", + datastructure=DATASET_OBJECT, + methods=REDUCTION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + see_also_obj="DataArray", +) +DATAARRAY_GENERATOR = GenericReductionGenerator( + cls="", + datastructure=DATAARRAY_OBJECT, + methods=REDUCTION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + see_also_obj="Dataset", +) + +DATAARRAY_GROUPBY_GENERATOR = GenericReductionGenerator( + cls="GroupBy", + datastructure=DATAARRAY_OBJECT, + methods=REDUCTION_METHODS, + docref="groupby", + docref_description="groupby operations", + example_call_preamble='.groupby("labels")', +) +DATAARRAY_RESAMPLE_GENERATOR = GenericReductionGenerator( + cls="Resample", + datastructure=DATAARRAY_OBJECT, + methods=REDUCTION_METHODS, + docref="resampling", + docref_description="resampling operations", + example_call_preamble='.resample(time="3M")', +) +DATASET_GROUPBY_GENERATOR = GenericReductionGenerator( + cls="GroupBy", + datastructure=DATASET_OBJECT, + methods=REDUCTION_METHODS, + docref="groupby", + docref_description="groupby operations", + example_call_preamble='.groupby("labels")', +) +DATASET_RESAMPLE_GENERATOR = GenericReductionGenerator( + cls="Resample", + datastructure=DATASET_OBJECT, + methods=REDUCTION_METHODS, + docref="resampling", + docref_description="resampling operations", + example_call_preamble='.resample(time="3M")', +) if __name__ == "__main__": - print(MODULE_PREAMBLE) - for obj in ["Dataset", "DataArray"]: - print(OBJ_PREAMBLE.format(obj=obj)) - for cls, docref in ( - ("GroupBy", "groupby"), - ("Resample", "resampling"), - ): - for line in render( - obj=obj, - cls=cls, - docref=docref, - example_generator=partial(generate_groupby_example, cls=cls), - ): - print(line) + import os + from pathlib import Path + + p = Path(os.getcwd()) + filepath = p.parent / "xarray" / "xarray" / "core" / "_reductions.py" + with open(filepath, mode="w", encoding="utf-8") as f: + f.write(MODULE_PREAMBLE + "\n") + for gen in [ + DATASET_GENERATOR, + DATAARRAY_GENERATOR, + DATASET_GROUPBY_GENERATOR, + DATASET_RESAMPLE_GENERATOR, + DATAARRAY_GROUPBY_GENERATOR, + DATAARRAY_RESAMPLE_GENERATOR, + ]: + for lines in gen.generate_methods(): + for line in lines: + f.write(line + "\n") From d535a3bf46ec26e3857bd0a350120107bb9a2199 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 13 Mar 2022 05:21:54 +0100 Subject: [PATCH 28/34] Run pyupgrade on core/groupby (#6351) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/groupby.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index dea8949b84f..68db14c29be 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import datetime import warnings -from typing import Any, Callable, Hashable, Sequence, Union +from typing import Any, Callable, Hashable, Sequence import numpy as np import pandas as pd @@ -869,9 +871,9 @@ def _combine(self, applied, shortcut=False): def reduce( self, func: Callable[..., Any], - dim: Union[None, Hashable, Sequence[Hashable]] = None, + dim: None | Hashable | Sequence[Hashable] = None, *, - axis: Union[None, int, Sequence[int]] = None, + axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, shortcut: bool = True, @@ -997,9 +999,9 @@ def _combine(self, applied): def reduce( self, func: Callable[..., Any], - dim: Union[None, Hashable, Sequence[Hashable]] = None, + dim: None | Hashable | Sequence[Hashable] = None, *, - axis: Union[None, int, Sequence[int]] = None, + axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, **kwargs: Any, From 573953ce85c693a9ba63667f438828ac151bbb37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:05:56 -0700 Subject: [PATCH 29/34] [pre-commit.ci] pre-commit autoupdate (#6357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/asottile/pyupgrade: v2.31.0 → v2.31.1](https://github.com/asottile/pyupgrade/compare/v2.31.0...v2.31.1) - [github.com/pre-commit/mirrors-mypy: v0.931 → v0.940](https://github.com/pre-commit/mirrors-mypy/compare/v0.931...v0.940) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63a2871b496..9de02156417 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.31.0 + rev: v2.31.1 hooks: - id: pyupgrade args: @@ -45,7 +45,7 @@ repos: # - id: velin # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v0.940 hooks: - id: mypy # Copied from setup.cfg From 70f2f5f28fb314d197eaf14ac3177b9222f2b268 Mon Sep 17 00:00:00 2001 From: keewis Date: Tue, 15 Mar 2022 20:06:31 +0100 Subject: [PATCH 30/34] Revert "explicitly install `ipython_genutils` (#6350)" (#6361) This reverts commit 9a71cc071603b1624f8eeca238a514267cec1bbc. --- ci/requirements/doc.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index f0511ce2006..e5fcc500f70 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -12,7 +12,6 @@ dependencies: - h5netcdf>=0.7.4 - ipykernel - ipython - - ipython_genutils # remove once `nbconvert` fixed its dependencies - iris>=2.3 - jupyter_client - matplotlib-base From 95bb9ae4233c16639682a532c14b26a3ea2728f3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 16 Mar 2022 09:19:22 +0530 Subject: [PATCH 31/34] Add new tutorial video (#6353) --- doc/tutorials-and-videos.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doc/tutorials-and-videos.rst b/doc/tutorials-and-videos.rst index 6a9602bcfa6..37b8d1968eb 100644 --- a/doc/tutorials-and-videos.rst +++ b/doc/tutorials-and-videos.rst @@ -18,6 +18,20 @@ Videos .. panels:: :card: text-center + --- + Xdev Python Tutorial Seminar Series 2022 Thinking with Xarray : High-level computation patterns | Deepak Cherian + ^^^ + .. raw:: html + + + + --- + Xdev Python Tutorial Seminar Series 2021 seminar introducing xarray (2 of 2) | Anderson Banihirwe + ^^^ + .. raw:: html + + + --- Xdev Python Tutorial Seminar Series 2021 seminar introducing xarray (1 of 2) | Anderson Banihirwe ^^^ From 53172cb1e03a98759faf77ef48efaa64676ad24a Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 17 Mar 2022 04:52:10 +0000 Subject: [PATCH 32/34] Allow write_empty_chunks to be set in Zarr encoding (#6348) --- xarray/backends/zarr.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 97517818d07..aca0b8064f5 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -212,7 +212,13 @@ def extract_zarr_variable_encoding( """ encoding = variable.encoding.copy() - valid_encodings = {"chunks", "compressor", "filters", "cache_metadata"} + valid_encodings = { + "chunks", + "compressor", + "filters", + "cache_metadata", + "write_empty_chunks", + } if raise_on_invalid: invalid = [k for k in encoding if k not in valid_encodings] From 99229efa6030bf872444234723f2d3199c3ef1e3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 17 Mar 2022 07:25:21 +0100 Subject: [PATCH 33/34] Remove test_rasterio_vrt_network (#6371) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/tests/test_backends.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e0bc0b10437..4696c41552f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4775,32 +4775,6 @@ def test_rasterio_vrt_with_src_crs(self): with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: assert da.crs == src_crs - @network - def test_rasterio_vrt_network(self): - # Make sure loading w/ rasterio give same results as xarray - import rasterio - - # use same url that rasterio package uses in tests - prefix = "https://landsat-pds.s3.amazonaws.com/L8/139/045/" - image = "LC81390452014295LGN00/LC81390452014295LGN00_B1.TIF" - httpstif = prefix + image - with rasterio.Env(aws_unsigned=True): - with rasterio.open(httpstif) as src: - with rasterio.vrt.WarpedVRT(src, crs="epsg:4326") as vrt: - expected_shape = vrt.width, vrt.height - expected_res = vrt.res - # Value of single pixel in center of image - lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2) - expected_val = next(vrt.sample([(lon, lat)])) - with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: - actual_shape = da.sizes["x"], da.sizes["y"] - actual_res = da.res - actual_val = da.sel(dict(x=lon, y=lat), method="nearest").data - - assert actual_shape == expected_shape - assert actual_res == expected_res - assert expected_val == actual_val - class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self): From 3ead17ea9e99283e2511b65b9d864d1c7b10b3c4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 17 Mar 2022 18:11:40 +0100 Subject: [PATCH 34/34] Explicit indexes (#5692) * no need to wrap pandas index in lazy index adapter * multi-index default level names * refactor setting Dataset/DataArray default indexes * update multi-index (text) repr Notes: - move the multi-index formatting logic into PandasMultiIndexingAdapter._repr_inline_ - inline repr: check for _repr_inline_ implementation first * remove print * minor fixes and improvements * fix dtype of index variables created from Index * fix multi-index selection regression See https://github.com/pydata/xarray/issues/5691 * check conflicting multi-index level names * update formatting (text and html) * check level name conflicts for midx given as coord * intended behavior or unwanted side effect? see #5732 * get rid of multi-index virtual coordinates Not totally yet: need to refactor set_index / reset_index * add level coords in indexes & keep coord order * fix copying multi-index level variable data * collect index for multi-index level variables Avoid re-creating the indexes for dimension variables. Collect then directly instead. Note: the change here is working for building new Datasets but I haven't checked other cases like merging different objects, etc. So I'm not sure this is the right approach. * wip refactor label based selection - Index.query must now return a mapping of {dim_name: positional_indexer} as indexes may be based on several coordinates with different dimensions - Added `group_coords_by_index` utility function (not used yet, not sure we'll need it) TODO: - Update DataArray selection - Update .loc and other places using remap_label_indexers - Fix selection of multi-index that returns only scalar coordinates * fix index query tests * fix multi-index adapter getitem scalar Multi-index level variables now return the scalar value that corresponds to the level instead of the multi-index tuple element (all levels). Also get rid of PandasMultiIndexingAdapter.__getitem__ cache optimization, which doesn't work with level scalar values and was premature optimization anyway. * wip refactor label based selection Fixed renamed dimension in the case of multi-index -> single index Updated DataArray._overwrite_indexes Dirty fix for alignment (not tested yet) * wip: deeper refactoring label-based sel Created QueryResult and MergedQueryResults classes for convenience. * fix some tests + minor tweaks * fix indexing PandasMultiIndexingAdapater When level is not None: - if result is another adapter: propagate it properly - if result is a numpy adapter: use level values * refactor cast label indexer to coord dtype Make the fix in #3153 specific to pandas indexes (i.e., do not apply it to other, custom indexes). See #5697 for details. This should also fix #5700 although no test has been added yet (we need to refactor set_index first). * better handling of multi-index level labels * label-based selection tweaks and fixes - Use a dataclass for QueryResult - Pass Variables and DataArrays indexers un-normalized to Index.query(). Indexes have the responsibility of returning the expected types for positional (dimension) indexers by adding back dimensions and coordinates if needed. - Typing fixes and tweaks * sel: propagate multi-index vars attrs/encoding Related to https://github.com/pydata/xarray/issues/1366 * wip refactor rename Still needs tests Also need to address the problem of multi-index level coordinates (data adapters currently not updated). We'll probably need for `Index.rename()` to also return new index variables? * wip refactor rename: return new vars from Index * refactor rename: update tests Still need to address default indexes internal checks (these are disabled for now). * typing tweaks * fix html formatting: dims w/ or w/o index * minor fixes and tweaks * wip refactor set_index * wip refactor set_index - Refactor reset_index - Improve creating new xarray indexes from pandas indexes with proper propagation of variable metadata (dtype, attrs, encoding) * refactor reorder_levels * Set pd.MultiIndex name from dim name Closes #4542 * .sel() with multi-index: return scalar coords ..instead of dropping those coordinates Closes #1408 once tests are fixed/updated * fix multi-index level coordinate inline repr * wip refactor stack Some changes: Multi-indexes are created only if the stacked dimensions each have a single coordinate index. In the other cases, multi-indexes are not created, which means that it is an irreversible operation (unstack will not work) Multi-indexes are created from the stacked coordinate variables and not anymore from the unstacked dimension indexes (level product). There's a significant decrease in performance but it is probably acceptable since it's now possible to avoid the creation of multi-indexes. It is possible to stack a dimension with a multi-index, but it drops the index. Otherwise it would make it hard for unstack() to figure out what's going on. It makes it clear that this is an irreversible operation. * stack: better rule for add/skip multi-index It is more robust and allows creating a multi-index from non-default (non-dimension) coordinates, as long as there's one and only one 1-d indexed coordinate for each dimension to stack. * add utility methods to class Indexes This removes some ugly & duplicate patterns. * re-arrange class Indexes internals * wip refactor stack / unstack stack: revert to old way of creating multi-index unstack: support non-default (non-dimension) multi-index (as long as there is exactly one multi-index per specified dimension) * add PandasMultiIndex.from_product_variables Refactored utils.multiindex_from_product_levels utility function into a new PandasMultiIndex classmethod. * unstack: propagate index coordinate metadata * wip: fix/update tests * fix/refactor reindex * functools.cached_property is since py38 Use the good old way as we still support py37 * update reset_coords * fix ipython key completion test make sure the stacked dataset used has a multi-index * update test_map_index_queries * do not coerce bool indexer as float coord dtype Fixes #5727 * add Index.create_variables() method This will probably be used instead of including index coordinate variables in the signature (return type) of many methods of ``Index``. It has several advantages: - More DRY, and for custom indexes that do not need to create coordinate variables with special data adapters, it's easier to just skip implementing this method and not bother with returning empty dicts in other methods - This allows to decouple index vs. coordinates creation. For many cases this can be done at the same time but for some cases like alignment this is useful - It's a more elegant solution when we need to propagate metadata (attrs, encoding) * improve Dataset/DataArray indexes proxy class One possible solution to the recurring problem of accessing coordinate variables related to a given index in some of Xarray's internals. I feel that the ``Indexes`` proxy is the right place for storing references to those indexed coordinate variables. It's safer than relying on a public attribute/property of ``Index``. * align Indexes API with Coordinates API Expose similar properties: - variables - dims * clean-up formatting using updated Indexes API * wip: refactor alignment * wip refactor alignment Almost done rewriting `align`. * tweaks and fixes - Some fixes and clean-up in new implementation of align - Add Index.join method (only for inner/outer join) - Tweak index generic types - IndexVars type: use Variable instead of IndexVariable (IndexVariable may eventually be dropped) * wip refactor alignment and reindex Imported all re-indexing logic into the `Alignator` class, which can be reused directly for both `align()` and `obj.reindex()`. TODO: - import override indexes into `Alignator` - connect Alignator to public API entry points - clean-up (remove old implementation functions) * wip refactor alignment / reindex: fixes and tweaks * wip refactor alignment (support join='override') * wip alignment: clean-up + tests I will fix mypy errors later. * refactor alignment fix tests & types annotations All `*align*` and `*reindex*` tests are passing! (locally) Mypy doesn't complain. TODO: refactor `interp` and `interp_like` (I only plan to fix it with the new Aligner class for now) * refactor interp and interp_like It now reuses the new alignment/reindex internals, but it still only supports dimension coordinates with a single index, * wip review merge * refactor swap_dims * refactor isel * add create_index option to stack * add Index.stack and Index.unstack methods Also added a `index_cls` parameter to `Dataset.stack` and `DataArray.stack`. Custom indexes may thus implement their own logic for stack/unstack, with the current limitation that a `pandas.MultiIndex` is still explicitly required for unstack optimized versions. * fix PandasIndex.isel with scalar indexers Fix concat tests and some groupby tests * fix index adapter (variable) dtype * more indexes invariants checks Check consistency between PandasIndex objects and coordinate variables data adapters (PandasIndexingAdapter). * fix DataArray.isel() * refactor Dataset/DataArray copy Filter indexes and preserve unique index objects for multi-coordinate indexes. Also fixed indexes tests (stack/unstack). * refactor computation (and merge) * minor fixes and tweaks * propagate_indexes -> filter_indexes_from_coords We can't just look at excluded dimensions anymore... * Update xarray/core/coordinates.py * refactor expand_dims * merge: avoid compare same indexes more than once * wip refactor update coords Check in merge that prioritized elements won't corrupt collected indexes. * refactor update/remove coords Fixes and tweaks Refactored assign, update, setitem, delitem, drop_vars Added or updated tests * misc. fixes * fix Dataset,__delitem__ * fix Dataset.reset_coords (default coords) * refactor Dataset.from_dataframe * Fix .sel with DataArray and multi-index * PandasIndex.from_variables: preserve wrapped index Prevent converting special index types like ``pd.CategoricalIndex`` when such objects are wrapped in xarray variables. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * two minor fixes after merging main * filter_indexes_from_coords: preserve index order * implicit multi-index from coord: fix edge case * groupby combine: fix missing index in result When coord is None * implicit multi-index coord: more robust fix Also cover the cases where indexes are equal but not identical (caused a couple of tests failing). Perfs may not be a concern for such edge case? * backward compat fix: multi-index given as data-var ... in Dataset constructor * refactor concat Summary: - Add the ``Index.concat()`` class method with implementations for ``PandasIndex`` and ``PandasMultiIndex`` (adapted from ``IndexVariable.concat()``). - Use ``Index.concat()`` to create new indexes (and their corresponding variables) in ``_dataset_concat``. Fallback to variable concatenation (with default index) when no implementation is given for ``Index.concat`` or when no all the variables have an index. - Refactored merging the other variables (also merge the indexes) in ``_dataset_concat``. This refactor is incomplete. It should work with Pandas (multi-)indexes but it is likely that it won't work with meta-indexes involving multiple dimensions. We probably need to update ``_calc_concat_over`` to take into account such meta-indexes and their related coordinates. Other limitation (?) we use here the index of the 1st dataset for the concatenation (i.e., ``Index.concat``). No specific check is made on the type and/or coordinates of the other indexes. * sel drop=True: remove multi-index scalar coords * PandasIndex.from_variables(multi-index level var) Make sure it gets the level index (not the whole multi-index). * reindex: disable invalid dimension check From the docstrings: mis-matched dimensions are simply ignored * add index concat tests + fix positions type * add Indexes.get_all_dims convenient method * refactor pad * unstack: return copies of mindex.levels Fix error when trying to set the name of the extracted level indexes. * strip_units: prevent index->array conversion which caused alignment conflicts due to multi-index variable converted to non-index variable. * attach_units: do not preserve coord multi-indexes To prevent conflicts when trying to implicitly create multi-index coordinates. * concat: avoid multiple loads of dask arrays * groupby mindex coord: propagate level names * Dataset.copy: don't reset indexes if data is given The `data` parameter accepts new data for data variables only, it won't affect the indexes. * reindex/align: fix coord dtype of new indexes Use `as_compatible_data` to get the right dtype to assign to the new index coordinates created from reindex indexers. * fix doctests * to_stacked_array: do not coerce level dtypes If the intent was to propagate the original coordinate dtypes, this is now handled by PandasMultiIndex.stack. This fixes the case where multi-index ``.levels`` do not return labels with "nan" values (if any) but where the coordinate does, which resulted in dtype mismatch between the coordinate variable dtype (e.g., coerced to int64) and the actual label values (float64 due to nan values), eventually causing numpy errors trying to coerce nan value to int64 when indexing the coordinate. * fix indent * doc: fix user-guide build errors * stack: fix new index variables not in coordinates * unstack full-reindex: fix alignment errors We need a mapping of all index coordinate names to the full index so that Aligner can find matching indexes. (Note: reindex may eventually be refactored to be a bit more clever so that providing only a `{dim: indexer}` will be enough in this case?) * PandasIndex coord dtype: avoid convert index->array * refactor diff * quick fix level coord dtype int32/64 on win * dask presist/compute dataarray: propagate indexes * refactor roll * rename Index.query -> Index.sel Also rename: - QueryResult -> IndexSelResult - merge_query_results -> merge_sel_results * remove Index.union and Index.intersection Those are not used elsewhere. ``Index.join`` is used instead for alignment. * use future annotations in indexes.py * Index.rename: return only the new index Create new coordinate variables using Index.create_variables instead (get rid of PandasIndex.from_pandas_index). * Index.stack: return only the new index Create new coordinate variables using Index.create_variables instead (get rid of PandasIndex.from_pandas_index) * PandasMultiIndex class methods: return only the index * wip get rid of Pandas(Multi)Index.from_pandas_index * remove Pandas(Multi)Index.from_pandas_index * Index.from_variables: return the new index only Use Index.create_variables to get the new coordinate variables. * rename: propagate_attrs_encoding not needed Index.create_variables already propagates coordinate variable metadata. * align exclude dims: pass through indexes * fix set_index append=True Level variables may have updated names in this case: single index + one level. * refactor default_indexes and re-enable invariant check Default indexes invariant check is now optional (enabled by default). * fix formatting errors * fix type * Dataset._indexes, DataArray._indexes: always dict Default indexes are created by the default constructors. Indexes are always passed explicitly by the fastpath constructors. * remove _level_coords property It was only used for plotting and has been replaced by simpler logic. * clean-up Use ``_indexes`` instead of ``.xindexes`` when possible, for speed. ``xindexes`` is not cached: it would return inconsistent results if updating the object in-place. Fix some types. * optimize Dataset/DataArray copy Avoid copying pandas indexes twice (indexes + coordinate variables) * add default implementation for Index.copy Better than raising an error? * add/fix indexes tests * tweak indexes formatting This will need more work (in a follow-up PR) to be consistent with the other data model components (i.e., coordinates and data variables): add summary (inline) and detailed reprs, maybe group by coordinates, etc. * fix doctests * PandasIndex.copy: avoid too many pd.Index copies * DataArray stack test: revert to original Now that default (range) indexes are created again (create_index=True). * test indexes/indexing: rename query -> sel * alignment: use future annotations * misc. tweaks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Aligner tweaks Access aligned objects through the ``.result`` attribute. Enable type checking in internal methods. * remove assert_unique_multiindex_level_names It is not needed anymore as multi-indexes have their own coordinates and we now check for possible name conflicts when "unpacking" multi-index levels at Dataset / DataArray creation. * remove propagate_attrs_encoding Only used in one place. * assert -> check + ValueError * misc. fixes and tweaks * update what's new * concat: remove fall-backs to Variable.concat Raise an error instead when trying to concat a mix of indexed and un-indexed coordinates or when the index doesn't support concat. We still support the concatenation of a mix of scalar and dimension indexed coordinates by creating a PandasIndex on-the-fly for scalar coordinates. * fix flaky test Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: keewis --- doc/user-guide/plotting.rst | 2 +- doc/whats-new.rst | 14 + xarray/core/alignment.py | 927 +++++++++++------- xarray/core/combine.py | 2 +- xarray/core/common.py | 5 +- xarray/core/computation.py | 58 +- xarray/core/concat.py | 158 ++-- xarray/core/coordinates.py | 61 +- xarray/core/dataarray.py | 306 +++--- xarray/core/dataset.py | 1272 ++++++++++++++----------- xarray/core/formatting.py | 178 ++-- xarray/core/formatting_html.py | 59 +- xarray/core/groupby.py | 25 +- xarray/core/indexes.py | 1305 +++++++++++++++++++++----- xarray/core/indexing.py | 366 +++++--- xarray/core/merge.py | 190 ++-- xarray/core/missing.py | 6 +- xarray/core/parallel.py | 16 +- xarray/core/types.py | 2 + xarray/core/utils.py | 55 +- xarray/core/variable.py | 64 +- xarray/plot/utils.py | 19 +- xarray/testing.py | 72 +- xarray/tests/__init__.py | 18 +- xarray/tests/test_backends.py | 4 +- xarray/tests/test_combine.py | 4 +- xarray/tests/test_computation.py | 4 +- xarray/tests/test_concat.py | 31 +- xarray/tests/test_dask.py | 6 +- xarray/tests/test_dataarray.py | 190 ++-- xarray/tests/test_dataset.py | 341 +++++-- xarray/tests/test_formatting.py | 4 +- xarray/tests/test_formatting_html.py | 24 +- xarray/tests/test_groupby.py | 12 +- xarray/tests/test_indexes.py | 587 ++++++++++-- xarray/tests/test_indexing.py | 194 +++- xarray/tests/test_merge.py | 2 +- xarray/tests/test_units.py | 22 +- xarray/tests/test_utils.py | 25 - 39 files changed, 4470 insertions(+), 2160 deletions(-) diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index d81ba30f12f..f514b4ecbef 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -251,7 +251,7 @@ Finally, if a dataset does not have any coordinates it enumerates all data point .. ipython:: python :okwarning: - air1d_multi = air1d_multi.drop("date") + air1d_multi = air1d_multi.drop(["date", "time", "decimal_day"]) air1d_multi.plot() The same applies to 2D plots below. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b22c6e4d858..05bdfcf78bb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,10 +22,18 @@ v2022.03.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add a ``create_index=True`` parameter to :py:meth:`Dataset.stack` and + :py:meth:`DataArray.stack` so that the creation of multi-indexes is optional + (:pull:`5692`). By `Benoît Bovy `_. +- Multi-index levels are now accessible through their own, regular coordinates + instead of virtual coordinates (:pull:`5692`). + By `Benoît Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ +- The Dataset and DataArray ``rename*`` methods do not implicitly add or drop + indexes. (:pull:`5692`). By `Benoît Bovy `_. Deprecations ~~~~~~~~~~~~ @@ -37,6 +45,9 @@ Bug fixes - Set ``skipna=None`` for all ``quantile`` methods (e.g. :py:meth:`Dataset.quantile`) and ensure it skips missing values for float dtypes (consistent with other methods). This should not change the behavior (:pull:`6303`). By `Mathias Hauser `_. +- Many bugs fixed by the explicit indexes refactor, mainly related to multi-index (virtual) + coordinates. See the corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. Documentation ~~~~~~~~~~~~~ @@ -45,6 +56,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Many internal changes due to the explicit indexes refactor. See the + corresponding pull-request on GitHub for more details. (:pull:`5692`). + By `Benoît Bovy `_. .. _whats-new.2022.02.0: .. _whats-new.2022.03.0: diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index f9342e2a82a..d201e3a613f 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator from collections import defaultdict @@ -5,84 +7,562 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, + Generic, Hashable, Iterable, Mapping, - Optional, Tuple, + Type, TypeVar, - Union, ) import numpy as np import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex, get_indexer_nd -from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index -from .variable import IndexVariable, Variable +from .common import DataWithCoords +from .indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_all_equal +from .utils import is_dict_like, is_full_slice, safe_cast_to_index +from .variable import Variable, as_compatible_data, calculate_dimensions if TYPE_CHECKING: - from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset - DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) - - -def _get_joiner(join, index_cls): - if join == "outer": - return functools.partial(functools.reduce, index_cls.union) - elif join == "inner": - return functools.partial(functools.reduce, index_cls.intersection) - elif join == "left": - return operator.itemgetter(0) - elif join == "right": - return operator.itemgetter(-1) - elif join == "exact": - # We cannot return a function to "align" in this case, because it needs - # access to the dimension name to give a good error message. - return None - elif join == "override": - # We rewrite all indexes and then use join='left' - return operator.itemgetter(0) - else: - raise ValueError(f"invalid value for join: {join}") +DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) + + +def reindex_variables( + variables: Mapping[Any, Variable], + dim_pos_indexers: Mapping[Any, Any], + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, +) -> dict[Hashable, Variable]: + """Conform a dictionary of variables onto a new set of variables reindexed + with dimension positional indexers and possibly filled with missing values. + + Not public API. + + """ + new_variables = {} + dim_sizes = calculate_dimensions(variables) + + masked_dims = set() + unchanged_dims = set() + for dim, indxr in dim_pos_indexers.items(): + # Negative values in dim_pos_indexers mean values missing in the new index + # See ``Index.reindex_like``. + if (indxr < 0).any(): + masked_dims.add(dim) + elif np.array_equal(indxr, np.arange(dim_sizes.get(dim, 0))): + unchanged_dims.add(dim) + + for name, var in variables.items(): + if isinstance(fill_value, dict): + fill_value_ = fill_value.get(name, dtypes.NA) + else: + fill_value_ = fill_value + + if sparse: + var = var._as_sparse(fill_value=fill_value_) + indxr = tuple( + slice(None) if d in unchanged_dims else dim_pos_indexers.get(d, slice(None)) + for d in var.dims + ) + needs_masking = any(d in masked_dims for d in var.dims) + + if needs_masking: + new_var = var._getitem_with_mask(indxr, fill_value=fill_value_) + elif all(is_full_slice(k) for k in indxr): + # no reindexing necessary + # here we need to manually deal with copying data, since + # we neither created a new ndarray nor used fancy indexing + new_var = var.copy(deep=copy) + else: + new_var = var[indxr] + + new_variables[name] = new_var + + return new_variables + + +CoordNamesAndDims = Tuple[Tuple[Hashable, Tuple[Hashable, ...]], ...] +MatchingIndexKey = Tuple[CoordNamesAndDims, Type[Index]] +NormalizedIndexes = Dict[MatchingIndexKey, Index] +NormalizedIndexVars = Dict[MatchingIndexKey, Dict[Hashable, Variable]] + + +class Aligner(Generic[DataAlignable]): + """Implements all the complex logic for the re-indexing and alignment of Xarray + objects. + + For internal use only, not public API. + Usage: + + aligner = Aligner(*objects, **kwargs) + aligner.align() + aligned_objects = aligner.results + + """ + + objects: tuple[DataAlignable, ...] + results: tuple[DataAlignable, ...] + objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...] + join: str + exclude_dims: frozenset[Hashable] + exclude_vars: frozenset[Hashable] + copy: bool + fill_value: Any + sparse: bool + indexes: dict[MatchingIndexKey, Index] + index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + aligned_indexes: dict[MatchingIndexKey, Index] + aligned_index_vars: dict[MatchingIndexKey, dict[Hashable, Variable]] + reindex: dict[MatchingIndexKey, bool] + reindex_kwargs: dict[str, Any] + unindexed_dim_sizes: dict[Hashable, set] + new_indexes: Indexes[Index] + + def __init__( + self, + objects: Iterable[DataAlignable], + join: str = "inner", + indexes: Mapping[Any, Any] = None, + exclude_dims: Iterable = frozenset(), + exclude_vars: Iterable[Hashable] = frozenset(), + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, + ): + self.objects = tuple(objects) + self.objects_matching_indexes = () + + if join not in ["inner", "outer", "override", "exact", "left", "right"]: + raise ValueError(f"invalid value for join: {join}") + self.join = join + + self.copy = copy + self.fill_value = fill_value + self.sparse = sparse + + if method is None and tolerance is None: + self.reindex_kwargs = {} + else: + self.reindex_kwargs = {"method": method, "tolerance": tolerance} + + if isinstance(exclude_dims, str): + exclude_dims = [exclude_dims] + self.exclude_dims = frozenset(exclude_dims) + self.exclude_vars = frozenset(exclude_vars) + + if indexes is None: + indexes = {} + self.indexes, self.index_vars = self._normalize_indexes(indexes) + + self.all_indexes = {} + self.all_index_vars = {} + self.unindexed_dim_sizes = {} + + self.aligned_indexes = {} + self.aligned_index_vars = {} + self.reindex = {} + + self.results = tuple() + def _normalize_indexes( + self, + indexes: Mapping[Any, Any], + ) -> tuple[NormalizedIndexes, NormalizedIndexVars]: + """Normalize the indexes/indexers used for re-indexing or alignment. -def _override_indexes(objects, all_indexes, exclude): - for dim, dim_indexes in all_indexes.items(): - if dim not in exclude: - lengths = { - getattr(index, "size", index.to_pandas_index().size) - for index in dim_indexes - } - if len(lengths) != 1: + Return dictionaries of xarray Index objects and coordinate variables + such that we can group matching indexes based on the dictionary keys. + + """ + if isinstance(indexes, Indexes): + xr_variables = dict(indexes.variables) + else: + xr_variables = {} + + xr_indexes: dict[Hashable, Index] = {} + for k, idx in indexes.items(): + if not isinstance(idx, Index): + if getattr(idx, "dims", (k,)) != (k,): + raise ValueError( + f"Indexer has dimensions {idx.dims} that are different " + f"from that to be indexed along '{k}'" + ) + data = as_compatible_data(idx) + pd_idx = safe_cast_to_index(data) + pd_idx.name = k + if isinstance(pd_idx, pd.MultiIndex): + idx = PandasMultiIndex(pd_idx, k) + else: + idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype) + xr_variables.update(idx.create_variables()) + xr_indexes[k] = idx + + normalized_indexes = {} + normalized_index_vars = {} + for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + coord_names_and_dims = [] + all_dims = set() + + for name, var in index_vars.items(): + dims = var.dims + coord_names_and_dims.append((name, dims)) + all_dims.update(dims) + + exclude_dims = all_dims & self.exclude_dims + if exclude_dims == all_dims: + continue + elif exclude_dims: + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) raise ValueError( - f"Indexes along dimension {dim!r} don't have the same length." - " Cannot use join='override'." + f"cannot exclude dimension(s) {excl_dims_str} from alignment because " + "these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" ) - objects = list(objects) - for idx, obj in enumerate(objects[1:]): - new_indexes = { - dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude - } + key = (tuple(coord_names_and_dims), type(idx)) + normalized_indexes[key] = idx + normalized_index_vars[key] = index_vars + + return normalized_indexes, normalized_index_vars + + def find_matching_indexes(self) -> None: + all_indexes: dict[MatchingIndexKey, list[Index]] + all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]] + all_indexes_dim_sizes: dict[MatchingIndexKey, dict[Hashable, set]] + objects_matching_indexes: list[dict[MatchingIndexKey, Index]] + + all_indexes = defaultdict(list) + all_index_vars = defaultdict(list) + all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set)) + objects_matching_indexes = [] + + for obj in self.objects: + obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes) + objects_matching_indexes.append(obj_indexes) + for key, idx in obj_indexes.items(): + all_indexes[key].append(idx) + for key, index_vars in obj_index_vars.items(): + all_index_vars[key].append(index_vars) + for dim, size in calculate_dimensions(index_vars).items(): + all_indexes_dim_sizes[key][dim].add(size) + + self.objects_matching_indexes = tuple(objects_matching_indexes) + self.all_indexes = all_indexes + self.all_index_vars = all_index_vars + + if self.join == "override": + for dim_sizes in all_indexes_dim_sizes.values(): + for dim, sizes in dim_sizes.items(): + if len(sizes) > 1: + raise ValueError( + "cannot align objects with join='override' with matching indexes " + f"along dimension {dim!r} that don't have the same size" + ) + + def find_matching_unindexed_dims(self) -> None: + unindexed_dim_sizes = defaultdict(set) + + for obj in self.objects: + for dim in obj.dims: + if dim not in self.exclude_dims and dim not in obj.xindexes.dims: + unindexed_dim_sizes[dim].add(obj.sizes[dim]) + + self.unindexed_dim_sizes = unindexed_dim_sizes + + def assert_no_index_conflict(self) -> None: + """Check for uniqueness of both coordinate and dimension names accross all sets + of matching indexes. + + We need to make sure that all indexes used for re-indexing or alignment + are fully compatible and do not conflict each other. + + Note: perhaps we could choose less restrictive constraints and instead + check for conflicts among the dimension (position) indexers returned by + `Index.reindex_like()` for each matching pair of object index / aligned + index? + (ref: https://github.com/pydata/xarray/issues/1603#issuecomment-442965602) + + """ + matching_keys = set(self.all_indexes) | set(self.indexes) + + coord_count: dict[Hashable, int] = defaultdict(int) + dim_count: dict[Hashable, int] = defaultdict(int) + for coord_names_dims, _ in matching_keys: + dims_set: set[Hashable] = set() + for name, dims in coord_names_dims: + coord_count[name] += 1 + dims_set.update(dims) + for dim in dims_set: + dim_count[dim] += 1 + + for count, msg in [(coord_count, "coordinates"), (dim_count, "dimensions")]: + dup = {k: v for k, v in count.items() if v > 1} + if dup: + items_msg = ", ".join( + f"{k!r} ({v} conflicting indexes)" for k, v in dup.items() + ) + raise ValueError( + "cannot re-index or align objects with conflicting indexes found for " + f"the following {msg}: {items_msg}\n" + "Conflicting indexes may occur when\n" + "- they relate to different sets of coordinate and/or dimension names\n" + "- they don't have the same type\n" + "- they may be used to reindex data along common dimensions" + ) - objects[idx + 1] = obj._overwrite_indexes(new_indexes) + def _need_reindex(self, dims, cmp_indexes) -> bool: + """Whether or not we need to reindex variables for a set of + matching indexes. + + We don't reindex when all matching indexes are equal for two reasons: + - It's faster for the usual case (already aligned objects). + - It ensures it's possible to do operations that don't require alignment + on indexes with duplicate values (which cannot be reindexed with + pandas). This is useful, e.g., for overwriting such duplicate indexes. + + """ + has_unindexed_dims = any(dim in self.unindexed_dim_sizes for dim in dims) + return not (indexes_all_equal(cmp_indexes)) or has_unindexed_dims + + def _get_index_joiner(self, index_cls) -> Callable: + if self.join in ["outer", "inner"]: + return functools.partial( + functools.reduce, + functools.partial(index_cls.join, how=self.join), + ) + elif self.join == "left": + return operator.itemgetter(0) + elif self.join == "right": + return operator.itemgetter(-1) + elif self.join == "override": + # We rewrite all indexes and then use join='left' + return operator.itemgetter(0) + else: + # join='exact' return dummy lambda (error is raised) + return lambda _: None + + def align_indexes(self) -> None: + """Compute all aligned indexes and their corresponding coordinate variables.""" + + aligned_indexes = {} + aligned_index_vars = {} + reindex = {} + new_indexes = {} + new_index_vars = {} + + for key, matching_indexes in self.all_indexes.items(): + matching_index_vars = self.all_index_vars[key] + dims = {d for coord in matching_index_vars[0].values() for d in coord.dims} + index_cls = key[1] + + if self.join == "override": + joined_index = matching_indexes[0] + joined_index_vars = matching_index_vars[0] + need_reindex = False + elif key in self.indexes: + joined_index = self.indexes[key] + joined_index_vars = self.index_vars[key] + cmp_indexes = list( + zip( + [joined_index] + matching_indexes, + [joined_index_vars] + matching_index_vars, + ) + ) + need_reindex = self._need_reindex(dims, cmp_indexes) + else: + if len(matching_indexes) > 1: + need_reindex = self._need_reindex( + dims, + list(zip(matching_indexes, matching_index_vars)), + ) + else: + need_reindex = False + if need_reindex: + if self.join == "exact": + raise ValueError( + "cannot align objects with join='exact' where " + "index/labels/sizes are not equal along " + "these coordinates (dimensions): " + + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0]) + ) + joiner = self._get_index_joiner(index_cls) + joined_index = joiner(matching_indexes) + if self.join == "left": + joined_index_vars = matching_index_vars[0] + elif self.join == "right": + joined_index_vars = matching_index_vars[-1] + else: + joined_index_vars = joined_index.create_variables() + else: + joined_index = matching_indexes[0] + joined_index_vars = matching_index_vars[0] + + reindex[key] = need_reindex + aligned_indexes[key] = joined_index + aligned_index_vars[key] = joined_index_vars + + for name, var in joined_index_vars.items(): + new_indexes[name] = joined_index + new_index_vars[name] = var + + # Explicitly provided indexes that are not found in objects to align + # may relate to unindexed dimensions so we add them too + for key, idx in self.indexes.items(): + if key not in aligned_indexes: + index_vars = self.index_vars[key] + reindex[key] = False + aligned_indexes[key] = idx + aligned_index_vars[key] = index_vars + for name, var in index_vars.items(): + new_indexes[name] = idx + new_index_vars[name] = var + + self.aligned_indexes = aligned_indexes + self.aligned_index_vars = aligned_index_vars + self.reindex = reindex + self.new_indexes = Indexes(new_indexes, new_index_vars) + + def assert_unindexed_dim_sizes_equal(self) -> None: + for dim, sizes in self.unindexed_dim_sizes.items(): + index_size = self.new_indexes.dims.get(dim) + if index_size is not None: + sizes.add(index_size) + add_err_msg = ( + f" (note: an index is found along that dimension " + f"with size={index_size!r})" + ) + else: + add_err_msg = "" + if len(sizes) > 1: + raise ValueError( + f"cannot reindex or align along dimension {dim!r} " + f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg + ) - return objects + def override_indexes(self) -> None: + objects = list(self.objects) + + for i, obj in enumerate(objects[1:]): + new_indexes = {} + new_variables = {} + matching_indexes = self.objects_matching_indexes[i + 1] + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + for name, var in self.aligned_index_vars[key].items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + + objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables) + + self.results = tuple(objects) + + def _get_dim_pos_indexers( + self, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> dict[Hashable, Any]: + dim_pos_indexers = {} + + for key, aligned_idx in self.aligned_indexes.items(): + obj_idx = matching_indexes.get(key) + if obj_idx is not None: + if self.reindex[key]: + indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs) # type: ignore[call-arg] + dim_pos_indexers.update(indexers) + + return dim_pos_indexers + + def _get_indexes_and_vars( + self, + obj: DataAlignable, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes = {} + new_variables = {} + + for key, aligned_idx in self.aligned_indexes.items(): + index_vars = self.aligned_index_vars[key] + obj_idx = matching_indexes.get(key) + if obj_idx is None: + # add the index if it relates to unindexed dimensions in obj + index_vars_dims = {d for var in index_vars.values() for d in var.dims} + if index_vars_dims <= set(obj.dims): + obj_idx = aligned_idx + if obj_idx is not None: + for name, var in index_vars.items(): + new_indexes[name] = aligned_idx + new_variables[name] = var + + return new_indexes, new_variables + + def _reindex_one( + self, + obj: DataAlignable, + matching_indexes: dict[MatchingIndexKey, Index], + ) -> DataAlignable: + new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) + dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) + + new_obj = obj._reindex_callback( + self, + dim_pos_indexers, + new_variables, + new_indexes, + self.fill_value, + self.exclude_dims, + self.exclude_vars, + ) + new_obj.encoding = obj.encoding + return new_obj + + def reindex_all(self) -> None: + self.results = tuple( + self._reindex_one(obj, matching_indexes) + for obj, matching_indexes in zip( + self.objects, self.objects_matching_indexes + ) + ) + + def align(self) -> None: + if not self.indexes and len(self.objects) == 1: + # fast path for the trivial case + (obj,) = self.objects + self.results = (obj.copy(deep=self.copy),) + + self.find_matching_indexes() + self.find_matching_unindexed_dims() + self.assert_no_index_conflict() + self.align_indexes() + self.assert_unindexed_dim_sizes_equal() + + if self.join == "override": + self.override_indexes() + else: + self.reindex_all() def align( - *objects: "DataAlignable", + *objects: DataAlignable, join="inner", copy=True, indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -) -> Tuple["DataAlignable", ...]: +) -> tuple[DataAlignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -251,8 +731,7 @@ def align( >>> a, b = xr.align(x, y, join="exact") Traceback (most recent call last): ... - "indexes along dimension {!r} are not equal".format(dim) - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' ... >>> a, b = xr.align(x, y, join="override") >>> a @@ -271,107 +750,16 @@ def align( * lon (lon) float64 100.0 120.0 """ - if indexes is None: - indexes = {} - - if not indexes and len(objects) == 1: - # fast path for the trivial case - (obj,) = objects - return (obj.copy(deep=copy),) - - all_indexes = defaultdict(list) - all_coords = defaultdict(list) - unlabeled_dim_sizes = defaultdict(set) - for obj in objects: - for dim in obj.dims: - if dim not in exclude: - all_coords[dim].append(obj.coords[dim]) - try: - index = obj.xindexes[dim] - except KeyError: - unlabeled_dim_sizes[dim].add(obj.sizes[dim]) - else: - all_indexes[dim].append(index) - - if join == "override": - objects = _override_indexes(objects, all_indexes, exclude) - - # We don't reindex over dimensions with all equal indexes for two reasons: - # - It's faster for the usual case (already aligned objects). - # - It ensures it's possible to do operations that don't require alignment - # on indexes with duplicate values (which cannot be reindexed with - # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joined_indexes = {} - for dim, matching_indexes in all_indexes.items(): - if dim in indexes: - index, _ = PandasIndex.from_pandas_index( - safe_cast_to_index(indexes[dim]), dim - ) - if ( - any(not index.equals(other) for other in matching_indexes) - or dim in unlabeled_dim_sizes - ): - joined_indexes[dim] = indexes[dim] - else: - if ( - any( - not matching_indexes[0].equals(other) - for other in matching_indexes[1:] - ) - or dim in unlabeled_dim_sizes - ): - if join == "exact": - raise ValueError(f"indexes along dimension {dim!r} are not equal") - joiner = _get_joiner(join, type(matching_indexes[0])) - index = joiner(matching_indexes) - # make sure str coords are not cast to object - index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim]) - joined_indexes[dim] = index - else: - index = all_coords[dim][0] - - if dim in unlabeled_dim_sizes: - unlabeled_sizes = unlabeled_dim_sizes[dim] - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - if isinstance(index, PandasIndex): - labeled_size = index.to_pandas_index().size - else: - labeled_size = index.size - if len(unlabeled_sizes | {labeled_size}) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension size(s) {unlabeled_sizes!r} " - f"than the size of the aligned dimension labels: {labeled_size!r}" - ) - - for dim, sizes in unlabeled_dim_sizes.items(): - if dim not in all_indexes and len(sizes) > 1: - raise ValueError( - f"arguments without labels along dimension {dim!r} cannot be " - f"aligned because they have different dimension sizes: {sizes!r}" - ) - - result = [] - for obj in objects: - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - valid_indexers = {} - for k, index in joined_indexes.items(): - if k in obj.dims: - if isinstance(index, Index): - valid_indexers[k] = index.to_pandas_index() - else: - valid_indexers[k] = index - if not valid_indexers: - # fast path for no reindexing necessary - new_obj = obj.copy(deep=copy) - else: - new_obj = obj.reindex( - copy=copy, fill_value=fill_value, indexers=valid_indexers - ) - new_obj.encoding = obj.encoding - result.append(new_obj) - - return tuple(result) + aligner = Aligner( + objects, + join=join, + copy=copy, + indexes=indexes, + exclude_dims=exclude, + fill_value=fill_value, + ) + aligner.align() + return aligner.results def deep_align( @@ -457,197 +845,78 @@ def is_alignable(obj): return out -def reindex_like_indexers( - target: "Union[DataArray, Dataset]", other: "Union[DataArray, Dataset]" -) -> Dict[Hashable, pd.Index]: - """Extract indexers to align target with other. - - Not public API. - - Parameters - ---------- - target : Dataset or DataArray - Object to be aligned. - other : Dataset or DataArray - Object to be aligned with. - - Returns - ------- - Dict[Hashable, pandas.Index] providing indexes for reindex keyword - arguments. - - Raises - ------ - ValueError - If any dimensions without labels have different sizes. - """ - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 - # this doesn't support yet indexes other than pd.Index - indexers = { - k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims - } - - for dim in other.dims: - if dim not in indexers and dim in target.dims: - other_size = other.sizes[dim] - target_size = target.sizes[dim] - if other_size != target_size: - raise ValueError( - "different size for unlabeled " - f"dimension on argument {dim!r}: {other_size!r} vs {target_size!r}" - ) - return indexers - - -def reindex_variables( - variables: Mapping[Any, Variable], - sizes: Mapping[Any, int], - indexes: Mapping[Any, Index], - indexers: Mapping, - method: Optional[str] = None, - tolerance: Union[Union[int, float], Iterable[Union[int, float]]] = None, +def reindex( + obj: DataAlignable, + indexers: Mapping[Any, Any], + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, - fill_value: Optional[Any] = dtypes.NA, + fill_value: Any = dtypes.NA, sparse: bool = False, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: - """Conform a dictionary of aligned variables onto a new set of variables, - filling in missing values with NaN. + exclude_vars: Iterable[Hashable] = frozenset(), +) -> DataAlignable: + """Re-index either a Dataset or a DataArray. Not public API. - Parameters - ---------- - variables : dict-like - Dictionary of xarray.Variable objects. - sizes : dict-like - Dictionary from dimension names to integer sizes. - indexes : dict-like - Dictionary of indexes associated with variables. - indexers : dict - Dictionary with keys given by dimension names and values given by - arrays of coordinates tick labels. Any mis-matched coordinate values - will be filled in with NaN, and any mis-matched dimension names will - simply be ignored. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional - Method to use for filling index values in ``indexers`` not found in - this dataset: - * None (default): don't fill gaps - * pad / ffill: propagate last valid index value forward - * backfill / bfill: propagate next valid index value backward - * nearest: use nearest valid index value - tolerance : optional - Maximum distance between original and new labels for inexact matches. - The values of the index at the matching locations must satisfy the - equation ``abs(index[indexer] - target) <= tolerance``. - Tolerance may be a scalar value, which applies the same tolerance - to all values, or list-like, which applies variable tolerance per - element. List-like must be the same size as the index and its dtype - must exactly match the index’s type. - copy : bool, optional - If ``copy=True``, data in the return values is always copied. If - ``copy=False`` and reindexing is unnecessary, or can be performed - with only slice operations, then the output may share memory with - the input. In either case, new xarray objects are always returned. - fill_value : scalar, optional - Value to use for newly missing values - sparse : bool, optional - Use an sparse-array - - Returns - ------- - reindexed : dict - Dict of reindexed variables. - new_indexes : dict - Dict of indexes associated with the reindexed variables. """ - from .dataarray import DataArray - - # create variables for the new dataset - reindexed: Dict[Hashable, Variable] = {} - - # build up indexers for assignment along each dimension - int_indexers = {} - new_indexes = dict(indexes) - masked_dims = set() - unchanged_dims = set() - - for dim, indexer in indexers.items(): - if isinstance(indexer, DataArray) and indexer.dims != (dim,): - raise ValueError( - "Indexer has dimensions {:s} that are different " - "from that to be indexed along {:s}".format(str(indexer.dims), dim) - ) - - target = safe_cast_to_index(indexers[dim]) - new_indexes[dim] = PandasIndex(target, dim) - - if dim in indexes: - # TODO (benbovy - flexible indexes): support other indexes than pd.Index? - index = indexes[dim].to_pandas_index() - - if not index.is_unique: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} because the " - "index has duplicate values" - ) - - int_indexer = get_indexer_nd(index, target, method, tolerance) - - # We uses negative values from get_indexer_nd to signify - # values that are missing in the index. - if (int_indexer < 0).any(): - masked_dims.add(dim) - elif np.array_equal(int_indexer, np.arange(len(index))): - unchanged_dims.add(dim) - int_indexers[dim] = int_indexer - - if dim in variables: - var = variables[dim] - args: tuple = (var.attrs, var.encoding) - else: - args = () - reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) - - for dim in sizes: - if dim not in indexes and dim in indexers: - existing_size = sizes[dim] - new_size = indexers[dim].size - if existing_size != new_size: - raise ValueError( - f"cannot reindex or align along dimension {dim!r} without an " - f"index because its size {existing_size!r} is different from the size of " - f"the new index {new_size!r}" - ) + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # bad_keys = [k for k in indexers if k not in obj._indexes and k not in obj.dims] + # if bad_keys: + # raise ValueError( + # f"indexer keys {bad_keys} do not correspond to any indexed coordinate " + # "or unindexed dimension in the object to reindex" + # ) + + aligner = Aligner( + (obj,), + indexes=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + sparse=sparse, + exclude_vars=exclude_vars, + ) + aligner.align() + return aligner.results[0] - for name, var in variables.items(): - if name not in indexers: - if isinstance(fill_value, dict): - fill_value_ = fill_value.get(name, dtypes.NA) - else: - fill_value_ = fill_value - if sparse: - var = var._as_sparse(fill_value=fill_value_) - key = tuple( - slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) - for d in var.dims - ) - needs_masking = any(d in masked_dims for d in var.dims) - - if needs_masking: - new_var = var._getitem_with_mask(key, fill_value=fill_value_) - elif all(is_full_slice(k) for k in key): - # no reindexing necessary - # here we need to manually deal with copying data, since - # we neither created a new ndarray nor used fancy indexing - new_var = var.copy(deep=copy) - else: - new_var = var[key] +def reindex_like( + obj: DataAlignable, + other: Dataset | DataArray, + method: str = None, + tolerance: int | float | Iterable[int | float] | None = None, + copy: bool = True, + fill_value: Any = dtypes.NA, +) -> DataAlignable: + """Re-index either a Dataset or a DataArray like another Dataset/DataArray. - reindexed[name] = new_var + Not public API. - return reindexed, new_indexes + """ + if not other._indexes: + # This check is not performed in Aligner. + for dim in other.dims: + if dim in obj.dims: + other_size = other.sizes[dim] + obj_size = obj.sizes[dim] + if other_size != obj_size: + raise ValueError( + "different size for unlabeled " + f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}" + ) + + return reindex( + obj, + indexers=other.xindexes, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, + ) def _get_broadcast_dims_map_common_coords(args, exclude): diff --git a/xarray/core/combine.py b/xarray/core/combine.py index d23a58522e6..78f016fdccd 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -86,7 +86,7 @@ def _infer_concat_order_from_coords(datasets): if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds.xindexes.get(dim) for ds in datasets] + indexes = [ds._indexes.get(dim) for ds in datasets] if any(index is None for index in indexes): raise ValueError( "Every dimension needs a coordinate for " diff --git a/xarray/core/common.py b/xarray/core/common.py index cb6da986892..c33db4a62ea 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -412,7 +412,7 @@ def get_index(self, key: Hashable) -> pd.Index: raise KeyError(key) try: - return self.xindexes[key].to_pandas_index() + return self._indexes[key].to_pandas_index() except KeyError: return pd.Index(range(self.sizes[key]), name=key) @@ -1159,8 +1159,7 @@ def resample( category=FutureWarning, ) - # TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass - if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex): + if isinstance(self._indexes[dim_name].to_pandas_index(), CFTimeIndex): from .resample_cftime import CFTimeGrouper grouper = CFTimeGrouper(freq, closed, label, base, loffset) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ce37251576a..7676d8e558c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -23,6 +23,7 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align +from .indexes import Index, filter_indexes_from_coords from .merge import merge_attrs, merge_coordinates_without_align from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array @@ -204,13 +205,13 @@ def _get_coords_list(args) -> list[Coordinates]: return coords_list -def build_output_coords( +def build_output_coords_and_indexes( args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", -) -> list[dict[Any, Variable]]: - """Build output coordinates for an operation. +) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: + """Build output coordinates and indexes for an operation. Parameters ---------- @@ -225,7 +226,7 @@ def build_output_coords( Returns ------- - Dictionary of Variable objects with merged coordinates. + Dictionaries of Variable and Index objects with merged coordinates. """ coords_list = _get_coords_list(args) @@ -233,24 +234,30 @@ def build_output_coords( # we can skip the expensive merge (unpacked_coords,) = coords_list merged_vars = dict(unpacked_coords.variables) + merged_indexes = dict(unpacked_coords.xindexes) else: - # TODO: save these merged indexes, instead of re-computing them later - merged_vars, unused_indexes = merge_coordinates_without_align( + merged_vars, merged_indexes = merge_coordinates_without_align( coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] + output_indexes = [] for output_dims in signature.output_core_dims: dropped_dims = signature.all_input_core_dims - set(output_dims) if dropped_dims: - filtered = { + filtered_coords = { k: v for k, v in merged_vars.items() if dropped_dims.isdisjoint(v.dims) } + filtered_indexes = filter_indexes_from_coords( + merged_indexes, set(filtered_coords) + ) else: - filtered = merged_vars - output_coords.append(filtered) + filtered_coords = merged_vars + filtered_indexes = merged_indexes + output_coords.append(filtered_coords) + output_indexes.append(filtered_indexes) - return output_coords + return output_coords, output_indexes def apply_dataarray_vfunc( @@ -278,7 +285,7 @@ def apply_dataarray_vfunc( else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - result_coords = build_output_coords( + result_coords, result_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) @@ -287,12 +294,19 @@ def apply_dataarray_vfunc( if signature.num_outputs > 1: out = tuple( - DataArray(variable, coords, name=name, fastpath=True) - for variable, coords in zip(result_var, result_coords) + DataArray( + variable, coords=coords, indexes=indexes, name=name, fastpath=True + ) + for variable, coords, indexes in zip( + result_var, result_coords, result_indexes + ) ) else: (coords,) = result_coords - out = DataArray(result_var, coords, name=name, fastpath=True) + (indexes,) = result_indexes + out = DataArray( + result_var, coords=coords, indexes=indexes, name=name, fastpath=True + ) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): @@ -391,7 +405,9 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: dict[Hashable, Variable], + coord_variables: Mapping[Hashable, Variable], + indexes: dict[Hashable, Index], ) -> Dataset: """Create a dataset as quickly as possible. @@ -401,7 +417,7 @@ def _fast_dataset( variables.update(coord_variables) coord_names = set(coord_variables) - return Dataset._construct_direct(variables, coord_names) + return Dataset._construct_direct(variables, coord_names, indexes=indexes) def apply_dataset_vfunc( @@ -433,7 +449,7 @@ def apply_dataset_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords( + list_of_coords, list_of_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) args = [getattr(arg, "data_vars", arg) for arg in args] @@ -443,10 +459,14 @@ def apply_dataset_vfunc( ) if signature.num_outputs > 1: - out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords)) + out = tuple( + _fast_dataset(*args) + for args in zip(result_vars, list_of_coords, list_of_indexes) + ) else: (coord_vars,) = list_of_coords - out = _fast_dataset(result_vars, coord_vars) + (indexes,) = list_of_indexes + out = _fast_dataset(result_vars, coord_vars, indexes=indexes) attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) if isinstance(out, tuple): diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 1e6e246322e..8ee4672c49a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,14 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Hashable, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, overload import pandas as pd from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv -from .merge import _VALID_COMPAT, merge_attrs, unique_variable -from .variable import IndexVariable, Variable, as_variable +from .indexes import Index, PandasIndex +from .merge import ( + _VALID_COMPAT, + collect_variables_and_indexes, + merge_attrs, + merge_collected, +) +from .variable import Variable from .variable import concat as concat_vars if TYPE_CHECKING: @@ -28,7 +34,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -43,7 +49,7 @@ def concat( data_vars: concat_options | list[Hashable] = "all", coords: concat_options | list[Hashable] = "different", compat: compat_options = "equals", - positions: Iterable[int] | None = None, + positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -240,30 +246,31 @@ def concat( ) -def _calc_concat_dim_coord(dim): - """ - Infer the dimension name and 1d coordinate variable (if appropriate) +def _calc_concat_dim_index( + dim_or_data: Hashable | Any, +) -> tuple[Hashable, PandasIndex | None]: + """Infer the dimension name and 1d index / coordinate variable (if appropriate) for concatenating along the new dimension. + """ from .dataarray import DataArray - if isinstance(dim, str): - coord = None - elif not isinstance(dim, (DataArray, Variable)): - dim_name = getattr(dim, "name", None) - if dim_name is None: - dim_name = "concat_dim" - coord = IndexVariable(dim_name, dim) - dim = dim_name - elif not isinstance(dim, DataArray): - coord = as_variable(dim).to_index_variable() - (dim,) = coord.dims + dim: Hashable | None + + if isinstance(dim_or_data, str): + dim = dim_or_data + index = None else: - coord = dim - if coord.name is None: - coord.name = dim.dims[0] - (dim,) = coord.dims - return dim, coord + if not isinstance(dim_or_data, (DataArray, Variable)): + dim = getattr(dim_or_data, "name", None) + if dim is None: + dim = "concat_dim" + else: + (dim,) = dim_or_data.dims + coord_dtype = getattr(dim_or_data, "dtype", None) + index = PandasIndex(dim_or_data, dim, coord_dtype=coord_dtype) + + return dim, index def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat): @@ -414,7 +421,7 @@ def _dataset_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", @@ -431,7 +438,8 @@ def _dataset_concat( "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" ) - dim, coord = _calc_concat_dim_coord(dim) + dim, index = _calc_concat_dim_index(dim) + # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] datasets = list( @@ -464,22 +472,20 @@ def _dataset_concat( variables_to_merge = (coord_names | data_names) - concat_over - dim_names result_vars = {} + result_indexes = {} + if variables_to_merge: - to_merge: dict[Hashable, list[Variable]] = { - var: [] for var in variables_to_merge + grouped = { + k: v + for k, v in collect_variables_and_indexes(list(datasets)).items() + if k in variables_to_merge } + merged_vars, merged_indexes = merge_collected( + grouped, compat=compat, equals=equals + ) + result_vars.update(merged_vars) + result_indexes.update(merged_indexes) - for ds in datasets: - for var in variables_to_merge: - if var in ds: - to_merge[var].append(ds.variables[var]) - - for var in variables_to_merge: - result_vars[var] = unique_variable( - var, to_merge[var], compat=compat, equals=equals.get(var, None) - ) - else: - result_vars = {} result_vars.update(dim_coords) # assign attrs and encoding from first dataset @@ -506,22 +512,64 @@ def ensure_common_dims(vars): var = var.set_dims(common_dims, common_shape) yield var - # stack up each variable to fill-out the dataset (in order) + # get the indexes to concatenate together, create a PandasIndex + # for any scalar coordinate variable found with ``name`` matching ``dim``. + # TODO: depreciate concat a mix of scalar and dimensional indexed coodinates? + # TODO: (benbovy - explicit indexes): check index types and/or coordinates + # of all datasets? + def get_indexes(name): + for ds in datasets: + if name in ds._indexes: + yield ds._indexes[name] + elif name == dim: + var = ds._variables[name] + if not var.dims: + yield PandasIndex([var.values], dim) + + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for k in datasets[0].variables: - if k in concat_over: + for name in datasets[0].variables: + if name in concat_over and name not in result_indexes: try: - vars = ensure_common_dims([ds[k].variable for ds in datasets]) + vars = ensure_common_dims([ds[name].variable for ds in datasets]) except KeyError: - raise ValueError(f"{k!r} is not present in all datasets.") - combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) - assert isinstance(combined, Variable) - result_vars[k] = combined - elif k in result_vars: + raise ValueError(f"{name!r} is not present in all datasets.") + + # Try concatenate the indexes, concatenate the variables when no index + # is found on all datasets. + indexes: list[Index] = list(get_indexes(name)) + if indexes: + if len(indexes) < len(datasets): + raise ValueError( + f"{name!r} must have either an index or no index in all datasets, " + f"found {len(indexes)}/{len(datasets)} datasets with an index." + ) + combined_idx = indexes[0].concat(indexes, dim, positions) + if name in datasets[0]._indexes: + idx_vars = datasets[0].xindexes.get_all_coords(name) + else: + # index created from a scalar coordinate + idx_vars = {name: datasets[0][name].variable} + result_indexes.update({k: combined_idx for k in idx_vars}) + combined_idx_vars = combined_idx.create_variables(idx_vars) + for k, v in combined_idx_vars.items(): + v.attrs = merge_attrs( + [ds.variables[k].attrs for ds in datasets], + combine_attrs=combine_attrs, + ) + result_vars[k] = v + else: + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + result_vars[name] = combined_var + + elif name in result_vars: # preserves original variable order - result_vars[k] = result_vars.pop(k) + result_vars[name] = result_vars.pop(name) result = Dataset(result_vars, attrs=result_attrs) + absent_coord_names = coord_names - set(result.variables) if absent_coord_names: raise ValueError( @@ -532,9 +580,13 @@ def ensure_common_dims(vars): result = result.drop_vars(unlabeled_dims, errors="ignore") - if coord is not None: - # add concat dimension last to ensure that its in the final Dataset - result[coord.name] = coord + if index is not None: + # add concat index / coordinate last to ensure that its in the final Dataset + result[dim] = index.create_variables()[dim] + result_indexes[dim] = index + + # TODO: add indexes at Dataset creation (when it is supported) + result = result._overwrite_indexes(result_indexes) return result @@ -545,7 +597,7 @@ def _dataarray_concat( data_vars: str | list[str], coords: str | list[str], compat: str, - positions: Iterable[int] | None, + positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: str = "override", diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9dfd64e9c99..458be214f81 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -16,11 +16,11 @@ import numpy as np import pandas as pd -from . import formatting, indexing -from .indexes import Index, Indexes +from . import formatting +from .indexes import Index, Indexes, assert_no_index_corrupted from .merge import merge_coordinates_without_align, merge_coords -from .utils import Frozen, ReprObject, either_dict_or_kwargs -from .variable import Variable +from .utils import Frozen, ReprObject +from .variable import Variable, calculate_dimensions if TYPE_CHECKING: from .dataarray import DataArray @@ -49,11 +49,11 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: raise NotImplementedError() @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: return self._data.indexes # type: ignore[attr-defined] @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: return self._data.xindexes # type: ignore[attr-defined] @property @@ -272,8 +272,6 @@ def to_dataset(self) -> "Dataset": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - variables = self._data._variables.copy() variables.update(coords) @@ -335,8 +333,6 @@ def __getitem__(self, key: Hashable) -> "DataArray": def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: - from .dataset import calculate_dimensions - coords_plus_data = coords.copy() coords_plus_data[_THIS_ARRAY] = self._data.variable dims = calculate_dimensions(coords_plus_data) @@ -360,11 +356,13 @@ def to_dataset(self) -> "Dataset": from .dataset import Dataset coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()} - return Dataset._construct_direct(coords, set(coords)) + indexes = dict(self._data.xindexes) + return Dataset._construct_direct(coords, set(coords), indexes=indexes) def __delitem__(self, key: Hashable) -> None: if key not in self: raise KeyError(f"{key!r} is not a coordinate variable.") + assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] if self._data._indexes is not None and key in self._data._indexes: @@ -390,44 +388,3 @@ def assert_coordinate_consistent( f"dimension coordinate {k!r} conflicts between " f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" ) - - -def remap_label_indexers( - obj: Union["DataArray", "Dataset"], - indexers: Mapping[Any, Any] = None, - method: str = None, - tolerance=None, - **indexers_kwargs: Any, -) -> Tuple[dict, dict]: # TODO more precise return type after annotations in indexing - """Remap indexers from obj.coords. - If indexer is an instance of DataArray and it has coordinate, then this coordinate - will be attached to pos_indexers. - - Returns - ------- - pos_indexers: Same type of indexers. - np.ndarray or Variable or DataArray - new_indexes: mapping of new dimensional-coordinate. - """ - from .dataarray import DataArray - - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "remap_label_indexers") - - v_indexers = { - k: v.variable.data if isinstance(v, DataArray) else v - for k, v in indexers.items() - } - - pos_indexers, new_indexes = indexing.remap_label_indexers( - obj, v_indexers, method=method, tolerance=tolerance - ) - # attach indexer's coordinate to pos_indexers - for k, v in indexers.items(): - if isinstance(v, Variable): - pos_indexers[k] = Variable(v.dims, pos_indexers[k]) - elif isinstance(v, DataArray): - # drop coordinates found in indexers since .sel() already - # ensures alignments - coords = {k: var for k, var in v._coords.items() if k not in indexers} - pos_indexers[k] = DataArray(pos_indexers[k], coords=coords, dims=v.dims) - return pos_indexers, new_indexes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d7c3fd9bab7..df1e096b021 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -22,6 +22,7 @@ from ..plot.plot import _PlotMethods from ..plot.utils import _get_units_from_attrs from . import ( + alignment, computation, dtypes, groupby, @@ -35,25 +36,22 @@ from ._reductions import DataArrayReductions from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor -from .alignment import ( - _broadcast_helper, - _get_broadcast_dims_map_common_coords, - align, - reindex_like_indexers, -) +from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords, get_chunksizes from .computation import unify_chunks -from .coordinates import ( - DataArrayCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) -from .dataset import Dataset, split_indexes +from .coordinates import DataArrayCoordinates, assert_coordinate_consistent +from .dataset import Dataset from .formatting import format_item -from .indexes import Index, Indexes, default_indexes, propagate_indexes -from .indexing import is_fancy_indexer -from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords +from .indexes import ( + Index, + Indexes, + PandasMultiIndex, + filter_indexes_from_coords, + isel_indexes, +) +from .indexing import is_fancy_indexer, map_index_queries +from .merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords from .npcompat import QUANTILE_METHODS, ArrayLike from .options import OPTIONS, _get_keep_attrs from .utils import ( @@ -63,13 +61,7 @@ _default, either_dict_or_kwargs, ) -from .variable import ( - IndexVariable, - Variable, - as_compatible_data, - as_variable, - assert_unique_multiindex_level_names, -) +from .variable import IndexVariable, Variable, as_compatible_data, as_variable if TYPE_CHECKING: try: @@ -163,8 +155,6 @@ def _infer_coords_and_dims( "matching the dimension size" ) - assert_unique_multiindex_level_names(new_coords) - return new_coords, dims @@ -205,8 +195,8 @@ def __setitem__(self, key, value) -> None: labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - pos_indexers, _ = remap_label_indexers(self.data_array, key) - self.data_array[pos_indexers] = value + dim_indexers = map_index_queries(self.data_array, key).dim_indexers + self.data_array[dim_indexers] = value # Used as the key corresponding to a DataArray's variable when converting @@ -343,7 +333,7 @@ class DataArray( _cache: dict[str, Any] _coords: dict[Any, Variable] _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _name: Hashable | None _variable: Variable @@ -373,14 +363,20 @@ def __init__( name: Hashable = None, attrs: Mapping = None, # internal parameters - indexes: dict[Hashable, pd.Index] = None, + indexes: dict[Hashable, Index] = None, fastpath: bool = False, ): if fastpath: variable = data assert dims is None assert attrs is None + assert indexes is not None else: + # TODO: (benbovy - explicit indexes) remove + # once it becomes part of the public interface + if indexes is not None: + raise ValueError("Providing explicit indexes is not supported yet") + # try to fill in arguments from data if they weren't supplied if coords is None: @@ -404,9 +400,7 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes = dict( - _extract_indexes_from_coords(coords) - ) # needed for to_dataset + indexes, coords = _create_indexes_from_coords(coords) # These fully describe a DataArray self._variable = variable @@ -416,10 +410,29 @@ def __init__( # TODO(shoyer): document this argument, once it becomes part of the # public interface. - self._indexes = indexes + self._indexes = indexes # type: ignore[assignment] self._close = None + @classmethod + def _construct_direct( + cls, + variable: Variable, + coords: dict[Any, Variable], + name: Hashable, + indexes: dict[Hashable, Index], + ) -> DataArray: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + obj = object.__new__(cls) + obj._variable = variable + obj._coords = coords + obj._name = name + obj._indexes = indexes + obj._close = None + return obj + def _replace( self: T_DataArray, variable: Variable = None, @@ -431,9 +444,11 @@ def _replace( variable = self.variable if coords is None: coords = self._coords + if indexes is None: + indexes = self._indexes if name is _default: name = self.name - return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) + return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( self, variable: Variable, name: Hashable | None | Default = _default @@ -449,37 +464,49 @@ def _replace_maybe_drop_dims( for k, v in self._coords.items() if v.shape == tuple(new_sizes[d] for d in v.dims) } - changed_dims = [ - k for k in variable.dims if variable.sizes[k] != self.sizes[k] - ] - indexes = propagate_indexes(self._indexes, exclude=changed_dims) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) else: allowed_dims = set(variable.dims) coords = { k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims } - indexes = propagate_indexes( - self._indexes, exclude=(set(self.dims) - allowed_dims) - ) + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> DataArray: - if not len(indexes): + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + coords: Mapping[Any, Variable] = None, + drop_coords: list[Hashable] = None, + rename_dims: Mapping[Any, Any] = None, + ) -> DataArray: + """Maybe replace indexes and their corresponding coordinates.""" + if not indexes: return self - coords = self._coords.copy() - for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx.to_pandas_index()) - obj = self._replace(coords=coords) - - # switch from dimension to level names, if necessary - dim_names: dict[Any, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + + if coords is None: + coords = {} + if drop_coords is None: + drop_coords = [] + + new_variable = self.variable.copy() + new_coords = self._coords.copy() + new_indexes = dict(self._indexes) + + for name in indexes: + new_coords[name] = coords[name] + new_indexes[name] = indexes[name] + + for name in drop_coords: + new_coords.pop(name) + new_indexes.pop(name) + + if rename_dims: + new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims] + + return self._replace( + variable=new_variable, coords=new_coords, indexes=new_indexes + ) def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) @@ -502,8 +529,8 @@ def subset(dim, label): variables = {label: subset(dim, label) for label in self.get_index(dim)} variables.update({k: v for k, v in self._coords.items() if k != dim}) - indexes = propagate_indexes(self._indexes, exclude=dim) coord_names = set(self._coords) - {dim} + indexes = filter_indexes_from_coords(self._indexes, coord_names) dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs ) @@ -708,21 +735,6 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - @property - def _level_coords(self) -> dict[Hashable, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[Hashable, Hashable] = {} - - for cname, var in self._coords.items(): - if var.ndim == 1 and isinstance(var, IndexVariable): - level_names = var.level_names - if level_names is not None: - (dim,) = var.dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _getitem_coord(self, key): from .dataset import _get_virtual_variable @@ -730,9 +742,7 @@ def _getitem_coord(self, key): var = self._coords[key] except KeyError: dim_sizes = dict(zip(self.dims, self.shape)) - _, key, var = _get_virtual_variable( - self._coords, key, self._level_coords, dim_sizes - ) + _, key, var = _get_virtual_variable(self._coords, key, dim_sizes) return self._replace_maybe_drop_dims(var, name=key) @@ -777,7 +787,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates # uses empty dict -- everything here can already be found in self.coords. yield HybridMappingProxy(keys=self.dims, mapping={}) - yield HybridMappingProxy(keys=self._level_coords, mapping={}) def __contains__(self, key: Any) -> bool: return key in self.data @@ -820,14 +829,12 @@ def indexes(self) -> Indexes: DataArray.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property def xindexes(self) -> Indexes: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._coords, self.dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property def coords(self) -> DataArrayCoordinates: @@ -855,7 +862,7 @@ def reset_coords( Dataset, or DataArray if ``drop == True`` """ if names is None: - names = set(self.coords) - set(self.dims) + names = set(self.coords) - set(self._indexes) dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) @@ -901,7 +908,8 @@ def _dask_finalize(results, name, func, *args, **kwargs): ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables - return DataArray(variable, coords, name=name, fastpath=True) + indexes = ds._indexes + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def load(self, **kwargs) -> DataArray: """Manually trigger loading of this array's data from disk or a @@ -1037,11 +1045,15 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: pandas.DataFrame.copy """ variable = self.variable.copy(deep=deep, data=data) - coords = {k: v.copy(deep=deep) for k, v in self._coords.items()} - if self._indexes is None: - indexes = self._indexes - else: - indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()} + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + coords = {} + for k, v in self._coords.items(): + if k in index_vars: + coords[k] = index_vars[k] + else: + coords[k] = v.copy(deep=deep) + return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> DataArray: @@ -1206,19 +1218,23 @@ def isel( # lists, or zero or one-dimensional np.ndarray's variable = self._variable.isel(indexers, missing_dims=missing_dims) + indexes, index_variables = isel_indexes(self.xindexes, indexers) coords = {} for coord_name, coord_value in self._coords.items(): - coord_indexers = { - k: v for k, v in indexers.items() if k in coord_value.dims - } - if coord_indexers: - coord_value = coord_value.isel(coord_indexers) - if drop and coord_value.ndim == 0: - continue + if coord_name in index_variables: + coord_value = index_variables[coord_name] + else: + coord_indexers = { + k: v for k, v in indexers.items() if k in coord_value.dims + } + if coord_indexers: + coord_value = coord_value.isel(coord_indexers) + if drop and coord_value.ndim == 0: + continue coords[coord_name] = coord_value - return self._replace(variable=variable, coords=coords) + return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( self, @@ -1463,6 +1479,37 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> DataArray: + """Callback called from ``Aligner`` to create a new reindexed DataArray.""" + + if isinstance(fill_value, dict): + fill_value = fill_value.copy() + sentinel = object() + value = fill_value.pop(self.name, sentinel) + if value is not sentinel: + fill_value[_THIS_ARRAY] = value + + ds = self._to_temp_dataset() + reindexed = ds._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + return self._from_temp_dataset(reindexed) + def reindex_like( self, other: DataArray | Dataset, @@ -1520,9 +1567,9 @@ def reindex_like( DataArray.reindex align """ - indexers = reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex_like( + self, + other=other, method=method, tolerance=tolerance, copy=copy, @@ -1609,22 +1656,15 @@ def reindex( DataArray.reindex_like align """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - if isinstance(fill_value, dict): - fill_value = fill_value.copy() - sentinel = object() - value = fill_value.pop(self.name, sentinel) - if value is not sentinel: - fill_value[_THIS_ARRAY] = value - - ds = self._to_temp_dataset().reindex( + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, indexers=indexers, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) - return self._from_temp_dataset(ds) def interp( self, @@ -2043,10 +2083,8 @@ def reset_index( -------- DataArray.set_index """ - coords, _ = split_indexes( - dims_or_levels, self._coords, set(), self._level_coords, drop=drop - ) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop) + return self._from_temp_dataset(ds) def reorder_levels( self, @@ -2071,21 +2109,14 @@ def reorder_levels( Another dataarray, with this dataarray's data but replaced coordinates. """ - dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") - replace_coords = {} - for dim, order in dim_order.items(): - coord = self._coords[dim] - index = coord.to_index() - if not isinstance(index, pd.MultiIndex): - raise ValueError(f"coordinate {dim!r} has no MultiIndex") - replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) - coords = self._coords.copy() - coords.update(replace_coords) - return self._replace(coords=coords) + ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs) + return self._from_temp_dataset(ds) def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: bool = True, + index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> DataArray: """ @@ -2102,6 +2133,14 @@ def stack( replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type. Must be an Xarray index that + implements `.stack()`. By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -2132,13 +2171,18 @@ def stack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') See Also -------- DataArray.unstack """ - ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs) + ds = self._to_temp_dataset().stack( + dimensions, + create_index=create_index, + index_cls=index_cls, + **dimensions_kwargs, + ) return self._from_temp_dataset(ds) def unstack( @@ -2192,7 +2236,7 @@ def unstack( ('b', 0), ('b', 1), ('b', 2)], - names=['x', 'y']) + name='z') >>> roundtripped = stacked.unstack() >>> arr.identical(roundtripped) True @@ -2244,7 +2288,7 @@ def to_unstacked_dataset(self, dim, level=0): ('a', 1.0), ('a', 2.0), ('b', nan)], - names=['variable', 'y']) + name='z') >>> roundtripped = stacked.to_unstacked_dataset(dim="z") >>> data.identical(roundtripped) True @@ -2253,10 +2297,7 @@ def to_unstacked_dataset(self, dim, level=0): -------- Dataset.to_stacked_array """ - - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = self.xindexes[dim].to_pandas_index() + idx = self._indexes[dim].to_pandas_index() if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") @@ -4072,6 +4113,9 @@ def pad( For ``mode="constant"`` and ``constant_values=None``, integer types will be promoted to ``float`` and padded with ``np.nan``. + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd7807c2e7c..155cf21b4db 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -15,7 +15,6 @@ Any, Callable, Collection, - DefaultDict, Hashable, Iterable, Iterator, @@ -53,24 +52,21 @@ from .arithmetic import DatasetArithmetic from .common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes from .computation import unify_chunks -from .coordinates import ( - DatasetCoordinates, - assert_coordinate_consistent, - remap_label_indexers, -) +from .coordinates import DatasetCoordinates, assert_coordinate_consistent from .duck_array_ops import datetime_to_numeric from .indexes import ( Index, Indexes, PandasIndex, PandasMultiIndex, - default_indexes, - isel_variable_and_index, - propagate_indexes, + assert_no_index_corrupted, + create_default_index_implicit, + filter_indexes_from_coords, + isel_indexes, remove_unused_levels_categories, - roll_index, + roll_indexes, ) -from .indexing import is_fancy_indexer +from .indexing import is_fancy_indexer, map_index_queries from .merge import ( dataset_merge_method, dataset_update_method, @@ -100,8 +96,8 @@ IndexVariable, Variable, as_variable, - assert_unique_multiindex_level_names, broadcast_variables, + calculate_dimensions, ) if TYPE_CHECKING: @@ -136,13 +132,12 @@ def _get_virtual_variable( - variables, key: Hashable, level_vars: Mapping = None, dim_sizes: Mapping = None + variables, key: Hashable, dim_sizes: Mapping = None ) -> tuple[Hashable, Hashable, Variable]: - """Get a virtual variable (e.g., 'time.year' or a MultiIndex level) - from a dict of xarray.Variable objects (if possible) + """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable + objects (if possible) + """ - if level_vars is None: - level_vars = {} if dim_sizes is None: dim_sizes = {} @@ -155,202 +150,22 @@ def _get_virtual_variable( raise KeyError(key) split_key = key.split(".", 1) - var_name: str | None - if len(split_key) == 2: - ref_name, var_name = split_key - elif len(split_key) == 1: - ref_name, var_name = key, None - else: + if len(split_key) != 2: raise KeyError(key) - if ref_name in level_vars: - dim_var = variables[level_vars[ref_name]] - ref_var = dim_var.to_index_variable().get_level_variable(ref_name) - else: - ref_var = variables[ref_name] + ref_name, var_name = split_key + ref_var = variables[ref_name] - if var_name is None: - virtual_var = ref_var - var_name = key + if _contains_datetime_like_objects(ref_var): + ref_var = xr.DataArray(ref_var) + data = getattr(ref_var.dt, var_name).data else: - if _contains_datetime_like_objects(ref_var): - ref_var = xr.DataArray(ref_var) - data = getattr(ref_var.dt, var_name).data - else: - data = getattr(ref_var, var_name).data - virtual_var = Variable(ref_var.dims, data) + data = getattr(ref_var, var_name).data + virtual_var = Variable(ref_var.dims, data) return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: - """Calculate the dimensions corresponding to a set of variables. - - Returns dictionary mapping from dimension names to sizes. Raises ValueError - if any of the dimension sizes conflict. - """ - dims: dict[Hashable, int] = {} - last_used = {} - scalar_vars = {k for k, v in variables.items() if not v.dims} - for k, var in variables.items(): - for dim, size in zip(var.dims, var.shape): - if dim in scalar_vars: - raise ValueError( - f"dimension {dim!r} already exists as a scalar variable" - ) - if dim not in dims: - dims[dim] = size - last_used[dim] = k - elif dims[dim] != size: - raise ValueError( - f"conflicting sizes for dimension {dim!r}: " - f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" - ) - return dims - - -def merge_indexes( - indexes: Mapping[Any, Hashable | Sequence[Hashable]], - variables: Mapping[Any, Variable], - coord_names: set[Hashable], - append: bool = False, -) -> tuple[dict[Hashable, Variable], set[Hashable]]: - """Merge variables into multi-indexes. - - Not public API. Used in Dataset and DataArray set_index - methods. - """ - vars_to_replace: dict[Hashable, Variable] = {} - vars_to_remove: list[Hashable] = [] - dims_to_replace: dict[Hashable, Hashable] = {} - error_msg = "{} is not the name of an existing variable." - - for dim, var_names in indexes.items(): - if isinstance(var_names, str) or not isinstance(var_names, Sequence): - var_names = [var_names] - - names: list[Hashable] = [] - codes: list[list[int]] = [] - levels: list[list[int]] = [] - current_index_variable = variables.get(dim) - - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - if ( - current_index_variable is not None - and var.dims != current_index_variable.dims - ): - raise ValueError( - f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}" - ) - - if current_index_variable is not None and append: - current_index = current_index_variable.to_index() - if isinstance(current_index, pd.MultiIndex): - names.extend(current_index.names) - codes.extend(current_index.codes) - levels.extend(current_index.levels) - else: - names.append(f"{dim}_level_0") - cat = pd.Categorical(current_index.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - if not len(names) and len(var_names) == 1: - idx = pd.Index(variables[var_names[0]].values) - - else: # MultiIndex - for n in var_names: - try: - var = variables[n] - except KeyError: - raise ValueError(error_msg.format(n)) - names.append(n) - cat = pd.Categorical(var.values, ordered=True) - codes.append(cat.codes) - levels.append(cat.categories) - - idx = pd.MultiIndex(levels, codes, names=names) - for n in names: - dims_to_replace[n] = dim - - vars_to_replace[dim] = IndexVariable(dim, idx) - vars_to_remove.extend(var_names) - - new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} - new_variables.update(vars_to_replace) - - # update dimensions if necessary, GH: 3512 - for k, v in new_variables.items(): - if any(d in dims_to_replace for d in v.dims): - new_dims = [dims_to_replace.get(d, d) for d in v.dims] - new_variables[k] = v._replace(dims=new_dims) - new_coord_names = coord_names | set(vars_to_replace) - new_coord_names -= set(vars_to_remove) - return new_variables, new_coord_names - - -def split_indexes( - dims_or_levels: Hashable | Sequence[Hashable], - variables: Mapping[Any, Variable], - coord_names: set[Hashable], - level_coords: Mapping[Any, Hashable], - drop: bool = False, -) -> tuple[dict[Hashable, Variable], set[Hashable]]: - """Extract (multi-)indexes (levels) as variables. - - Not public API. Used in Dataset and DataArray reset_index - methods. - """ - if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): - dims_or_levels = [dims_or_levels] - - dim_levels: DefaultDict[Any, list[Hashable]] = defaultdict(list) - dims = [] - for k in dims_or_levels: - if k in level_coords: - dim_levels[level_coords[k]].append(k) - else: - dims.append(k) - - vars_to_replace = {} - vars_to_create: dict[Hashable, Variable] = {} - vars_to_remove = [] - - for d in dims: - index = variables[d].to_index() - if isinstance(index, pd.MultiIndex): - dim_levels[d] = index.names - else: - vars_to_remove.append(d) - if not drop: - vars_to_create[str(d) + "_"] = Variable(d, index, variables[d].attrs) - - for d, levs in dim_levels.items(): - index = variables[d].to_index() - if len(levs) == index.nlevels: - vars_to_remove.append(d) - else: - vars_to_replace[d] = IndexVariable(d, index.droplevel(levs)) - - if not drop: - for lev in levs: - idx = index.get_level_values(lev) - vars_to_create[idx.name] = Variable(d, idx, variables[d].attrs) - - new_variables = dict(variables) - for v in set(vars_to_remove): - del new_variables[v] - new_variables.update(vars_to_replace) - new_variables.update(vars_to_create) - new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove) - - return new_variables, new_coord_names - - def _assert_empty(args: tuple, msg: str = "%s") -> None: if args: raise ValueError(msg % args) @@ -570,8 +385,8 @@ def __setitem__(self, key, value) -> None: ) # set new values - pos_indexers, _ = remap_label_indexers(self.dataset, key) - self.dataset[pos_indexers] = value + dim_indexers = map_index_queries(self.dataset, key).dim_indexers + self.dataset[dim_indexers] = value class Dataset(DataWithCoords, DatasetReductions, DatasetArithmetic, Mapping): @@ -703,7 +518,7 @@ class Dataset(DataWithCoords, DatasetReductions, DatasetArithmetic, Mapping): _dims: dict[Hashable, int] _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] | None + _indexes: dict[Hashable, Index] _variables: dict[Hashable, Variable] __slots__ = ( @@ -1081,6 +896,8 @@ def _construct_direct( """ if dims is None: dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -1097,7 +914,7 @@ def _replace( coord_names: set[Hashable] = None, dims: dict[Any, int] = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -1118,7 +935,7 @@ def _replace( self._dims = dims if attrs is not _default: self._attrs = attrs - if indexes is not _default: + if indexes is not None: self._indexes = indexes if encoding is not _default: self._encoding = encoding @@ -1132,8 +949,8 @@ def _replace( dims = self._dims.copy() if attrs is _default: attrs = copy.copy(self._attrs) - if indexes is _default: - indexes = copy.copy(self._indexes) + if indexes is None: + indexes = self._indexes.copy() if encoding is _default: encoding = copy.copy(self._encoding) obj = self._construct_direct( @@ -1146,7 +963,7 @@ def _replace_with_new_dims( variables: dict[Hashable, Variable], coord_names: set = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None | Default = _default, + indexes: dict[Hashable, Index] = None, inplace: bool = False, ) -> Dataset: """Replace variables with recalculated dimensions.""" @@ -1174,26 +991,79 @@ def _replace_vars_and_dims( variables, coord_names, dims, attrs, indexes=None, inplace=inplace ) - def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> Dataset: + def _overwrite_indexes( + self, + indexes: Mapping[Hashable, Index], + variables: Mapping[Hashable, Variable] = None, + drop_variables: list[Hashable] = None, + drop_indexes: list[Hashable] = None, + rename_dims: Mapping[Hashable, Hashable] = None, + ) -> Dataset: + """Maybe replace indexes. + + This function may do a lot more depending on index query + results. + + """ if not indexes: return self - variables = self._variables.copy() - new_indexes = dict(self.xindexes) - for name, idx in indexes.items(): - variables[name] = IndexVariable(name, idx.to_pandas_index()) - new_indexes[name] = idx - obj = self._replace(variables, indexes=new_indexes) - - # switch from dimension to level names, if necessary - dim_names: dict[Hashable, str] = {} - for dim, idx in indexes.items(): - pd_idx = idx.to_pandas_index() - if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: - dim_names[dim] = pd_idx.name - if dim_names: - obj = obj.rename(dim_names) - return obj + if variables is None: + variables = {} + if drop_variables is None: + drop_variables = [] + if drop_indexes is None: + drop_indexes = [] + + new_variables = self._variables.copy() + new_coord_names = self._coord_names.copy() + new_indexes = dict(self._indexes) + + index_variables = {} + no_index_variables = {} + for name, var in variables.items(): + old_var = self._variables.get(name) + if old_var is not None: + var.attrs.update(old_var.attrs) + var.encoding.update(old_var.encoding) + if name in indexes: + index_variables[name] = var + else: + no_index_variables[name] = var + + for name in indexes: + new_indexes[name] = indexes[name] + + for name, var in index_variables.items(): + new_coord_names.add(name) + new_variables[name] = var + + # append no-index variables at the end + for k in no_index_variables: + new_variables.pop(k) + new_variables.update(no_index_variables) + + for name in drop_indexes: + new_indexes.pop(name) + + for name in drop_variables: + new_variables.pop(name) + new_indexes.pop(name, None) + new_coord_names.remove(name) + + replaced = self._replace( + variables=new_variables, coord_names=new_coord_names, indexes=new_indexes + ) + + if rename_dims: + # skip rename indexes: they should already have the right name(s) + dims = replaced._rename_dims(rename_dims) + new_variables, new_coord_names = replaced._rename_vars({}, rename_dims) + return replaced._replace( + variables=new_variables, coord_names=new_coord_names, dims=dims + ) + else: + return replaced def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: """Returns a copy of this dataset. @@ -1293,10 +1163,11 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: pandas.DataFrame.copy """ if data is None: - variables = {k: v.copy(deep=deep) for k, v in self._variables.items()} + data = {} elif not utils.is_dict_like(data): raise ValueError("Data must be dict-like") - else: + + if data: var_keys = set(self.data_vars.keys()) data_keys = set(data.keys()) keys_not_in_vars = data_keys - var_keys @@ -1311,14 +1182,19 @@ def copy(self, deep: bool = False, data: Mapping = None) -> Dataset: "Data must contain all variables in original " "dataset. Data is missing {}".format(keys_missing_from_data) ) - variables = { - k: v.copy(deep=deep, data=data.get(k)) - for k, v in self._variables.items() - } + + indexes, index_vars = self.xindexes.copy_indexes(deep=deep) + + variables = {} + for k, v in self._variables.items(): + if k in index_vars: + variables[k] = index_vars[k] + else: + variables[k] = v.copy(deep=deep, data=data.get(k)) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) - return self._replace(variables, attrs=attrs) + return self._replace(variables, indexes=indexes, attrs=attrs) def as_numpy(self: Dataset) -> Dataset: """ @@ -1332,21 +1208,6 @@ def as_numpy(self: Dataset) -> Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - @property - def _level_coords(self) -> dict[str, Hashable]: - """Return a mapping of all MultiIndex levels and their corresponding - coordinate name. - """ - level_coords: dict[str, Hashable] = {} - for name, index in self.xindexes.items(): - # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. - pd_index = index.to_pandas_index() - if isinstance(pd_index, pd.MultiIndex): - level_names = pd_index.names - (dim,) = self.variables[name].dims - level_coords.update({lname: dim for lname in level_names}) - return level_coords - def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. @@ -1360,13 +1221,16 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims + self._variables, name, self.dims ) variables[var_name] = var if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) if (var_name,) == var.dims: - indexes[var_name] = var._to_xindex() + index, index_vars = create_default_index_implicit(var, names) + indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) needed_dims: OrderedSet[Hashable] = OrderedSet() for v in variables.values(): @@ -1382,8 +1246,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Dataset: if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) - if k in self.xindexes: - indexes[k] = self.xindexes[k] + + indexes.update(filter_indexes_from_coords(self._indexes, coord_names)) return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1394,9 +1258,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: try: variable = self._variables[name] except KeyError: - _, name, variable = _get_virtual_variable( - self._variables, name, self._level_coords, self.dims - ) + _, name, variable = _get_virtual_variable(self._variables, name, self.dims) needed_dims = set(variable.dims) @@ -1406,10 +1268,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] - if self._indexes is None: - indexes = None - else: - indexes = {k: v for k, v in self._indexes.items() if k in coords} + indexes = filter_indexes_from_coords(self._indexes, set(coords)) return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) @@ -1436,9 +1295,6 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) - # uses empty dict -- everything here can already be found in self.coords. - yield HybridMappingProxy(keys=self._level_coords, mapping={}) - def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether 'key' is an array in the dataset or not. @@ -1629,11 +1485,12 @@ def _setitem_check(self, key, value): def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" + assert_no_index_corrupted(self.xindexes, {key}) + + if key in self._indexes: + del self._indexes[key] del self._variables[key] self._coord_names.discard(key) - if key in self.xindexes: - assert self._indexes is not None - del self._indexes[key] self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable @@ -1707,7 +1564,7 @@ def identical(self, other: Dataset) -> bool: return False @property - def indexes(self) -> Indexes: + def indexes(self) -> Indexes[pd.Index]: """Mapping of pandas.Index objects used for label based indexing. Raises an error if this Dataset has indexes that cannot be coerced @@ -1718,14 +1575,12 @@ def indexes(self) -> Indexes: Dataset.xindexes """ - return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + return self.xindexes.to_pandas_indexes() @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" - if self._indexes is None: - self._indexes = default_indexes(self._variables, self._dims) - return Indexes(self._indexes) + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: @@ -1789,14 +1644,14 @@ def reset_coords( Dataset """ if names is None: - names = self._coord_names - set(self.dims) + names = self._coord_names - set(self._indexes) else: if isinstance(names, str) or not isinstance(names, Iterable): names = [names] else: names = list(names) self._assert_all_in_dataset(names) - bad_coords = set(names) & set(self.dims) + bad_coords = set(names) & set(self._indexes) if bad_coords: raise ValueError( f"cannot remove index coordinates with reset_coords: {bad_coords}" @@ -2221,9 +2076,7 @@ def _validate_indexers( v = np.asarray(v) if v.dtype.kind in "US": - # TODO: benbovy - flexible indexes - # update when CFTimeIndex has its own xarray index class - index = self.xindexes[k].to_pandas_index() + index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): @@ -2359,24 +2212,22 @@ def isel( variables = {} dims: dict[Hashable, int] = {} coord_names = self._coord_names.copy() - indexes = self._indexes.copy() if self._indexes is not None else None - - for var_name, var_value in self._variables.items(): - var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims} - if var_indexers: - var_value = var_value.isel(var_indexers) - if drop and var_value.ndim == 0 and var_name in coord_names: - coord_names.remove(var_name) - if indexes: - indexes.pop(var_name, None) - continue - if indexes and var_name in indexes: - if var_value.ndim == 1: - indexes[var_name] = var_value._to_xindex() - else: - del indexes[var_name] - variables[var_name] = var_value - dims.update(zip(var_value.dims, var_value.shape)) + + indexes, index_variables = isel_indexes(self.xindexes, indexers) + + for name, var in self._variables.items(): + # preserve variable order + if name in index_variables: + var = index_variables[name] + else: + var_indexers = {k: v for k, v in indexers.items() if k in var.dims} + if var_indexers: + var = var.isel(var_indexers) + if drop and var.ndim == 0 and name in coord_names: + coord_names.remove(name) + continue + variables[name] = var + dims.update(zip(var.dims, var.shape)) return self._construct_direct( variables=variables, @@ -2395,29 +2246,22 @@ def _isel_fancy( drop: bool, missing_dims: str = "raise", ) -> Dataset: - # Note: we need to preserve the original indexers variable in order to merge the - # coords below - indexers_list = list(self._validate_indexers(indexers, missing_dims)) + valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} - indexes: dict[Hashable, Index] = {} + indexes, index_variables = isel_indexes(self.xindexes, valid_indexers) for name, var in self.variables.items(): - var_indexers = {k: v for k, v in indexers_list if k in var.dims} - if drop and name in var_indexers: - continue # drop this variable - - if name in self.xindexes: - new_var, new_index = isel_variable_and_index( - name, var, self.xindexes[name], var_indexers - ) - if new_index is not None: - indexes[name] = new_index - elif var_indexers: - new_var = var.isel(indexers=var_indexers) + if name in index_variables: + new_var = index_variables[name] else: - new_var = var.copy(deep=False) - + var_indexers = { + k: v for k, v in valid_indexers.items() if k in var.dims + } + if var_indexers: + new_var = var.isel(indexers=var_indexers) + else: + new_var = var.copy(deep=False) variables[name] = new_var coord_names = self._coord_names & variables.keys() @@ -2434,7 +2278,7 @@ def sel( self, indexers: Mapping[Any, Any] = None, method: str = None, - tolerance: Number = None, + tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, ) -> Dataset: @@ -2499,15 +2343,22 @@ def sel( DataArray.sel """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") - pos_indexers, new_indexes = remap_label_indexers( + query_results = map_index_queries( self, indexers=indexers, method=method, tolerance=tolerance ) - # TODO: benbovy - flexible indexes: also use variables returned by Index.query - # (temporary dirty fix). - new_indexes = {k: v[0] for k, v in new_indexes.items()} - result = self.isel(indexers=pos_indexers, drop=drop) - return result._overwrite_indexes(new_indexes) + if drop: + no_scalar_variables = {} + for k, v in query_results.variables.items(): + if v.dims: + no_scalar_variables[k] = v + else: + if k in self._coord_names: + query_results.drop_coords.append(k) + query_results.variables = no_scalar_variables + + result = self.isel(indexers=query_results.dim_indexers, drop=drop) + return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( self, @@ -2677,6 +2528,58 @@ def broadcast_like( return _broadcast_helper(args[1], exclude, dims_map, common_coords) + def _reindex_callback( + self, + aligner: alignment.Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Dataset: + """Callback called from ``Aligner`` to create a new reindexed Dataset.""" + + new_variables = variables.copy() + new_indexes = indexes.copy() + + # pass through indexes from excluded dimensions + # no extra check needed for multi-coordinate indexes, potential conflicts + # should already have been detected when aligning the indexes + for name, idx in self._indexes.items(): + var = self._variables[name] + if set(var.dims) <= exclude_dims: + new_indexes[name] = idx + new_variables[name] = var + + if not dim_pos_indexers: + # fast path for no reindexing necessary + if set(new_indexes) - set(self._indexes): + # this only adds new indexes and their coordinate variables + reindexed = self._overwrite_indexes(new_indexes, new_variables) + else: + reindexed = self.copy(deep=aligner.copy) + else: + to_reindex = { + k: v + for k, v in self.variables.items() + if k not in variables and k not in exclude_vars + } + reindexed_vars = alignment.reindex_variables( + to_reindex, + dim_pos_indexers, + copy=aligner.copy, + fill_value=fill_value, + sparse=aligner.sparse, + ) + new_variables.update(reindexed_vars) + new_coord_names = self._coord_names | set(new_indexes) + reindexed = self._replace_with_new_dims( + new_variables, new_coord_names, indexes=new_indexes + ) + + return reindexed + def reindex_like( self, other: Dataset | DataArray, @@ -2733,13 +2636,13 @@ def reindex_like( Dataset.reindex align """ - indexers = alignment.reindex_like_indexers(self, other) - return self.reindex( - indexers=indexers, + return alignment.reindex_like( + self, + other=other, method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, - tolerance=tolerance, ) def reindex( @@ -2823,6 +2726,7 @@ def reindex( temperature (station) float64 10.98 14.3 12.06 10.9 pressure (station) float64 211.8 322.9 218.8 445.9 >>> x.indexes + Indexes: station: Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') Create a new index and reindex the dataset. By default values in the new index that @@ -2946,14 +2850,14 @@ def reindex( original dataset, use the :py:meth:`~Dataset.fillna()` method. """ - return self._reindex( - indexers, - method, - tolerance, - copy, - fill_value, - sparse=False, - **indexers_kwargs, + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, + copy=copy, + fill_value=fill_value, ) def _reindex( @@ -2967,28 +2871,18 @@ def _reindex( **indexers_kwargs: Any, ) -> Dataset: """ - same to _reindex but support sparse option + Same as reindex but supports sparse option. """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") - - bad_dims = [d for d in indexers if d not in self.dims] - if bad_dims: - raise ValueError(f"invalid reindex dimensions: {bad_dims}") - - variables, indexes = alignment.reindex_variables( - self.variables, - self.sizes, - self.xindexes, - indexers, - method, - tolerance, + return alignment.reindex( + self, + indexers=indexers, + method=method, + tolerance=tolerance, copy=copy, fill_value=fill_value, sparse=sparse, ) - coord_names = set(self._coord_names) - coord_names.update(indexers) - return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp( self, @@ -3181,7 +3075,7 @@ def _validate_interp_indexer(x, new_x): } variables: dict[Hashable, Variable] = {} - to_reindex: dict[Hashable, Variable] = {} + reindex: bool = False for name, var in obj._variables.items(): if name in indexers: continue @@ -3199,42 +3093,49 @@ def _validate_interp_indexer(x, new_x): elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): # For types that we do not understand do stepwise # interpolation to avoid modifying the elements. - # Use reindex_variables instead because it supports + # reindex the variable instead because it supports # booleans and objects and retains the dtype but inside # this loop there might be some duplicate code that slows it # down, therefore collect these signals and run it later: - to_reindex[name] = var + reindex = True elif all(d not in indexers for d in var.dims): # For anything else we can only keep variables if they # are not dependent on any coords that are being # interpolated along: variables[name] = var - if to_reindex: - # Reindex variables: - variables_reindex = alignment.reindex_variables( - variables=to_reindex, - sizes=obj.sizes, - indexes=obj.xindexes, - indexers={k: v[-1] for k, v in validated_indexers.items()}, + if reindex: + reindex_indexers = { + k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,) + } + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, method=method_non_numeric, - )[0] - variables.update(variables_reindex) + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed._indexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() - # Get the indexes that are not being interpolated along: - indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) # Attach indexer as coordinate - variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) if v.dims == (k,): - indexes[k] = v._to_xindex() + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables({k: v}) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) @@ -3295,7 +3196,14 @@ def interp_like( """ if kwargs is None: kwargs = {} - coords = alignment.reindex_like_indexers(self, other) + + # pick only dimension coordinates with a single index + coords = {} + other_indexes = other.xindexes + for dim in self.dims: + other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") + if len(other_dim_coords) == 1: + coords[dim] = other_dim_coords[dim] numeric_coords: dict[Hashable, pd.Index] = {} object_coords: dict[Hashable, pd.Index] = {} @@ -3336,28 +3244,34 @@ def _rename_vars(self, name_dict, dims_dict): def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} - def _rename_indexes(self, name_dict, dims_set): - # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645 - if self._indexes is None: - return None + def _rename_indexes(self, name_dict, dims_dict): + if not self._indexes: + return {}, {} + indexes = {} - for k, v in self.indexes.items(): - new_name = name_dict.get(k, k) - if new_name not in dims_set: - continue - if isinstance(v, pd.MultiIndex): - new_names = [name_dict.get(k, k) for k in v.names] - indexes[new_name] = PandasMultiIndex( - v.rename(names=new_names), new_name - ) - else: - indexes[new_name] = PandasIndex(v.rename(new_name), new_name) - return indexes + variables = {} + + for index, coord_names in self.xindexes.group_by_index(): + new_index = index.rename(name_dict, dims_dict) + new_coord_names = [name_dict.get(k, k) for k in coord_names] + indexes.update({k: new_index for k in new_coord_names}) + new_index_vars = new_index.create_variables( + { + new: self._variables[old] + for old, new in zip(coord_names, new_coord_names) + } + ) + variables.update(new_index_vars) + + return indexes, variables def _rename_all(self, name_dict, dims_dict): variables, coord_names = self._rename_vars(name_dict, dims_dict) dims = self._rename_dims(dims_dict) - indexes = self._rename_indexes(name_dict, dims.keys()) + + indexes, index_vars = self._rename_indexes(name_dict, dims_dict) + variables = {k: index_vars.get(k, v) for k, v in variables.items()} + return variables, coord_names, dims, indexes def rename( @@ -3399,7 +3313,6 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) - assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( @@ -3573,21 +3486,19 @@ def swap_dims( dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: var = v.to_index_variable() - if k in self.xindexes: - indexes[k] = self.xindexes[k] + var.dims = dims + if k in self._indexes: + indexes[k] = self._indexes[k] + variables[k] = var else: - new_index = var.to_index() - if new_index.nlevels == 1: - # make sure index name matches dimension name - new_index = new_index.rename(k) - if isinstance(new_index, pd.MultiIndex): - indexes[k] = PandasMultiIndex(new_index, k) - else: - indexes[k] = PandasIndex(new_index, k) + index, index_vars = create_default_index_implicit(var) + indexes.update({name: index for name in index_vars}) + variables.update(index_vars) + coord_names.update(index_vars) else: var = v.to_base_variable() - var.dims = dims - variables[k] = var + var.dims = dims + variables[k] = var return self._replace_with_new_dims(variables, coord_names, indexes=indexes) @@ -3667,6 +3578,7 @@ def expand_dims( ) variables: dict[Hashable, Variable] = {} + indexes: dict[Hashable, Index] = dict(self._indexes) coord_names = self._coord_names.copy() # If dim is a dict, then ensure that the values are either integers # or iterables. @@ -3676,7 +3588,9 @@ def expand_dims( # save the coordinates to the variables dict, and set the # value within the dim dict to the length of the iterable # for later use. - variables[k] = xr.IndexVariable((k,), v) + index = PandasIndex(v, k) + indexes[k] = index + variables.update(index.create_variables()) coord_names.add(k) dim[k] = variables[k].size elif isinstance(v, int): @@ -3712,15 +3626,15 @@ def expand_dims( all_dims.insert(d, c) variables[k] = v.set_dims(dict(all_dims)) else: - # If dims includes a label of a non-dimension coordinate, - # it will be promoted to a 1D coordinate with a single value. - variables[k] = v.set_dims(k).to_index_variable() - - new_dims = self._dims.copy() - new_dims.update(dim) + if k not in variables: + # If dims includes a label of a non-dimension coordinate, + # it will be promoted to a 1D coordinate with a single value. + index, index_vars = create_default_index_implicit(v.set_dims(k)) + indexes[k] = index + variables.update(index_vars) - return self._replace_vars_and_dims( - variables, dims=new_dims, coord_names=coord_names + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes ) def set_index( @@ -3781,11 +3695,87 @@ def set_index( Dataset.reset_index Dataset.swap_dims """ - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") - variables, coord_names = merge_indexes( - indexes, self._variables, self._coord_names, append=append + dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") + + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + maybe_drop_indexes: list[Hashable] = [] + drop_variables: list[Hashable] = [] + replace_dims: dict[Hashable, Hashable] = {} + + for dim, _var_names in dim_coords.items(): + if isinstance(_var_names, str) or not isinstance(_var_names, Sequence): + var_names = [_var_names] + else: + var_names = list(_var_names) + + invalid_vars = set(var_names) - set(self._variables) + if invalid_vars: + raise ValueError( + ", ".join([str(v) for v in invalid_vars]) + + " variable(s) do not exist" + ) + + current_coord_names = self.xindexes.get_all_coords(dim, errors="ignore") + + # drop any pre-existing index involved + maybe_drop_indexes += list(current_coord_names) + var_names + for k in var_names: + maybe_drop_indexes += list( + self.xindexes.get_all_coords(k, errors="ignore") + ) + + drop_variables += var_names + + if len(var_names) == 1 and (not append or dim not in self._indexes): + var_name = var_names[0] + var = self._variables[var_name] + if var.dims != (dim,): + raise ValueError( + f"dimension mismatch: try setting an index for dimension {dim!r} with " + f"variable {var_name!r} that has dimensions {var.dims}" + ) + idx = PandasIndex.from_variables({dim: var}) + idx_vars = idx.create_variables({var_name: var}) + else: + if append: + current_variables = { + k: self._variables[k] for k in current_coord_names + } + else: + current_variables = {} + idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand( + dim, + current_variables, + {k: self._variables[k] for k in var_names}, + ) + for n in idx.index.names: + replace_dims[n] = dim + + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes_: dict[Any, Index] = { + k: v for k, v in self._indexes.items() if k not in maybe_drop_indexes + } + indexes_.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + # update dimensions if necessary, GH: 3512 + for k, v in variables.items(): + if any(d in replace_dims for d in v.dims): + new_dims = [replace_dims.get(d, d) for d in v.dims] + variables[k] = v._replace(dims=new_dims) + + coord_names = self._coord_names - set(drop_variables) | set(new_variables) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes_ ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) def reset_index( self, @@ -3812,14 +3802,56 @@ def reset_index( -------- Dataset.set_index """ - variables, coord_names = split_indexes( - dims_or_levels, - self._variables, - self._coord_names, - cast(Mapping[Hashable, Hashable], self._level_coords), - drop=drop, - ) - return self._replace_vars_and_dims(variables, coord_names=coord_names) + if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence): + dims_or_levels = [dims_or_levels] + + invalid_coords = set(dims_or_levels) - set(self._indexes) + if invalid_coords: + raise ValueError( + f"{tuple(invalid_coords)} are not coordinates with an index" + ) + + drop_indexes: list[Hashable] = [] + drop_variables: list[Hashable] = [] + replaced_indexes: list[PandasMultiIndex] = [] + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + + for name in dims_or_levels: + index = self._indexes[name] + drop_indexes += list(self.xindexes.get_all_coords(name)) + + if isinstance(index, PandasMultiIndex) and name not in self.dims: + # special case for pd.MultiIndex (name is an index level): + # replace by a new index with dropped level(s) instead of just drop the index + if index not in replaced_indexes: + level_names = index.index.names + level_vars = { + k: self._variables[k] + for k in level_names + if k not in dims_or_levels + } + if level_vars: + idx = index.keep_levels(level_vars) + idx_vars = idx.create_variables(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + replaced_indexes.append(index) + + if drop: + drop_variables.append(name) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) + + variables = { + k: v for k, v in self._variables.items() if k not in drop_variables + } + variables.update(new_variables) + + coord_names = set(new_variables) | self._coord_names + + return self._replace(variables, coord_names=coord_names, indexes=indexes) def reorder_levels( self, @@ -3846,61 +3878,149 @@ def reorder_levels( """ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() - indexes = dict(self.xindexes) + indexes = dict(self._indexes) + new_indexes: dict[Hashable, Index] = {} + new_variables: dict[Hashable, IndexVariable] = {} + for dim, order in dim_order.items(): - coord = self._variables[dim] - # TODO: benbovy - flexible indexes: update when MultiIndex - # has its own class inherited from xarray.Index - index = self.xindexes[dim].to_pandas_index() - if not isinstance(index, pd.MultiIndex): + index = self._indexes[dim] + + if not isinstance(index, PandasMultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") - new_index = index.reorder_levels(order) - variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = PandasMultiIndex(new_index, dim) + + level_vars = {k: self._variables[k] for k in order} + idx = index.reorder_levels(level_vars) + idx_vars = idx.create_variables(level_vars) + new_indexes.update({k: idx for k in idx_vars}) + new_variables.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in new_indexes} + indexes.update(new_indexes) + + variables = {k: v for k, v in self._variables.items() if k not in new_variables} + variables.update(new_variables) return self._replace(variables, indexes=indexes) - def _stack_once(self, dims, new_dim): + def _get_stack_index( + self, + dim, + multi=False, + create_index=False, + ) -> tuple[Index | None, dict[Hashable, Variable]]: + """Used by stack and unstack to get one pandas (multi-)index among + the indexed coordinates along dimension `dim`. + + If exactly one index is found, return it with its corresponding + coordinate variables(s), otherwise return None and an empty dict. + + If `create_index=True`, create a new index if none is found or raise + an error if multiple indexes are found. + + """ + stack_index: Index | None = None + stack_coords: dict[Hashable, Variable] = {} + + for name, index in self._indexes.items(): + var = self._variables[name] + if ( + var.ndim == 1 + and var.dims[0] == dim + and ( + # stack: must be a single coordinate index + not multi + and not self.xindexes.is_multi(name) + # unstack: must be an index that implements .unstack + or multi + and type(index).unstack is not Index.unstack + ) + ): + if stack_index is not None and index is not stack_index: + # more than one index found, stop + if create_index: + raise ValueError( + f"cannot stack dimension {dim!r} with `create_index=True` " + "and with more than one index found along that dimension" + ) + return None, {} + stack_index = index + stack_coords[name] = var + + if create_index and stack_index is None: + if dim in self._variables: + var = self._variables[dim] + else: + _, _, var = _get_virtual_variable(self._variables, dim, self.dims) + # dummy index (only `stack_coords` will be used to construct the multi-index) + stack_index = PandasIndex([0], dim) + stack_coords = {dim: var} + + return stack_index, stack_coords + + def _stack_once(self, dims, new_dim, index_cls, create_index=True): if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: dims = list(infix_dims(dims, self.dims)) - variables = {} - for name, var in self.variables.items(): - if name not in dims: - if any(d in var.dims for d in dims): - add_dims = [d for d in dims if d not in var.dims] - vdims = list(var.dims) + add_dims - shape = [self.dims[d] for d in vdims] - exp_var = var.set_dims(vdims, shape) - stacked_var = exp_var.stack(**{new_dim: dims}) - variables[name] = stacked_var - else: - variables[name] = var.copy(deep=False) - # consider dropping levels that are unused? - levels = [self.get_index(dim) for dim in dims] - idx = utils.multiindex_from_product_levels(levels, names=dims) - variables[new_dim] = IndexVariable(new_dim, idx) + new_variables: dict[Hashable, Variable] = {} + stacked_var_names: list[Hashable] = [] + drop_indexes: list[Hashable] = [] - coord_names = set(self._coord_names) - set(dims) | {new_dim} - - indexes = {k: v for k, v in self.xindexes.items() if k not in dims} - indexes[new_dim] = PandasMultiIndex(idx, new_dim) + for name, var in self.variables.items(): + if any(d in var.dims for d in dims): + add_dims = [d for d in dims if d not in var.dims] + vdims = list(var.dims) + add_dims + shape = [self.dims[d] for d in vdims] + exp_var = var.set_dims(vdims, shape) + stacked_var = exp_var.stack(**{new_dim: dims}) + new_variables[name] = stacked_var + stacked_var_names.append(name) + else: + new_variables[name] = var.copy(deep=False) + + # drop indexes of stacked coordinates (if any) + for name in stacked_var_names: + drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore")) + + new_indexes = {} + new_coord_names = set(self._coord_names) + if create_index or create_index is None: + product_vars: dict[Any, Variable] = {} + for dim in dims: + idx, idx_vars = self._get_stack_index(dim, create_index=create_index) + if idx is not None: + product_vars.update(idx_vars) + + if len(product_vars) == len(dims): + idx = index_cls.stack(product_vars, new_dim) + new_indexes[new_dim] = idx + new_indexes.update({k: idx for k in product_vars}) + idx_vars = idx.create_variables(product_vars) + # keep consistent multi-index coordinate order + for k in idx_vars: + new_variables.pop(k, None) + new_variables.update(idx_vars) + new_coord_names.update(idx_vars) + + indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes} + indexes.update(new_indexes) return self._replace_with_new_dims( - variables, coord_names=coord_names, indexes=indexes + new_variables, coord_names=new_coord_names, indexes=indexes ) def stack( self, dimensions: Mapping[Any, Sequence[Hashable]] = None, + create_index: bool | None = True, + index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], ) -> Dataset: """ Stack any number of existing dimensions into a single new dimension. - New dimensions will be added at the end, and the corresponding + New dimensions will be added at the end, and by default the corresponding coordinate variables will be combined into a MultiIndex. Parameters @@ -3911,6 +4031,14 @@ def stack( ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. + create_index : bool, optional + If True (default), create a multi-index for each of the stacked dimensions. + If False, don't create any index. + If None, create a multi-index only if exactly one single (1-d) coordinate + index is found for every dimension to stack. + index_cls: class, optional + Can be used to pass a custom multi-index type (must be an Xarray index that + implements `.stack()`). By default, a pandas multi-index wrapper is used. **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -3927,7 +4055,7 @@ def stack( dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "stack") result = self for new_dim, dims in dimensions.items(): - result = result._stack_once(dims, new_dim) + result = result._stack_once(dims, new_dim, index_cls, create_index) return result def to_stacked_array( @@ -3998,9 +4126,9 @@ def to_stacked_array( array([[0, 1, 2, 6], [3, 4, 5, 7]]) Coordinates: - * z (z) MultiIndex - - variable (z) object 'a' 'a' 'a' 'b' - - y (z) object 'u' 'v' 'w' nan + * z (z) object MultiIndex + * variable (z) object 'a' 'a' 'a' 'b' + * y (z) object 'u' 'v' 'w' nan Dimensions without coordinates: x """ @@ -4036,32 +4164,30 @@ def ensure_stackable(val): stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] data_array = xr.concat(stackable_vars, dim=new_dim) - # coerce the levels of the MultiIndex to have the same type as the - # input dimensions. This code is messy, so it might be better to just - # input a dummy value for the singleton dimension. - # TODO: benbovy - flexible indexes: update when MultIndex has its own - # class inheriting from xarray.Index - idx = data_array.xindexes[new_dim].to_pandas_index() - levels = [idx.levels[0]] + [ - level.astype(self[level.name].dtype) for level in idx.levels[1:] - ] - new_idx = idx.set_levels(levels) - data_array[new_dim] = IndexVariable(new_dim, new_idx) - if name is not None: data_array.name = name return data_array - def _unstack_once(self, dim: Hashable, fill_value, sparse: bool = False) -> Dataset: - index = self.get_index(dim) - index = remove_unused_levels_categories(index) - + def _unstack_once( + self, + dim: Hashable, + index_and_vars: tuple[Index, dict[Hashable, Variable]], + fill_value, + sparse: bool = False, + ) -> Dataset: + index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + for name, idx in new_indexes.items(): + variables.update(idx.create_variables(index_vars)) for name, var in self.variables.items(): - if name != dim: + if name not in index_vars: if dim in var.dims: if isinstance(fill_value, Mapping): fill_value_ = fill_value[name] @@ -4069,55 +4195,66 @@ def _unstack_once(self, dim: Hashable, fill_value, sparse: bool = False) -> Data fill_value_ = fill_value variables[name] = var._unstack_once( - index=index, dim=dim, fill_value=fill_value_, sparse=sparse + index=clean_index, + dim=dim, + fill_value=fill_value_, + sparse=sparse, ) else: variables[name] = var - for name, lev in zip(index.names, index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(index.names) + coord_names = set(self._coord_names) - {dim} | set(new_indexes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) - def _unstack_full_reindex(self, dim: Hashable, fill_value, sparse: bool) -> Dataset: - index = self.get_index(dim) - index = remove_unused_levels_categories(index) - full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) + def _unstack_full_reindex( + self, + dim: Hashable, + index_and_vars: tuple[Index, dict[Hashable, Variable]], + fill_value, + sparse: bool, + ) -> Dataset: + index, index_vars = index_and_vars + variables: dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self._indexes.items() if k != dim} + + new_indexes, clean_index = index.unstack() + indexes.update(new_indexes) + + new_index_variables = {} + for name, idx in new_indexes.items(): + new_index_variables.update(idx.create_variables(index_vars)) + + new_dim_sizes = {k: v.size for k, v in new_index_variables.items()} + variables.update(new_index_variables) # take a shortcut in case the MultiIndex was not modified. - if index.equals(full_idx): + full_idx = pd.MultiIndex.from_product( + clean_index.levels, names=clean_index.names + ) + if clean_index.equals(full_idx): obj = self else: + # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex + xr_full_idx = PandasMultiIndex(full_idx, dim) + indexers = Indexes( + {k: xr_full_idx for k in index_vars}, + xr_full_idx.create_variables(index_vars), + ) obj = self._reindex( - {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse + indexers, copy=False, fill_value=fill_value, sparse=sparse ) - new_dim_names = index.names - new_dim_sizes = [lev.size for lev in index.levels] - - variables: dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.xindexes.items() if k != dim} - for name, var in obj.variables.items(): - if name != dim: + if name not in index_vars: if dim in var.dims: - new_dims = dict(zip(new_dim_names, new_dim_sizes)) - variables[name] = var.unstack({dim: new_dims}) + variables[name] = var.unstack({dim: new_dim_sizes}) else: variables[name] = var - for name, lev in zip(new_dim_names, index.levels): - idx, idx_vars = PandasIndex.from_pandas_index(lev, name) - variables[name] = idx_vars[name] - indexes[name] = idx - - coord_names = set(self._coord_names) - {dim} | set(new_dim_names) + coord_names = set(self._coord_names) - {dim} | set(new_dim_sizes) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -4156,10 +4293,9 @@ def unstack( -------- Dataset.stack """ + if dim is None: - dims = [ - d for d in self.dims if isinstance(self.get_index(d), pd.MultiIndex) - ] + dims = list(self.dims) else: if isinstance(dim, str) or not isinstance(dim, Iterable): dims = [dim] @@ -4172,13 +4308,21 @@ def unstack( f"Dataset does not contain the dimensions: {missing_dims}" ) - non_multi_dims = [ - d for d in dims if not isinstance(self.get_index(d), pd.MultiIndex) - ] + # each specified dimension must have exactly one multi-index + stacked_indexes: dict[Any, tuple[Index, dict[Hashable, Variable]]] = {} + for d in dims: + idx, idx_vars = self._get_stack_index(d, multi=True) + if idx is not None: + stacked_indexes[d] = idx, idx_vars + + if dim is None: + dims = list(stacked_indexes) + else: + non_multi_dims = set(dims) - set(stacked_indexes) if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - f"have a MultiIndex: {non_multi_dims}" + f"have exactly one multi-index: {tuple(non_multi_dims)}" ) result = self.copy(deep=False) @@ -4188,7 +4332,7 @@ def unstack( # We only check the non-index variables. # https://github.com/pydata/xarray/issues/5902 nonindexes = [ - self.variables[k] for k in set(self.variables) - set(self.xindexes) + self.variables[k] for k in set(self.variables) - set(self._indexes) ] # Notes for each of these cases: # 1. Dask arrays don't support assignment by index, which the fast unstack @@ -4210,9 +4354,13 @@ def unstack( for dim in dims: if needs_full_reindex: - result = result._unstack_full_reindex(dim, fill_value, sparse) + result = result._unstack_full_reindex( + dim, stacked_indexes[dim], fill_value, sparse + ) else: - result = result._unstack_once(dim, fill_value, sparse) + result = result._unstack_once( + dim, stacked_indexes[dim], fill_value, sparse + ) return result def update(self, other: CoercibleMapping) -> Dataset: @@ -4379,9 +4527,11 @@ def drop_vars( if errors == "raise": self._assert_all_in_dataset(names) + assert_no_index_corrupted(self.xindexes, names) + variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k not in names} + indexes = {k: v for k, v in self._indexes.items() if k not in names} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -5096,7 +5246,7 @@ def reduce( ) coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes @@ -5299,15 +5449,16 @@ def to_array(self, dim="variable", name=None): broadcast_vars = broadcast_variables(*data_vars) data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0) - coords = dict(self.coords) - coords[dim] = list(self.data_vars) - indexes = propagate_indexes(self._indexes) - dims = (dim,) + broadcast_vars[0].dims + variable = Variable(dims, data, self.attrs, fastpath=True) - return DataArray( - data, coords, dims, attrs=self.attrs, name=name, indexes=indexes - ) + coords = {k: v.variable for k, v in self.coords.items()} + indexes = filter_indexes_from_coords(self._indexes, set(coords)) + new_dim_index = PandasIndex(list(self.data_vars), dim) + indexes[dim] = new_dim_index + coords.update(new_dim_index.create_variables()) + + return DataArray._construct_direct(variable, coords, name, indexes) def _normalize_dim_order( self, dim_order: list[Hashable] = None @@ -5519,7 +5670,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Datase # forwarding arguments to pandas.Series.to_numpy? arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] - obj = cls() + indexes: dict[Hashable, Index] = {} + index_vars: dict[Hashable, Variable] = {} if isinstance(idx, pd.MultiIndex): dims = tuple( @@ -5527,11 +5679,17 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Datase for n, name in enumerate(idx.names) ) for dim, lev in zip(dims, idx.levels): - obj[dim] = (dim, lev) + xr_idx = PandasIndex(lev, dim) + indexes[dim] = xr_idx + index_vars.update(xr_idx.create_variables()) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) - obj[index_name] = (dims, idx) + xr_idx = PandasIndex(idx, index_name) + indexes[index_name] = xr_idx + index_vars.update(xr_idx.create_variables()) + + obj = cls._construct_direct(index_vars, set(index_vars), indexes=indexes) if sparse: obj._set_sparse_data_from_dataframe(idx, arrays, dims) @@ -5890,10 +6048,13 @@ def diff(self, dim, n=1, label="upper"): else: raise ValueError("The 'label' argument has to be either 'upper' or 'lower'") + indexes, index_vars = isel_indexes(self.xindexes, kwargs_new) variables = {} for name, var in self.variables.items(): - if dim in var.dims: + if name in index_vars: + variables[name] = index_vars[name] + elif dim in var.dims: if name in self.data_vars: variables[name] = var.isel(**kwargs_end) - var.isel(**kwargs_start) else: @@ -5901,16 +6062,6 @@ def diff(self, dim, n=1, label="upper"): else: variables[name] = var - indexes = dict(self.xindexes) - if dim in indexes: - if isinstance(indexes[dim], PandasIndex): - # maybe optimize? (pandas index already indexed above with var.isel) - new_index = indexes[dim].index[kwargs_new[dim]] - if isinstance(new_index, pd.MultiIndex): - indexes[dim] = PandasMultiIndex(new_index, dim) - else: - indexes[dim] = PandasIndex(new_index, dim) - difference = self._replace_with_new_dims(variables, indexes=indexes) if n > 1: @@ -6049,29 +6200,27 @@ def roll( if invalid: raise ValueError(f"dimensions {invalid!r} do not exist") - unrolled_vars = () if roll_coords else self.coords + unrolled_vars: tuple[Hashable, ...] + + if roll_coords: + indexes, index_vars = roll_indexes(self.xindexes, shifts) + unrolled_vars = () + else: + indexes = dict(self._indexes) + index_vars = dict(self.xindexes.variables) + unrolled_vars = tuple(self.coords) variables = {} for k, var in self.variables.items(): - if k not in unrolled_vars: + if k in index_vars: + variables[k] = index_vars[k] + elif k not in unrolled_vars: variables[k] = var.roll( shifts={k: s for k, s in shifts.items() if k in var.dims} ) else: variables[k] = var - if roll_coords: - indexes: dict[Hashable, Index] = {} - idx: pd.Index - for k, idx in self.xindexes.items(): - (dim,) = self.variables[k].dims - if dim in shifts: - indexes[k] = roll_index(idx, shifts[dim]) - else: - indexes[k] = idx - else: - indexes = dict(self.xindexes) - return self._replace(variables, indexes=indexes) def sortby(self, variables, ascending=True): @@ -6324,7 +6473,7 @@ def quantile( # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None @@ -6557,7 +6706,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = {k: v for k, v in self.xindexes.items() if k in variables} + indexes = {k: v for k, v in self._indexes.items() if k in variables} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -7165,6 +7314,9 @@ def pad( promoted to ``float`` and padded with ``np.nan``. To avoid type promotion specify ``constant_values=np.nan`` + Padding coordinates will drop their corresponding index (if any) and will reset default + indexes for dimension coordinates. + Examples -------- >>> ds = xr.Dataset({"foo": ("x", range(5))}) @@ -7190,6 +7342,15 @@ def pad( coord_pad_options = {} variables = {} + + # keep indexes that won't be affected by pad and drop all other indexes + xindexes = self.xindexes + pad_dims = set(pad_width) + indexes = {} + for k, idx in xindexes.items(): + if not pad_dims.intersection(xindexes.get_all_dims(k)): + indexes[k] = idx + for name, var in self.variables.items(): var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims} if not var_pad_width: @@ -7209,8 +7370,15 @@ def pad( mode=coord_pad_mode, **coord_pad_options, # type: ignore[arg-type] ) - - return self._replace_vars_and_dims(variables) + # reset default index of dimension coordinates + if (name,) == var.dims: + dim_var = {name: variables[name]} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) + indexes[name] = index + variables[name] = index_vars[name] + + return self._replace_with_new_dims(variables, indexes=indexes) def idxmin( self, diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 2a9f8a27815..81617ae38f9 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -2,6 +2,7 @@ """ import contextlib import functools +from collections import defaultdict from datetime import datetime, timedelta from itertools import chain, zip_longest from typing import Collection, Hashable, Optional @@ -266,10 +267,10 @@ def inline_sparse_repr(array): def inline_variable_array_repr(var, max_width): """Build a one-line summary of a variable's data.""" - if var._in_memory: - return format_array_flat(var, max_width) - elif hasattr(var._data, "_repr_inline_"): + if hasattr(var._data, "_repr_inline_"): return var._data._repr_inline_(max_width) + elif var._in_memory: + return format_array_flat(var, max_width) elif isinstance(var._data, dask_array_type): return inline_dask_repr(var.data) elif isinstance(var._data, sparse_array_type): @@ -282,68 +283,33 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( - name: Hashable, var, col_width: int, marker: str = " ", max_width: int = None + name: Hashable, var, col_width: int, max_width: int = None, is_index: bool = False ): """Summarize a variable in one line, e.g., for the Dataset.__repr__.""" + variable = getattr(var, "variable", var) + if max_width is None: max_width_options = OPTIONS["display_width"] if not isinstance(max_width_options, int): raise TypeError(f"`max_width` value of `{max_width}` is not a valid int") else: max_width = max_width_options + + marker = "*" if is_index else " " first_col = pretty_print(f" {marker} {name} ", col_width) - if var.dims: - dims_str = "({}) ".format(", ".join(map(str, var.dims))) + + if variable.dims: + dims_str = "({}) ".format(", ".join(map(str, variable.dims))) else: dims_str = "" - front_str = f"{first_col}{dims_str}{var.dtype} " + front_str = f"{first_col}{dims_str}{variable.dtype} " values_width = max_width - len(front_str) - values_str = inline_variable_array_repr(var, values_width) + values_str = inline_variable_array_repr(variable, values_width) return front_str + values_str -def _summarize_coord_multiindex(coord, col_width, marker): - first_col = pretty_print(f" {marker} {coord.name} ", col_width) - return f"{first_col}({str(coord.dims[0])}) MultiIndex" - - -def _summarize_coord_levels(coord, col_width, marker="-"): - if len(coord) > 100 and col_width < len(coord): - n_values = col_width - indices = list(range(0, n_values)) + list(range(-n_values, 0)) - subset = coord[indices] - else: - subset = coord - - return "\n".join( - summarize_variable( - lname, subset.get_level_variable(lname), col_width, marker=marker - ) - for lname in subset.level_names - ) - - -def summarize_datavar(name, var, col_width): - return summarize_variable(name, var.variable, col_width) - - -def summarize_coord(name: Hashable, var, col_width: int): - is_index = name in var.dims - marker = "*" if is_index else " " - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - return "\n".join( - [ - _summarize_coord_multiindex(coord, col_width, marker), - _summarize_coord_levels(coord, col_width), - ] - ) - return summarize_variable(name, var.variable, col_width, marker) - - def summarize_attr(key, value, col_width=None): """Summary for __repr__ - use ``X.attrs[key]`` for full value.""" # Indent key and add ':', then right-pad if col_width is not None @@ -359,23 +325,6 @@ def summarize_attr(key, value, col_width=None): EMPTY_REPR = " *empty*" -def _get_col_items(mapping): - """Get all column items to format, including both keys of `mapping` - and MultiIndex levels if any. - """ - from .variable import IndexVariable - - col_items = [] - for k, v in mapping.items(): - col_items.append(k) - var = getattr(v, "variable", v) - if isinstance(var, IndexVariable): - level_names = var.to_index_variable().level_names - if level_names is not None: - col_items += list(level_names) - return col_items - - def _calculate_col_width(col_items): max_name_length = max(len(str(s)) for s in col_items) if col_items else 0 col_width = max(max_name_length, 7) + 6 @@ -383,10 +332,21 @@ def _calculate_col_width(col_items): def _mapping_repr( - mapping, title, summarizer, expand_option_name, col_width=None, max_rows=None + mapping, + title, + summarizer, + expand_option_name, + col_width=None, + max_rows=None, + indexes=None, ): if col_width is None: col_width = _calculate_col_width(mapping) + + summarizer_kwargs = defaultdict(dict) + if indexes is not None: + summarizer_kwargs = {k: {"is_index": k in indexes} for k in mapping} + summary = [f"{title}:"] if mapping: len_mapping = len(mapping) @@ -396,15 +356,22 @@ def _mapping_repr( summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = calc_max_rows_first(max_rows) keys = list(mapping.keys()) - summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]] + summary += [ + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[:first_rows] + ] if max_rows > 1: last_rows = calc_max_rows_last(max_rows) summary += [pretty_print(" ...", col_width) + " ..."] summary += [ - summarizer(k, mapping[k], col_width) for k in keys[-last_rows:] + summarizer(k, mapping[k], col_width, **summarizer_kwargs[k]) + for k in keys[-last_rows:] ] else: - summary += [summarizer(k, v, col_width) for k, v in mapping.items()] + summary += [ + summarizer(k, v, col_width, **summarizer_kwargs[k]) + for k, v in mapping.items() + ] else: summary += [EMPTY_REPR] return "\n".join(summary) @@ -413,7 +380,7 @@ def _mapping_repr( data_vars_repr = functools.partial( _mapping_repr, title="Data variables", - summarizer=summarize_datavar, + summarizer=summarize_variable, expand_option_name="display_expand_data_vars", ) @@ -428,21 +395,25 @@ def _mapping_repr( def coords_repr(coords, col_width=None, max_rows=None): if col_width is None: - col_width = _calculate_col_width(_get_col_items(coords)) + col_width = _calculate_col_width(coords) return _mapping_repr( coords, title="Coordinates", - summarizer=summarize_coord, + summarizer=summarize_variable, expand_option_name="display_expand_coords", col_width=col_width, + indexes=coords.xindexes, max_rows=max_rows, ) def indexes_repr(indexes): - summary = [] - for k, v in indexes.items(): - summary.append(wrap_indent(repr(v), f"{k}: ")) + summary = ["Indexes:"] + if indexes: + for k, v in indexes.items(): + summary.append(wrap_indent(repr(v), f"{k}: ")) + else: + summary += [EMPTY_REPR] return "\n".join(summary) @@ -604,7 +575,7 @@ def array_repr(arr): if hasattr(arr, "coords"): if arr.coords: - col_width = _calculate_col_width(_get_col_items(arr.coords)) + col_width = _calculate_col_width(arr.coords) summary.append( coords_repr(arr.coords, col_width=col_width, max_rows=max_rows) ) @@ -624,7 +595,7 @@ def array_repr(arr): def dataset_repr(ds): summary = [f""] - col_width = _calculate_col_width(_get_col_items(ds.variables)) + col_width = _calculate_col_width(ds.variables) max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) @@ -655,9 +626,20 @@ def diff_dim_summary(a, b): return "" -def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): - def extra_items_repr(extra_keys, mapping, ab_side): - extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] +def _diff_mapping_repr( + a_mapping, + b_mapping, + compat, + title, + summarizer, + col_width=None, + a_indexes=None, + b_indexes=None, +): + def extra_items_repr(extra_keys, mapping, ab_side, kwargs): + extra_repr = [ + summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys + ] if extra_repr: header = f"{title} only on the {ab_side} object:" return [header] + extra_repr @@ -671,6 +653,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): diff_items = [] + a_summarizer_kwargs = defaultdict(dict) + if a_indexes is not None: + a_summarizer_kwargs = {k: {"is_index": k in a_indexes} for k in a_mapping} + b_summarizer_kwargs = defaultdict(dict) + if b_indexes is not None: + b_summarizer_kwargs = {k: {"is_index": k in b_indexes} for k in b_mapping} + for k in a_keys & b_keys: try: # compare xarray variable @@ -690,7 +679,8 @@ def extra_items_repr(extra_keys, mapping, ab_side): if not compatible: temp = [ - summarizer(k, vars[k], col_width) for vars in (a_mapping, b_mapping) + summarizer(k, a_mapping[k], col_width, **a_summarizer_kwargs[k]), + summarizer(k, b_mapping[k], col_width, **b_summarizer_kwargs[k]), ] if compat == "identical" and is_variable: @@ -712,19 +702,29 @@ def extra_items_repr(extra_keys, mapping, ab_side): if diff_items: summary += [f"Differing {title.lower()}:"] + diff_items - summary += extra_items_repr(a_keys - b_keys, a_mapping, "left") - summary += extra_items_repr(b_keys - a_keys, b_mapping, "right") + summary += extra_items_repr(a_keys - b_keys, a_mapping, "left", a_summarizer_kwargs) + summary += extra_items_repr( + b_keys - a_keys, b_mapping, "right", b_summarizer_kwargs + ) return "\n".join(summary) -diff_coords_repr = functools.partial( - _diff_mapping_repr, title="Coordinates", summarizer=summarize_coord -) +def diff_coords_repr(a, b, compat, col_width=None): + return _diff_mapping_repr( + a, + b, + compat, + "Coordinates", + summarize_variable, + col_width=col_width, + a_indexes=a.indexes, + b_indexes=b.indexes, + ) diff_data_vars_repr = functools.partial( - _diff_mapping_repr, title="Data variables", summarizer=summarize_datavar + _diff_mapping_repr, title="Data variables", summarizer=summarize_variable ) @@ -786,9 +786,7 @@ def diff_dataset_repr(a, b, compat): ) ] - col_width = _calculate_col_width( - set(_get_col_items(a.variables) + _get_col_items(b.variables)) - ) + col_width = _calculate_col_width(set(list(a.variables) + list(b.variables))) summary.append(diff_dim_summary(a, b)) summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 36c252f276e..db62466a8d3 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -31,12 +31,12 @@ def short_data_repr_html(array): return f"
{text}
" -def format_dims(dims, coord_names): +def format_dims(dims, dims_with_index): if not dims: return "" dim_css_map = { - k: " class='xr-has-index'" if k in coord_names else "" for k, v in dims.items() + dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims } dims_li = "".join( @@ -66,38 +66,7 @@ def _icon(icon_name): ) -def _summarize_coord_multiindex(name, coord): - preview = f"({', '.join(escape(l) for l in coord.level_names)})" - return summarize_variable( - name, coord, is_index=True, dtype="MultiIndex", preview=preview - ) - - -def summarize_coord(name, var): - is_index = name in var.dims - if is_index: - coord = var.variable.to_index_variable() - if coord.level_names is not None: - coords = {name: _summarize_coord_multiindex(name, coord)} - for lname in coord.level_names: - var = coord.get_level_variable(lname) - coords[lname] = summarize_variable(lname, var) - return coords - - return {name: summarize_variable(name, var, is_index)} - - -def summarize_coords(variables): - coords = {} - for k, v in variables.items(): - coords.update(**summarize_coord(k, v)) - - vars_li = "".join(f"
  • {v}
  • " for v in coords.values()) - - return f"
      {vars_li}
    " - - -def summarize_variable(name, var, is_index=False, dtype=None, preview=None): +def summarize_variable(name, var, is_index=False, dtype=None): variable = var.variable if hasattr(var, "variable") else var cssclass_idx = " class='xr-has-index'" if is_index else "" @@ -110,7 +79,7 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): data_id = "data-" + str(uuid.uuid4()) disabled = "" if len(var.attrs) else "disabled" - preview = preview or escape(inline_variable_array_repr(variable, 35)) + preview = escape(inline_variable_array_repr(variable, 35)) attrs_ul = summarize_attrs(var.attrs) data_repr = short_data_repr_html(variable) @@ -134,6 +103,17 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): ) +def summarize_coords(variables): + li_items = [] + for k, v in variables.items(): + li_content = summarize_variable(k, v, is_index=k in variables.xindexes) + li_items.append(f"
  • {li_content}
  • ") + + vars_li = "".join(li_items) + + return f"
      {vars_li}
    " + + def summarize_vars(variables): vars_li = "".join( f"
  • {summarize_variable(k, v)}
  • " @@ -184,7 +164,7 @@ def _mapping_section( def dim_section(obj): - dim_list = format_dims(obj.dims, list(obj.coords)) + dim_list = format_dims(obj.dims, obj.xindexes.dims) return collapsible_section( "Dimensions", inline_details=dim_list, enabled=False, collapsed=True @@ -265,15 +245,18 @@ def _obj_repr(obj, header_components, sections): def array_repr(arr): dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + if hasattr(arr, "xindexes"): + indexed_dims = arr.xindexes.dims + else: + indexed_dims = {} obj_type = f"xarray.{type(arr).__name__}" arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" - coord_names = list(arr.coords) if hasattr(arr, "coords") else [] header_components = [ f"
    {obj_type}
    ", f"
    {arr_name}
    ", - format_dims(dims, coord_names), + format_dims(dims, indexed_dims), ] sections = [array_section(arr)] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 68db14c29be..3c26c2129ae 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -12,7 +12,7 @@ from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat -from .indexes import propagate_indexes +from .indexes import create_default_index_implicit, filter_indexes_from_coords from .options import _get_keep_attrs from .pycompat import integer_types from .utils import ( @@ -23,7 +23,7 @@ peek_at, safe_cast_to_index, ) -from .variable import IndexVariable, Variable, as_variable +from .variable import IndexVariable, Variable def check_reduce_dims(reduce_dims, dimensions): @@ -57,6 +57,8 @@ def unique_value_groups(ar, sort=True): the corresponding value in `unique_values`. """ inverse, values = pd.factorize(ar, sort=sort) + if isinstance(values, pd.MultiIndex): + values.names = ar.names groups = [[] for _ in range(len(values))] for n, g in enumerate(inverse): if g >= 0: @@ -472,6 +474,7 @@ def _infer_concat_args(self, applied_example): (dim,) = coord.dims if isinstance(coord, _DummyGroup): coord = None + coord = getattr(coord, "variable", coord) return coord, dim, positions def _binary_op(self, other, f, reflexive=False): @@ -522,7 +525,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] - obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims) + obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj def fillna(self, value): @@ -766,6 +769,8 @@ def _concat_shortcut(self, applied, dim, positions=None): # speed things up, but it's not very interpretable and there are much # faster alternatives (e.g., doing the grouped aggregation in a # compiled language) + # TODO: benbovy - explicit indexes: this fast implementation doesn't + # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) reordered = _maybe_reorder(stacked, dim, positions) return self._obj._replace_maybe_drop_dims(reordered) @@ -857,13 +862,11 @@ def _combine(self, applied, shortcut=False): if isinstance(combined, type(self._obj)): # only restore dimension order for arrays combined = self._restore_dim_order(combined) - # assign coord when the applied function does not return that coord + # assign coord and index when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - if shortcut: - coord_var = as_variable(coord) - combined._coords[coord.name] = coord_var - else: - combined.coords[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, coords=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined @@ -991,7 +994,9 @@ def _combine(self, applied): combined = _maybe_reorder(combined, dim, positions) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: - combined[coord.name] = coord + index, index_vars = create_default_index_implicit(coord) + indexes = {k: index for k in index_vars} + combined = combined._overwrite_indexes(indexes, variables=index_vars) combined = self._maybe_restore_empty_groups(combined) combined = self._maybe_unstack(combined) return combined diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b66fbdf6504..e02e1f569b2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,44 +1,70 @@ +from __future__ import annotations + import collections.abc +import copy +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, Dict, + Generic, Hashable, Iterable, Iterator, Mapping, - Optional, Sequence, - Tuple, - Union, + TypeVar, + cast, ) import numpy as np import pandas as pd -from . import formatting, utils -from .indexing import ( - LazilyIndexedArray, - PandasIndexingAdapter, - PandasMultiIndexingAdapter, -) -from .utils import is_dict_like, is_scalar +from . import formatting, nputils, utils +from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter +from .types import T_Index +from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: - from .variable import IndexVariable, Variable + from .variable import Variable -IndexVars = Dict[Hashable, "IndexVariable"] +IndexVars = Dict[Any, "Variable"] class Index: """Base class inherited by all xarray-compatible indexes.""" @classmethod - def from_variables( - cls, variables: Mapping[Any, "Variable"] - ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover + def from_variables(cls, variables: Mapping[Any, Variable]) -> Index: + raise NotImplementedError() + + @classmethod + def concat( + cls: type[T_Index], + indexes: Sequence[T_Index], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> T_Index: + raise NotImplementedError() + + @classmethod + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Index: + raise NotImplementedError( + f"{cls!r} cannot be used for creating an index of stacked coordinates" + ) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: raise NotImplementedError() + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + if variables is not None: + # pass through + return dict(**variables) + else: + return {} + def to_pandas_index(self) -> pd.Index: """Cast this xarray index to a pandas.Index object or raise a TypeError if this is not supported. @@ -47,27 +73,54 @@ def to_pandas_index(self) -> pd.Index: pandas.Index object. """ - raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") - def query( - self, labels: Dict[Hashable, Any] - ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover - raise NotImplementedError() + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Index | None: + return None + + def sel(self, labels: dict[Any, Any]) -> IndexSelResult: + raise NotImplementedError(f"{self!r} doesn't support label-based selection") + + def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: + raise NotImplementedError( + f"{self!r} doesn't support alignment with inner/outer join method" + ) + + def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") def equals(self, other): # pragma: no cover raise NotImplementedError() - def union(self, other): # pragma: no cover - raise NotImplementedError() + def roll(self, shifts: Mapping[Any, int]) -> Index | None: + return None - def intersection(self, other): # pragma: no cover - raise NotImplementedError() + def rename( + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] + ) -> Index: + return self - def copy(self, deep: bool = True): # pragma: no cover - raise NotImplementedError() + def __copy__(self) -> Index: + return self.copy(deep=False) + + def __deepcopy__(self, memo=None) -> Index: + # memo does nothing but is required for compatibility with + # copy.deepcopy + return self.copy(deep=True) + + def copy(self, deep: bool = True) -> Index: + cls = self.__class__ + copied = cls.__new__(cls) + if deep: + for k, v in self.__dict__.items(): + setattr(copied, k, copy.deepcopy(v)) + else: + copied.__dict__.update(self.__dict__) + return copied def __getitem__(self, indexer: Any): - # if not implemented, index will be dropped from the Dataset or DataArray raise NotImplementedError() @@ -134,6 +187,23 @@ def _is_nested_tuple(possible_tuple): ) +def normalize_label(value, dtype=None) -> np.ndarray: + if getattr(value, "ndim", 1) <= 1: + value = _asarray_tuplesafe(value) + if dtype is not None and dtype.kind == "f" and value.dtype.kind != "b": + # pd.Index built from coordinate with float precision != 64 + # see https://github.com/pydata/xarray/pull/3153 for details + # bypass coercing dtype for boolean indexers (ignore index) + # see https://github.com/pydata/xarray/issues/5727 + value = np.asarray(value, dtype=dtype) + return value + + +def as_scalar(value: np.ndarray): + # see https://github.com/pydata/xarray/pull/4292 for details + return value[()] if value.dtype.kind in "mM" else value.item() + + def get_indexer_nd(index, labels, method=None, tolerance=None): """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels @@ -147,16 +217,37 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): class PandasIndex(Index): """Wrap a pandas.Index as an xarray compatible index.""" - __slots__ = ("index", "dim") + index: pd.Index + dim: Hashable + coord_dtype: Any - def __init__(self, array: Any, dim: Hashable): - self.index = utils.safe_cast_to_index(array) + __slots__ = ("index", "dim", "coord_dtype") + + def __init__(self, array: Any, dim: Hashable, coord_dtype: Any = None): + # make a shallow copy: cheap and because the index name may be updated + # here or in other constructors (cannot use pd.Index.rename as this + # constructor is also called from PandasMultiIndex) + index = utils.safe_cast_to_index(array).copy() + + if index.name is None: + index.name = dim + + self.index = index self.dim = dim - @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): - from .variable import IndexVariable + if coord_dtype is None: + coord_dtype = get_valid_numpy_dtype(index) + self.coord_dtype = coord_dtype + def _replace(self, index, dim=None, coord_dtype=None): + if dim is None: + dim = self.dim + if coord_dtype is None: + coord_dtype = self.coord_dtype + return type(self)(index, dim, coord_dtype) + + @classmethod + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasIndex: if len(variables) != 1: raise ValueError( f"PandasIndex only accepts one variable, found {len(variables)} variables" @@ -172,35 +263,113 @@ def from_variables(cls, variables: Mapping[Any, "Variable"]): dim = var.dims[0] - obj = cls(var.data, dim) + # TODO: (benbovy - explicit indexes): add __index__ to ExplicitlyIndexesNDArrayMixin? + # this could be eventually used by Variable.to_index() and would remove the need to perform + # the checks below. - data = PandasIndexingAdapter(obj.index) - index_var = IndexVariable( - dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True - ) + # preserve wrapped pd.Index (if any) + data = getattr(var._data, "array", var.data) + # multi-index level variable: get level index + if isinstance(var._data, PandasMultiIndexingAdapter): + level = var._data.level + if level is not None: + data = var._data.array.get_level_values(level) + + obj = cls(data, dim, coord_dtype=var.dtype) + assert not isinstance(obj.index, pd.MultiIndex) + obj.index.name = name + + return obj + + @staticmethod + def _concat_indexes(indexes, dim, positions=None) -> pd.Index: + new_pd_index: pd.Index + + if not indexes: + new_pd_index = pd.Index([]) + else: + if not all(idx.dim == dim for idx in indexes): + dims = ",".join({f"{idx.dim!r}" for idx in indexes}) + raise ValueError( + f"Cannot concatenate along dimension {dim!r} indexes with " + f"dimensions: {dims}" + ) + pd_indexes = [idx.index for idx in indexes] + new_pd_index = pd_indexes[0].append(pd_indexes[1:]) - return obj, {name: index_var} + if positions is not None: + indices = nputils.inverse_permutation(np.concatenate(positions)) + new_pd_index = new_pd_index.take(indices) + + return new_pd_index @classmethod - def from_pandas_index(cls, index: pd.Index, dim: Hashable): + def concat( + cls, + indexes: Sequence[PandasIndex], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> PandasIndex: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + coord_dtype = None + else: + coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) + + return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: from .variable import IndexVariable - if index.name is None: - name = dim - index = index.copy() - index.name = dim - else: - name = index.name + name = self.index.name + attrs: Mapping[Hashable, Any] | None + encoding: Mapping[Hashable, Any] | None - data = PandasIndexingAdapter(index) - index_var = IndexVariable(dim, data, fastpath=True) + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None - return cls(index, dim), {name: index_var} + data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype) + var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding) + return {name: var} def to_pandas_index(self) -> pd.Index: return self.index - def query(self, labels, method=None, tolerance=None): + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> PandasIndex | None: + from .variable import Variable + + indxr = indexers[self.dim] + if isinstance(indxr, Variable): + if indxr.dims != (self.dim,): + # can't preserve a index if result has new dimensions + return None + else: + indxr = indxr.data + if not isinstance(indxr, slice) and is_scalar(indxr): + # scalar indexer: drop index + return None + + return self._replace(self.index[indxr]) + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from .dataarray import DataArray + from .variable import Variable + + if method is not None and not isinstance(method, str): + raise TypeError("``method`` must be a string") + assert len(labels) == 1 coord_name, label = next(iter(labels.items())) @@ -212,147 +381,423 @@ def query(self, labels, method=None, tolerance=None): "a dimension that does not have a MultiIndex" ) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() + label_array = normalize_label(label, dtype=self.coord_dtype) + if label_array.ndim == 0: + label_value = as_scalar(label_array) if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( - "'method' is not a valid kwarg when indexing using a CategoricalIndex." + "'method' is not supported when indexing using a CategoricalIndex." ) if tolerance is not None: raise ValueError( - "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + "'tolerance' is not supported when indexing using a CategoricalIndex." ) indexer = self.index.get_loc(label_value) else: if method is not None: - indexer = get_indexer_nd(self.index, label, method, tolerance) + indexer = get_indexer_nd( + self.index, label_array, method, tolerance + ) if np.any(indexer < 0): raise KeyError( f"not all values found in index {coord_name!r}" ) else: indexer = self.index.get_loc(label_value) - elif label.dtype.kind == "b": - indexer = label + elif label_array.dtype.kind == "b": + indexer = label_array else: - indexer = get_indexer_nd(self.index, label, method, tolerance) + indexer = get_indexer_nd(self.index, label_array, method, tolerance) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") - return indexer, None + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + indexer = DataArray(indexer, coords=label._coords, dims=label.dims) - def equals(self, other): - return self.index.equals(other.index) + return IndexSelResult({self.dim: indexer}) - def union(self, other): - new_index = self.index.union(other.index) - return type(self)(new_index, self.dim) + def equals(self, other: Index): + if not isinstance(other, PandasIndex): + return False + return self.index.equals(other.index) and self.dim == other.dim - def intersection(self, other): - new_index = self.index.intersection(other.index) - return type(self)(new_index, self.dim) + def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + if how == "outer": + index = self.index.union(other.index) + else: + # how = "inner" + index = self.index.intersection(other.index) + + coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype) + return type(self)(index, self.dim, coord_dtype=coord_dtype) + + def reindex_like( + self, other: PandasIndex, method=None, tolerance=None + ) -> dict[Hashable, Any]: + if not self.index.is_unique: + raise ValueError( + f"cannot reindex or align along dimension {self.dim!r} because the " + "(pandas) index has duplicate values" + ) + + return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)} + + def roll(self, shifts: Mapping[Any, int]) -> PandasIndex: + shift = shifts[self.dim] % self.index.shape[0] + + if shift != 0: + new_pd_idx = self.index[-shift:].append(self.index[:-shift]) + else: + new_pd_idx = self.index[:] + + return self._replace(new_pd_idx) + + def rename(self, name_dict, dims_dict): + if self.index.name not in name_dict and self.dim not in dims_dict: + return self + + new_name = name_dict.get(self.index.name, self.index.name) + index = self.index.rename(new_name) + new_dim = dims_dict.get(self.dim, self.dim) + return self._replace(index, dim=new_dim) def copy(self, deep=True): - return type(self)(self.index.copy(deep=deep), self.dim) + if deep: + index = self.index.copy(deep=True) + else: + # index will be copied in constructor + index = self.index + return self._replace(index) def __getitem__(self, indexer: Any): - return type(self)(self.index[indexer], self.dim) - + return self._replace(self.index[indexer]) -def _create_variables_from_multiindex(index, dim, level_meta=None): - from .variable import IndexVariable - if level_meta is None: - level_meta = {} +def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal"): + """Check that all multi-index variable candidates are 1-dimensional and + either share the same (single) dimension or each have a different dimension. - variables = {} + """ + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") - dim_coord_adapter = PandasMultiIndexingAdapter(index) - variables[dim] = IndexVariable( - dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True - ) + dims = {var.dims for var in variables.values()} - for level in index.names: - meta = level_meta.get(level, {}) - data = PandasMultiIndexingAdapter( - index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter + if all_dims == "equal" and len(dims) > 1: + raise ValueError( + "unmatched dimensions for multi-index variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) ) - variables[level] = IndexVariable( - dim, - data, - attrs=meta.get("attrs"), - encoding=meta.get("encoding"), - fastpath=True, + + if all_dims == "different" and len(dims) < len(variables): + raise ValueError( + "conflicting dimensions for multi-index product variables " + + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()]) ) - return variables + +def remove_unused_levels_categories(index: pd.Index) -> pd.Index: + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + else: + level = level[index.codes[i]] + levels.append(level) + # TODO: calling from_array() reorders MultiIndex levels. It would + # be best to avoid this, if possible, e.g., by using + # MultiIndex.remove_unused_levels() (which does not reorder) on the + # part of the MultiIndex that is not categorical, or by fixing this + # upstream in pandas. + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index class PandasMultiIndex(PandasIndex): - @classmethod - def from_variables(cls, variables: Mapping[Any, "Variable"]): - if any([var.ndim != 1 for var in variables.values()]): - raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") + """Wrap a pandas.MultiIndex as an xarray compatible index.""" - dims = {var.dims for var in variables.values()} - if len(dims) != 1: - raise ValueError( - "unmatched dimensions for variables " - + ",".join([str(k) for k in variables]) - ) + level_coords_dtype: dict[str, Any] + + __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") + + def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None): + super().__init__(array, dim) + + # default index level names + names = [] + for i, idx in enumerate(self.index.levels): + name = idx.name or f"{dim}_level_{i}" + if name == dim: + raise ValueError( + f"conflicting multi-index level name {name!r} with dimension {dim!r}" + ) + names.append(name) + self.index.names = names + + if level_coords_dtype is None: + level_coords_dtype = { + idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels + } + self.level_coords_dtype = level_coords_dtype + + def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex: + if dim is None: + dim = self.dim + index.name = dim + if level_coords_dtype is None: + level_coords_dtype = self.level_coords_dtype + return type(self)(index, dim, level_coords_dtype) + + @classmethod + def from_variables(cls, variables: Mapping[Any, Variable]) -> PandasMultiIndex: + _check_dim_compat(variables) + dim = next(iter(variables.values())).dims[0] - dim = next(iter(dims))[0] index = pd.MultiIndex.from_arrays( [var.values for var in variables.values()], names=variables.keys() ) - obj = cls(index, dim) + index.name = dim + level_coords_dtype = {name: var.dtype for name, var in variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) - level_meta = { - name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} - for name, var in variables.items() - } - index_vars = _create_variables_from_multiindex( - index, dim, level_meta=level_meta - ) + return obj - return obj, index_vars + @classmethod + def concat( # type: ignore[override] + cls, + indexes: Sequence[PandasMultiIndex], + dim: Hashable, + positions: Iterable[Iterable[int]] = None, + ) -> PandasMultiIndex: + new_pd_index = cls._concat_indexes(indexes, dim, positions) + + if not indexes: + level_coords_dtype = None + else: + level_coords_dtype = {} + for name in indexes[0].level_coords_dtype: + level_coords_dtype[name] = np.result_type( + *[idx.level_coords_dtype[name] for idx in indexes] + ) + + return cls(new_pd_index, dim=dim, level_coords_dtype=level_coords_dtype) + + @classmethod + def stack( + cls, variables: Mapping[Any, Variable], dim: Hashable + ) -> PandasMultiIndex: + """Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a + new dimension. + + Level variables must have a dimension distinct from each other. + + Keeps levels the same (doesn't refactorize them) so that it gives back the original + labels after a stack/unstack roundtrip. + + """ + _check_dim_compat(variables, all_dims="different") + + level_indexes = [utils.safe_cast_to_index(var) for var in variables.values()] + for name, idx in zip(variables, level_indexes): + if isinstance(idx, pd.MultiIndex): + raise ValueError( + f"cannot create a multi-index along stacked dimension {dim!r} " + f"from variable {name!r} that wraps a multi-index" + ) + + split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] + + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + level_coords_dtype = {k: var.dtype for k, var in variables.items()} + + return cls(index, dim, level_coords_dtype=level_coords_dtype) + + def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: + clean_index = remove_unused_levels_categories(self.index) + + new_indexes: dict[Hashable, Index] = {} + for name, lev in zip(clean_index.names, clean_index.levels): + idx = PandasIndex( + lev.copy(), name, coord_dtype=self.level_coords_dtype[name] + ) + new_indexes[name] = idx + + return new_indexes, clean_index @classmethod - def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): - index_vars = _create_variables_from_multiindex(index, dim) - return cls(index, dim), index_vars + def from_variables_maybe_expand( + cls, + dim: Hashable, + current_variables: Mapping[Any, Variable], + variables: Mapping[Any, Variable], + ) -> tuple[PandasMultiIndex, IndexVars]: + """Create a new multi-index maybe by expanding an existing one with + new variables as index levels. + + The index and its corresponding coordinates may be created along a new dimension. + """ + names: list[Hashable] = [] + codes: list[list[int]] = [] + levels: list[list[int]] = [] + level_variables: dict[Any, Variable] = {} + + _check_dim_compat({**current_variables, **variables}) + + if len(current_variables) > 1: + # expand from an existing multi-index + data = cast( + PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data + ) + current_index = data.array + names.extend(current_index.names) + codes.extend(current_index.codes) + levels.extend(current_index.levels) + for name in current_index.names: + level_variables[name] = current_variables[name] + + elif len(current_variables) == 1: + # expand from one 1D variable (no multi-index): convert it to an index level + var = next(iter(current_variables.values())) + new_var_name = f"{dim}_level_0" + names.append(new_var_name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[new_var_name] = var + + for name, var in variables.items(): + names.append(name) + cat = pd.Categorical(var.values, ordered=True) + codes.append(cat.codes) + levels.append(cat.categories) + level_variables[name] = var + + index = pd.MultiIndex(levels, codes, names=names) + level_coords_dtype = {k: var.dtype for k, var in level_variables.items()} + obj = cls(index, dim, level_coords_dtype=level_coords_dtype) + index_vars = obj.create_variables(level_variables) + + return obj, index_vars + + def keep_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex | PandasIndex: + """Keep only the provided levels and return a new multi-index with its + corresponding coordinates. + + """ + index = self.index.droplevel( + [k for k in self.index.names if k not in level_variables] + ) + + if isinstance(index, pd.MultiIndex): + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + else: + return PandasIndex( + index, self.dim, coord_dtype=self.level_coords_dtype[index.name] + ) + + def reorder_levels( + self, level_variables: Mapping[Any, Variable] + ) -> PandasMultiIndex: + """Re-arrange index levels using input order and return a new multi-index with + its corresponding coordinates. + + """ + index = self.index.reorder_levels(level_variables.keys()) + level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} + return self._replace(index, level_coords_dtype=level_coords_dtype) + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from .variable import IndexVariable + + if variables is None: + variables = {} + + index_vars: IndexVars = {} + for name in (self.dim,) + self.index.names: + if name == self.dim: + level = None + dtype = None + else: + level = name + dtype = self.level_coords_dtype[name] + + var = variables.get(name, None) + if var is not None: + attrs = var.attrs + encoding = var.encoding + else: + attrs = {} + encoding = {} + + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + index_vars[name] = IndexVariable( + self.dim, + data, + attrs=attrs, + encoding=encoding, + fastpath=True, + ) + + return index_vars + + def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: + from .dataarray import DataArray + from .variable import Variable - def query(self, labels, method=None, tolerance=None): if method is not None or tolerance is not None: raise ValueError( "multi-index does not support ``method`` and ``tolerance``" ) new_index = None + scalar_coord_values = {} # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): - is_nested_vals = _is_nested_tuple(tuple(labels.values())) - if len(labels) == self.index.nlevels and not is_nested_vals: - indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) + label_values = {} + for k, v in labels.items(): + label_array = normalize_label(v, dtype=self.level_coords_dtype[k]) + try: + label_values[k] = as_scalar(label_array) + except ValueError: + # label should be an item not an array-like + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + + has_slice = any([isinstance(v, slice) for v in label_values.values()]) + + if len(label_values) == self.index.nlevels and not has_slice: + indexer = self.index.get_loc( + tuple(label_values[k] for k in self.index.names) + ) else: - for k, v in labels.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - f"available along coordinate {k!r} (multi-index level)" - ) indexer, new_index = self.index.get_loc_level( - tuple(labels.values()), level=tuple(labels.keys()) + tuple(label_values.values()), level=tuple(label_values.keys()) ) + scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen if indexer.dtype.kind == "b" and indexer.sum() == 0: raise KeyError(f"{labels} not found") @@ -376,7 +821,7 @@ def query(self, labels, method=None, tolerance=None): raise ValueError( f"invalid multi-index level names {invalid_levels}" ) - return self.query(label) + return self.sel(label) elif isinstance(label, slice): indexer = _query_slice(self.index, label, coord_name) @@ -387,94 +832,383 @@ def query(self, labels, method=None, tolerance=None): elif len(label) == self.index.nlevels: indexer = self.index.get_loc(label) else: - indexer, new_index = self.index.get_loc_level( - label, level=list(range(len(label))) - ) + levels = [self.index.names[i] for i in range(len(label))] + indexer, new_index = self.index.get_loc_level(label, level=levels) + scalar_coord_values.update({k: v for k, v in zip(levels, label)}) else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - indexer, new_index = self.index.get_loc_level(label.item(), level=0) - elif label.dtype.kind == "b": - indexer = label + label_array = normalize_label(label) + if label_array.ndim == 0: + label_value = as_scalar(label_array) + indexer, new_index = self.index.get_loc_level(label_value, level=0) + scalar_coord_values[self.index.names[0]] = label_value + elif label_array.dtype.kind == "b": + indexer = label_array else: - if label.ndim > 1: + if label_array.ndim > 1: raise ValueError( "Vectorized selection is not available along " f"coordinate {coord_name!r} with a multi-index" ) - indexer = get_indexer_nd(self.index, label) + indexer = get_indexer_nd(self.index, label_array) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") + # attach dimension names and/or coordinates to positional indexer + if isinstance(label, Variable): + indexer = Variable(label.dims, indexer) + elif isinstance(label, DataArray): + # do not include label-indexer DataArray coordinates that conflict + # with the level names of this index + coords = { + k: v + for k, v in label._coords.items() + if k not in self.index.names + } + indexer = DataArray(indexer, coords=coords, dims=label.dims) + if new_index is not None: if isinstance(new_index, pd.MultiIndex): - new_index, new_vars = PandasMultiIndex.from_pandas_index( - new_index, self.dim + level_coords_dtype = { + k: self.level_coords_dtype[k] for k in new_index.names + } + new_index = self._replace( + new_index, level_coords_dtype=level_coords_dtype ) + dims_dict = {} + drop_coords = [] else: - new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim) - return indexer, (new_index, new_vars) + new_index = PandasIndex( + new_index, + new_index.name, + coord_dtype=self.level_coords_dtype[new_index.name], + ) + dims_dict = {self.dim: new_index.index.name} + drop_coords = [self.dim] + + # variable(s) attrs and encoding metadata are propagated + # when replacing the indexes in the resulting xarray object + new_vars = new_index.create_variables() + indexes = cast(Dict[Any, Index], {k: new_index for k in new_vars}) + + # add scalar variable for each dropped level + variables = new_vars + for name, val in scalar_coord_values.items(): + variables[name] = Variable([], val) + + return IndexSelResult( + {self.dim: indexer}, + indexes=indexes, + variables=variables, + drop_indexes=list(scalar_coord_values), + drop_coords=drop_coords, + rename_dims=dims_dict, + ) + + else: + return IndexSelResult({self.dim: indexer}) + + def join(self, other, how: str = "inner"): + if how == "outer": + # bug in pandas? need to reset index.name + other_index = other.index.copy() + other_index.name = None + index = self.index.union(other_index) + index.name = self.dim else: - return indexer, None + # how = "inner" + index = self.index.intersection(other.index) + level_coords_dtype = { + k: np.result_type(lvl_dtype, other.level_coords_dtype[k]) + for k, lvl_dtype in self.level_coords_dtype.items() + } + + return type(self)(index, self.dim, level_coords_dtype=level_coords_dtype) + + def rename(self, name_dict, dims_dict): + if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict: + return self + + # pandas 1.3.0: could simply do `self.index.rename(names_dict)` + new_names = [name_dict.get(k, k) for k in self.index.names] + index = self.index.rename(new_names) + + new_dim = dims_dict.get(self.dim, self.dim) + new_level_coords_dtype = { + k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + } + return self._replace( + index, dim=new_dim, level_coords_dtype=new_level_coords_dtype + ) + + +def create_default_index_implicit( + dim_variable: Variable, + all_variables: Mapping | Iterable[Hashable] | None = None, +) -> tuple[PandasIndex, IndexVars]: + """Create a default index from a dimension variable. + + Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex, + otherwise create a PandasIndex (note that this will become obsolete once we + depreciate implcitly passing a pandas.MultiIndex as a coordinate). -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: - """ - Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ - if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() - # if it contains CategoricalIndex, we need to remove unused categories - # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): - levels = [] - for i, level in enumerate(index.levels): - if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() - else: - level = level[index.codes[i]] - levels.append(level) - # TODO: calling from_array() reorders MultiIndex levels. It would - # be best to avoid this, if possible, e.g., by using - # MultiIndex.remove_unused_levels() (which does not reorder) on the - # part of the MultiIndex that is not categorical, or by fixing this - # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() - return index + if all_variables is None: + all_variables = {} + if not isinstance(all_variables, Mapping): + all_variables = {k: None for k in all_variables} + + name = dim_variable.dims[0] + array = getattr(dim_variable._data, "array", None) + index: PandasIndex + + if isinstance(array, pd.MultiIndex): + index = PandasMultiIndex(array, name) + index_vars = index.create_variables() + # check for conflict between level names and variable names + duplicate_names = [k for k in index_vars if k in all_variables and k != name] + if duplicate_names: + # dirty workaround for an edge case where both the dimension + # coordinate and the level coordinates are given for the same + # multi-index object => do not raise an error + # TODO: remove this check when removing the multi-index dimension coordinate + if len(duplicate_names) < len(index.index.names): + conflict = True + else: + duplicate_vars = [all_variables[k] for k in duplicate_names] + conflict = any( + v is None or not dim_variable.equals(v) for v in duplicate_vars + ) + + if conflict: + conflict_str = "\n".join(duplicate_names) + raise ValueError( + f"conflicting MultiIndex level / variable name(s):\n{conflict_str}" + ) + else: + dim_var = {name: dim_variable} + index = PandasIndex.from_variables(dim_var) + index_vars = index.create_variables(dim_var) + return index, index_vars -class Indexes(collections.abc.Mapping): - """Immutable proxy for Dataset or DataArrary indexes.""" - __slots__ = ("_indexes",) +# generic type that represents either a pandas or an xarray index +T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex", Index, pd.Index) + + +class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): + """Immutable proxy for Dataset or DataArrary indexes. + + Keys are coordinate names and values may correspond to either pandas or + xarray indexes. + + Also provides some utility methods. + + """ + + _indexes: dict[Any, T_PandasOrXarrayIndex] + _variables: dict[Any, Variable] + + __slots__ = ( + "_indexes", + "_variables", + "_dims", + "__coord_name_id", + "__id_index", + "__id_coord_names", + ) - def __init__(self, indexes: Mapping[Any, Union[pd.Index, Index]]) -> None: - """Not for public consumption. + def __init__( + self, + indexes: dict[Any, T_PandasOrXarrayIndex], + variables: dict[Any, Variable], + ): + """Constructor not for public consumption. Parameters ---------- - indexes : Dict[Any, pandas.Index] + indexes : dict Indexes held by this object. + variables : dict + Indexed coordinate variables in this object. + """ self._indexes = indexes + self._variables = variables + + self._dims: Mapping[Hashable, int] | None = None + self.__coord_name_id: dict[Any, int] | None = None + self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None + self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None + + @property + def _coord_name_id(self) -> dict[Any, int]: + if self.__coord_name_id is None: + self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()} + return self.__coord_name_id + + @property + def _id_index(self) -> dict[int, T_PandasOrXarrayIndex]: + if self.__id_index is None: + self.__id_index = {id(idx): idx for idx in self.get_unique()} + return self.__id_index + + @property + def _id_coord_names(self) -> dict[int, tuple[Hashable, ...]]: + if self.__id_coord_names is None: + id_coord_names: Mapping[int, list[Hashable]] = defaultdict(list) + for k, v in self._coord_name_id.items(): + id_coord_names[v].append(k) + self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()} + + return self.__id_coord_names + + @property + def variables(self) -> Mapping[Hashable, Variable]: + return Frozen(self._variables) + + @property + def dims(self) -> Mapping[Hashable, int]: + from .variable import calculate_dimensions + + if self._dims is None: + self._dims = calculate_dimensions(self._variables) + + return Frozen(self._dims) + + def get_unique(self) -> list[T_PandasOrXarrayIndex]: + """Return a list of unique indexes, preserving order.""" + + unique_indexes: list[T_PandasOrXarrayIndex] = [] + seen: set[T_PandasOrXarrayIndex] = set() + + for index in self._indexes.values(): + if index not in seen: + unique_indexes.append(index) + seen.add(index) + + return unique_indexes + + def is_multi(self, key: Hashable) -> bool: + """Return True if ``key`` maps to a multi-coordinate index, + False otherwise. + """ + return len(self._id_coord_names[self._coord_name_id[key]]) > 1 + + def get_all_coords( + self, key: Hashable, errors: str = "raise" + ) -> dict[Hashable, Variable]: + """Return all coordinates having the same index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + coords : dict + A dictionary of all coordinate variables having the same index. + + """ + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + + if key not in self._indexes: + if errors == "raise": + raise ValueError(f"no index found for {key!r} coordinate") + else: + return {} + + all_coord_names = self._id_coord_names[self._coord_name_id[key]] + return {k: self._variables[k] for k in all_coord_names} + + def get_all_dims( + self, key: Hashable, errors: str = "raise" + ) -> Mapping[Hashable, int]: + """Return all dimensions shared by an index. + + Parameters + ---------- + key : hashable + Index key. + errors : {"raise", "ignore"}, optional + If "raise", raises a ValueError if `key` is not in indexes. + If "ignore", an empty tuple is returned instead. + + Returns + ------- + dims : dict + A dictionary of all dimensions shared by an index. + + """ + from .variable import calculate_dimensions + + return calculate_dimensions(self.get_all_coords(key, errors=errors)) + + def group_by_index( + self, + ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]: + """Returns a list of unique indexes and their corresponding coordinates.""" + + index_coords = [] + + for i in self._id_index: + index = self._id_index[i] + coords = {k: self._variables[k] for k in self._id_coord_names[i]} + index_coords.append((index, coords)) + + return index_coords - def __iter__(self) -> Iterator[pd.Index]: + def to_pandas_indexes(self) -> Indexes[pd.Index]: + """Returns an immutable proxy for Dataset or DataArrary pandas indexes. + + Raises an error if this proxy contains indexes that cannot be coerced to + pandas.Index objects. + + """ + indexes: dict[Hashable, pd.Index] = {} + + for k, idx in self._indexes.items(): + if isinstance(idx, pd.Index): + indexes[k] = idx + elif isinstance(idx, Index): + indexes[k] = idx.to_pandas_index() + + return Indexes(indexes, self._variables) + + def copy_indexes( + self, deep: bool = True + ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: + """Return a new dictionary with copies of indexes, preserving + unique indexes. + + """ + new_indexes = {} + new_index_vars = {} + for idx, coords in self.group_by_index(): + new_idx = idx.copy(deep=deep) + idx_vars = idx.create_variables(coords) + new_indexes.update({k: new_idx for k in coords}) + new_index_vars.update(idx_vars) + + return new_indexes, new_index_vars + + def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]: return iter(self._indexes) - def __len__(self): + def __len__(self) -> int: return len(self._indexes) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self._indexes - def __getitem__(self, key) -> pd.Index: + def __getitem__(self, key) -> T_PandasOrXarrayIndex: return self._indexes[key] def __repr__(self): @@ -482,8 +1216,8 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, "Variable"], dims: Iterable -) -> Dict[Hashable, Index]: + coords: Mapping[Any, Variable], dims: Iterable +) -> dict[Hashable, Index]: """Default indexes for a Dataset/DataArray. Parameters @@ -498,76 +1232,167 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return {key: coords[key]._to_xindex() for key in dims if key in coords} + indexes: dict[Hashable, Index] = {} + coord_names = set(coords) + for name, var in coords.items(): + if name in dims: + index, index_vars = create_default_index_implicit(var, coords) + if set(index_vars) <= coord_names: + indexes.update({k: index for k in index_vars}) + + return indexes -def isel_variable_and_index( - name: Hashable, - variable: "Variable", - index: Index, - indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]], -) -> Tuple["Variable", Optional[Index]]: - """Index a Variable and an Index together. - If the index cannot be indexed, return None (it will be dropped). +def indexes_equal( + index: Index, + other_index: Index, + variable: Variable, + other_variable: Variable, + cache: dict[tuple[int, int], bool | None] = None, +) -> bool: + """Check if two indexes are equal, possibly with cached results. - (note: not compatible yet with xarray flexible indexes). + If the two indexes are not of the same type or they do not implement + equality, fallback to coordinate labels equality check. """ - from .variable import Variable + if cache is None: + # dummy cache + cache = {} + + key = (id(index), id(other_index)) + equal: bool | None = None + + if key not in cache: + if type(index) is type(other_index): + try: + equal = index.equals(other_index) + except NotImplementedError: + equal = None + else: + cache[key] = equal + else: + equal = None + else: + equal = cache[key] - if not indexers: - # nothing to index - return variable.copy(deep=False), index + if equal is None: + equal = variable.equals(other_variable) - if len(variable.dims) > 1: - raise NotImplementedError( - "indexing multi-dimensional variable with indexes is not supported yet" - ) + return cast(bool, equal) - new_variable = variable.isel(indexers) - if new_variable.dims != (name,): - # can't preserve a index if result has new dimensions - return new_variable, None +def indexes_all_equal( + elements: Sequence[tuple[Index, dict[Hashable, Variable]]] +) -> bool: + """Check if indexes are all equal. - # we need to compute the new index - (dim,) = variable.dims - indexer = indexers[dim] - if isinstance(indexer, Variable): - indexer = indexer.data - try: - new_index = index[indexer] - except NotImplementedError: - new_index = None + If they are not of the same type or they do not implement this check, check + if their coordinate variables are all equal instead. - return new_variable, new_index + """ + def check_variables(): + variables = [e[1] for e in elements] + return any( + not variables[0][k].equals(other_vars[k]) + for other_vars in variables[1:] + for k in variables[0] + ) -def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: - """Roll an pandas.Index.""" - pd_index = index.to_pandas_index() - count %= pd_index.shape[0] - if count != 0: - new_idx = pd_index[-count:].append(pd_index[:-count]) + indexes = [e[0] for e in elements] + same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) + if same_type: + try: + not_equal = any( + not indexes[0].equals(other_idx) for other_idx in indexes[1:] + ) + except NotImplementedError: + not_equal = check_variables() else: - new_idx = pd_index[:] - return PandasIndex(new_idx, index.dim) + not_equal = check_variables() + + return not not_equal + + +def _apply_indexes( + indexes: Indexes[Index], + args: Mapping[Any, Any], + func: str, +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()} + new_index_variables: dict[Hashable, Variable] = {} + + for index, index_vars in indexes.group_by_index(): + index_dims = {d for var in index_vars.values() for d in var.dims} + index_args = {k: v for k, v in args.items() if k in index_dims} + if index_args: + new_index = getattr(index, func)(index_args) + if new_index is not None: + new_indexes.update({k: new_index for k in index_vars}) + new_index_vars = new_index.create_variables(index_vars) + new_index_variables.update(new_index_vars) + else: + for k in index_vars: + new_indexes.pop(k, None) + return new_indexes, new_index_variables -def propagate_indexes( - indexes: Optional[Dict[Hashable, Index]], exclude: Optional[Any] = None -) -> Optional[Dict[Hashable, Index]]: - """Creates new indexes dict from existing dict optionally excluding some dimensions.""" - if exclude is None: - exclude = () - if is_scalar(exclude): - exclude = (exclude,) +def isel_indexes( + indexes: Indexes[Index], + indexers: Mapping[Any, Any], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_indexes(indexes, indexers, "isel") - if indexes is not None: - new_indexes = {k: v for k, v in indexes.items() if k not in exclude} - else: - new_indexes = None # type: ignore[assignment] - return new_indexes +def roll_indexes( + indexes: Indexes[Index], + shifts: Mapping[Any, int], +) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: + return _apply_indexes(indexes, shifts, "roll") + + +def filter_indexes_from_coords( + indexes: Mapping[Any, Index], + filtered_coord_names: set, +) -> dict[Hashable, Index]: + """Filter index items given a (sub)set of coordinate names. + + Drop all multi-coordinate related index items for any key missing in the set + of coordinate names. + + """ + filtered_indexes: dict[Any, Index] = dict(**indexes) + + index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) + for name, idx in indexes.items(): + index_coord_names[id(idx)].add(name) + + for idx_coord_names in index_coord_names.values(): + if not idx_coord_names <= filtered_coord_names: + for k in idx_coord_names: + del filtered_indexes[k] + + return filtered_indexes + + +def assert_no_index_corrupted( + indexes: Indexes[Index], + coord_names: set[Hashable], +) -> None: + """Assert removing coordinates will not corrupt indexes.""" + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of coordinate names to remove + for index, index_coords in indexes.group_by_index(): + common_names = set(index_coords) & coord_names + if common_names and len(common_names) != len(index_coords): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coords) + raise ValueError( + f"cannot remove coordinate(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{index}" + ) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 17d026baa59..c797e6652de 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1,10 +1,23 @@ import enum import functools import operator -from collections import defaultdict +from collections import Counter, defaultdict from contextlib import suppress +from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import numpy as np import pandas as pd @@ -19,7 +32,176 @@ is_duck_dask_array, sparse_array_type, ) -from .utils import maybe_cast_to_coords_dtype +from .types import T_Xarray +from .utils import either_dict_or_kwargs, get_valid_numpy_dtype + +if TYPE_CHECKING: + from .indexes import Index + from .variable import Variable + + +@dataclass +class IndexSelResult: + """Index query results. + + Attributes + ---------- + dim_indexers: dict + A dictionary where keys are array dimensions and values are + location-based indexers. + indexes: dict, optional + New indexes to replace in the resulting DataArray or Dataset. + variables : dict, optional + New variables to replace in the resulting DataArray or Dataset. + drop_coords : list, optional + Coordinate(s) to drop in the resulting DataArray or Dataset. + drop_indexes : list, optional + Index(es) to drop in the resulting DataArray or Dataset. + rename_dims : dict, optional + A dictionary in the form ``{old_dim: new_dim}`` for dimension(s) to + rename in the resulting DataArray or Dataset. + + """ + + dim_indexers: Dict[Any, Any] + indexes: Dict[Any, "Index"] = field(default_factory=dict) + variables: Dict[Any, "Variable"] = field(default_factory=dict) + drop_coords: List[Hashable] = field(default_factory=list) + drop_indexes: List[Hashable] = field(default_factory=list) + rename_dims: Dict[Any, Hashable] = field(default_factory=dict) + + def as_tuple(self): + """Unlike ``dataclasses.astuple``, return a shallow copy. + + See https://stackoverflow.com/a/51802661 + + """ + return ( + self.dim_indexers, + self.indexes, + self.variables, + self.drop_coords, + self.drop_indexes, + self.rename_dims, + ) + + +def merge_sel_results(results: List[IndexSelResult]) -> IndexSelResult: + all_dims_count = Counter([dim for res in results for dim in res.dim_indexers]) + duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1} + + if duplicate_dims: + # TODO: this message is not right when combining indexe(s) queries with + # location-based indexing on a dimension with no dimension-coordinate (failback) + fmt_dims = [ + f"{dim!r}: {count} indexes involved" + for dim, count in duplicate_dims.items() + ] + raise ValueError( + "Xarray does not support label-based selection with more than one index " + "over the following dimension(s):\n" + + "\n".join(fmt_dims) + + "\nSuggestion: use a multi-index for each of those dimension(s)." + ) + + dim_indexers = {} + indexes = {} + variables = {} + drop_coords = [] + drop_indexes = [] + rename_dims = {} + + for res in results: + dim_indexers.update(res.dim_indexers) + indexes.update(res.indexes) + variables.update(res.variables) + drop_coords += res.drop_coords + drop_indexes += res.drop_indexes + rename_dims.update(res.rename_dims) + + return IndexSelResult( + dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims + ) + + +def group_indexers_by_index( + obj: T_Xarray, + indexers: Mapping[Any, Any], + options: Mapping[str, Any], +) -> List[Tuple["Index", Dict[Any, Any]]]: + """Returns a list of unique indexes and their corresponding indexers.""" + unique_indexes = {} + grouped_indexers: Mapping[Union[int, None], Dict] = defaultdict(dict) + + for key, label in indexers.items(): + index: "Index" = obj.xindexes.get(key, None) + + if index is not None: + index_id = id(index) + unique_indexes[index_id] = index + grouped_indexers[index_id][key] = label + elif key in obj.coords: + raise KeyError(f"no index found for coordinate {key!r}") + elif key not in obj.dims: + raise KeyError(f"{key!r} is not a valid dimension or coordinate") + elif len(options): + raise ValueError( + f"cannot supply selection options {options!r} for dimension {key!r}" + "that has no associated coordinate or index" + ) + else: + # key is a dimension without a "dimension-coordinate" + # failback to location-based selection + # TODO: depreciate this implicit behavior and suggest using isel instead? + unique_indexes[None] = None + grouped_indexers[None][key] = label + + return [(unique_indexes[k], grouped_indexers[k]) for k in unique_indexes] + + +def map_index_queries( + obj: T_Xarray, + indexers: Mapping[Any, Any], + method=None, + tolerance=None, + **indexers_kwargs: Any, +) -> IndexSelResult: + """Execute index queries from a DataArray / Dataset and label-based indexers + and return the (merged) query results. + + """ + from .dataarray import DataArray + + # TODO benbovy - flexible indexes: remove when custom index options are available + if method is None and tolerance is None: + options = {} + else: + options = {"method": method, "tolerance": tolerance} + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries") + grouped_indexers = group_indexers_by_index(obj, indexers, options) + + results = [] + for index, labels in grouped_indexers: + if index is None: + # forward dimension indexers with no index/coordinate + results.append(IndexSelResult(labels)) + else: + results.append(index.sel(labels, **options)) # type: ignore[call-arg] + + merged = merge_sel_results(results) + + # drop dimension coordinates found in dimension indexers + # (also drop multi-index if any) + # (.sel() already ensures alignment) + for k, v in merged.dim_indexers.items(): + if isinstance(v, DataArray): + if k in v._indexes: + v = v.reset_index(k) + drop_coords = [name for name in v._coords if name in merged.dim_indexers] + merged.dim_indexers[k] = v.drop_vars(drop_coords) + + return merged def expanded_indexer(key, ndim): @@ -56,80 +238,6 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): - # TODO: benbovy - flexible indexes: indexers are still grouped by dimension - # - Make xarray.Index hashable so that it can be used as key in a mapping? - indexes = {} - grouped_indexers = defaultdict(dict) - - # TODO: data_obj.xindexes should eventually return the PandasIndex instance - # for each multi-index levels - xindexes = dict(data_obj.xindexes) - for level, dim in data_obj._level_coords.items(): - xindexes[level] = xindexes[dim] - - for key, label in indexers.items(): - try: - index = xindexes[key] - coord = data_obj.coords[key] - dim = coord.dims[0] - if dim not in indexes: - indexes[dim] = index - - label = maybe_cast_to_coords_dtype(label, coord.dtype) - grouped_indexers[dim][key] = label - - except KeyError: - if key in data_obj.coords: - raise KeyError(f"no index found for coordinate {key}") - elif key not in data_obj.dims: - raise KeyError(f"{key} is not a valid dimension or coordinate") - # key is a dimension without coordinate: we'll reuse the provided labels - elif method is not None or tolerance is not None: - raise ValueError( - "cannot supply ``method`` or ``tolerance`` " - "when the indexed dimension does not have " - "an associated coordinate." - ) - grouped_indexers[None][key] = label - - return indexes, grouped_indexers - - -def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): - """Given an xarray data object and label based indexers, return a mapping - of equivalent location based indexers. Also return a mapping of updated - pandas index objects (in case of multi-index level drop). - """ - if method is not None and not isinstance(method, str): - raise TypeError("``method`` must be a string") - - pos_indexers = {} - new_indexes = {} - - indexes, grouped_indexers = group_indexers_by_index( - data_obj, indexers, method, tolerance - ) - - forward_pos_indexers = grouped_indexers.pop(None, None) - if forward_pos_indexers is not None: - for dim, label in forward_pos_indexers.items(): - pos_indexers[dim] = label - - for dim, index in indexes.items(): - labels = grouped_indexers[dim] - idxr, new_idx = index.query(labels, method=method, tolerance=tolerance) - pos_indexers[dim] = idxr - if new_idx is not None: - new_indexes[dim] = new_idx - - # TODO: benbovy - flexible indexes: support the following cases: - # - an index query returns positional indexers over multiple dimensions - # - check/combine positional indexers returned by multiple indexes over the same dimension - - return pos_indexers, new_indexes - - def _normalize_slice(sl, size): """Ensure that given slice only contains positive start and stop values (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])""" @@ -1272,18 +1380,9 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): self.array = utils.safe_cast_to_index(array) if dtype is None: - if isinstance(array, pd.PeriodIndex): - dtype_ = np.dtype("O") - elif hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype_ = array.categories.dtype - elif not utils.is_valid_numpy_dtype(array.dtype): - dtype_ = np.dtype("O") - else: - dtype_ = array.dtype + self._dtype = get_valid_numpy_dtype(array) else: - dtype_ = np.dtype(dtype) # type: ignore[assignment] - self._dtype = dtype_ + self._dtype = np.dtype(dtype) # type: ignore[assignment] @property def dtype(self) -> np.dtype: @@ -1303,6 +1402,26 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: def shape(self) -> Tuple[int]: return (len(self.array),) + def _convert_scalar(self, item): + if item is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + item = np.datetime64("NaT", "ns") + elif isinstance(item, timedelta): + item = np.timedelta64(getattr(item, "value", item), "ns") + elif isinstance(item, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + item = np.asarray(item.to_datetime64()) + elif self.dtype != object: + item = np.asarray(item, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + return utils.to_0d_array(item) + def __getitem__( self, indexer ) -> Union[ @@ -1319,34 +1438,14 @@ def __getitem__( (key,) = key if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(self.array.values)[indexer] + return NumpyIndexingAdapter(np.asarray(self))[indexer] result = self.array[key] if isinstance(result, pd.Index): - result = type(self)(result, dtype=self.dtype) + return type(self)(result, dtype=self.dtype) else: - # result is a scalar - if result is pd.NaT: - # work around the impossibility of casting NaT with asarray - # note: it probably would be better in general to return - # pd.Timestamp rather np.than datetime64 but this is easier - # (for now) - result = np.datetime64("NaT", "ns") - elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, "value", result), "ns") - elif isinstance(result, pd.Timestamp): - # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 - # numpy fails to convert pd.Timestamp to np.datetime64[ns] - result = np.asarray(result.to_datetime64()) - elif self.dtype != object: - result = np.asarray(result, dtype=self.dtype) - - # as for numpy.ndarray indexing, we always want the result to be - # a NumPy array. - result = utils.to_0d_array(result) - - return result + return self._convert_scalar(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1382,11 +1481,9 @@ def __init__( array: pd.MultiIndex, dtype: DTypeLike = None, level: Optional[str] = None, - adapter: Optional[PandasIndexingAdapter] = None, ): super().__init__(array, dtype) self.level = level - self.adapter = adapter def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if self.level is not None: @@ -1394,16 +1491,47 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: else: return super().__array__(dtype) - @functools.lru_cache(1) + def _convert_scalar(self, item): + if isinstance(item, tuple) and self.level is not None: + idx = tuple(self.array.names).index(self.level) + item = item[idx] + return super()._convert_scalar(item) + def __getitem__(self, indexer): - if self.adapter is None: - return super().__getitem__(indexer) - else: - return self.adapter.__getitem__(indexer) + result = super().__getitem__(indexer) + if isinstance(result, type(self)): + result.level = self.level + + return result def __repr__(self) -> str: if self.level is None: return super().__repr__() else: - props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + props = ( + f"(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + ) return f"{type(self).__name__}{props}" + + def _repr_inline_(self, max_width) -> str: + # special implementation to speed-up the repr for big multi-indexes + if self.level is None: + return "MultiIndex" + else: + from .formatting import format_array_flat + + if self.size > 100 and max_width < self.size: + n_values = max_width + indices = np.concatenate( + [np.arange(0, n_values), np.arange(-n_values, 0)] + ) + subset = self[OuterIndexer((indices,))] + else: + subset = self + + return format_array_flat(np.asarray(subset), max_width) + + def copy(self, deep: bool = True) -> "PandasMultiIndexingAdapter": + # see PandasIndexingAdapter.copy + array = self.array.copy(deep=True) if deep else self.array + return type(self)(array, self._dtype, self.level) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index e5407ae79c3..b428d4ae958 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from typing import ( TYPE_CHECKING, AbstractSet, @@ -19,9 +20,15 @@ from . import dtypes from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .indexes import Index, PandasIndex +from .indexes import ( + Index, + Indexes, + create_default_index_implicit, + filter_indexes_from_coords, + indexes_equal, +) from .utils import Frozen, compat_dict_union, dict_equiv, equivalent -from .variable import Variable, as_variable, assert_unique_multiindex_level_names +from .variable import Variable, as_variable, calculate_dimensions if TYPE_CHECKING: from .coordinates import Coordinates @@ -32,9 +39,9 @@ ArrayLike = Any VariableLike = Union[ ArrayLike, - Tuple[DimsLike, ArrayLike], - Tuple[DimsLike, ArrayLike, Mapping], - Tuple[DimsLike, ArrayLike, Mapping, Mapping], + tuple[DimsLike, ArrayLike], + tuple[DimsLike, ArrayLike, Mapping], + tuple[DimsLike, ArrayLike, Mapping, Mapping], ] XarrayValue = Union[DataArray, Variable, VariableLike] DatasetLike = Union[Dataset, Mapping[Any, XarrayValue]] @@ -165,11 +172,44 @@ def _assert_compat_valid(compat): MergeElement = Tuple[Variable, Optional[Index]] +def _assert_prioritized_valid( + grouped: dict[Hashable, list[MergeElement]], + prioritized: Mapping[Any, MergeElement], +) -> None: + """Make sure that elements given in prioritized will not corrupt any + index given in grouped. + """ + prioritized_names = set(prioritized) + grouped_by_index: dict[int, list[Hashable]] = defaultdict(list) + indexes: dict[int, Index] = {} + + for name, elements_list in grouped.items(): + for (_, index) in elements_list: + if index is not None: + grouped_by_index[id(index)].append(name) + indexes[id(index)] = index + + # An index may be corrupted when the set of its corresponding coordinate name(s) + # partially overlaps the set of names given in prioritized + for index_id, index_coord_names in grouped_by_index.items(): + index_names = set(index_coord_names) + common_names = index_names & prioritized_names + if common_names and len(common_names) != len(index_names): + common_names_str = ", ".join(f"{k!r}" for k in common_names) + index_names_str = ", ".join(f"{k!r}" for k in index_coord_names) + raise ValueError( + f"cannot set or update variable(s) {common_names_str}, which would corrupt " + f"the following index built from coordinates {index_names_str}:\n" + f"{indexes[index_id]!r}" + ) + + def merge_collected( grouped: dict[Hashable, list[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", - combine_attrs="override", + combine_attrs: str | None = "override", + equals: dict[Hashable, bool] = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -179,6 +219,8 @@ def merge_collected( prioritized : mapping compat : str Type of equality check to use when checking for conflicts. + equals : mapping, optional + corresponding to result of compat test Returns ------- @@ -188,11 +230,15 @@ def merge_collected( """ if prioritized is None: prioritized = {} + if equals is None: + equals = {} _assert_compat_valid(compat) + _assert_prioritized_valid(grouped, prioritized) merged_vars: dict[Hashable, Variable] = {} merged_indexes: dict[Hashable, Index] = {} + index_cmp_cache: dict[tuple[int, int], bool | None] = {} for name, elements_list in grouped.items(): if name in prioritized: @@ -206,17 +252,19 @@ def merge_collected( for variable, index in elements_list if index is not None ] - if indexed_elements: # TODO(shoyer): consider adjusting this logic. Are we really # OK throwing away variable without an index in favor of # indexed variables, without even checking if values match? variable, index = indexed_elements[0] - for _, other_index in indexed_elements[1:]: - if not index.equals(other_index): + for other_var, other_index in indexed_elements[1:]: + if not indexes_equal( + index, other_index, variable, other_var, index_cmp_cache + ): raise MergeError( - f"conflicting values for index {name!r} on objects to be " - f"combined:\nfirst value: {index!r}\nsecond value: {other_index!r}" + f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n" + f"first index: {index!r}\nsecond index: {other_index!r}\n" + f"first variable: {variable!r}\nsecond variable: {other_var!r}\n" ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: @@ -234,7 +282,9 @@ def merge_collected( else: variables = [variable for variable, _ in elements_list] try: - merged_vars[name] = unique_variable(name, variables, compat) + merged_vars[name] = unique_variable( + name, variables, compat, equals.get(name, None) + ) except MergeError: if compat != "minimal": # we need more than "minimal" compatibility (for which @@ -251,6 +301,7 @@ def merge_collected( def collect_variables_and_indexes( list_of_mappings: list[DatasetLike], + indexes: Mapping[Any, Any] | None = None, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -260,15 +311,21 @@ def collect_variables_and_indexes( - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in an xarray.Variable - or an xarray.DataArray + + If a mapping of indexes is given, those indexes are assigned to all variables + with a matching key/name. + """ from .dataarray import DataArray from .dataset import Dataset - grouped: dict[Hashable, list[tuple[Variable, Index | None]]] = {} + if indexes is None: + indexes = {} + + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) def append(name, variable, index): - values = grouped.setdefault(name, []) - values.append((variable, index)) + grouped[name].append((variable, index)) def append_all(variables, indexes): for name, variable in variables.items(): @@ -276,27 +333,26 @@ def append_all(variables, indexes): for mapping in list_of_mappings: if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping.xindexes) + append_all(mapping.variables, mapping._indexes) continue for name, variable in mapping.items(): if isinstance(variable, DataArray): coords = variable._coords.copy() # use private API for speed - indexes = dict(variable.xindexes) + indexes = dict(variable._indexes) # explicitly overwritten variables should take precedence coords.pop(name, None) indexes.pop(name, None) append_all(coords, indexes) variable = as_variable(variable, name=name) - - if variable.dims == (name,): - idx_variable = variable.to_index_variable() - index = variable._to_xindex() - append(name, idx_variable, index) + if name in indexes: + append(name, variable, indexes[name]) + elif variable.dims == (name,): + idx, idx_vars = create_default_index_implicit(variable) + append_all(idx_vars, {k: idx for k in idx_vars}) else: - index = None - append(name, variable, index) + append(name, variable, None) return grouped @@ -305,14 +361,14 @@ def collect_from_coordinates( list_of_coords: list[Coordinates], ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes to be merged from Coordinate objects.""" - grouped: dict[Hashable, list[tuple[Variable, Index | None]]] = {} + grouped: dict[Hashable, list[MergeElement]] = defaultdict(list) for coords in list_of_coords: variables = coords.variables indexes = coords.xindexes for name, variable in variables.items(): - value = grouped.setdefault(name, []) - value.append((variable, indexes.get(name))) + grouped[name].append((variable, indexes.get(name))) + return grouped @@ -342,7 +398,14 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) + # TODO: indexes should probably be filtered in collected elements + # before merging them + merged_coords, merged_indexes = merge_collected( + filtered, prioritized, combine_attrs=combine_attrs + ) + merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords)) + + return merged_coords, merged_indexes def determine_coords( @@ -471,26 +534,56 @@ def merge_coords( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected(collected, prioritized, compat=compat) - assert_unique_multiindex_level_names(variables) return variables, out_indexes -def merge_data_and_coords(data, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): """Used in Dataset.__init__.""" - objects = [data, coords] + indexes, coords = _create_indexes_from_coords(coords, data_vars) + objects = [data_vars, coords] explicit_coords = coords.keys() - indexes = dict(_extract_indexes_from_coords(coords)) return merge_core( - objects, compat, join, explicit_coords=explicit_coords, indexes=indexes + objects, + compat, + join, + explicit_coords=explicit_coords, + indexes=Indexes(indexes, coords), ) -def _extract_indexes_from_coords(coords): - """Yields the name & index of valid indexes from a mapping of coords""" - for name, variable in coords.items(): - variable = as_variable(variable, name=name) +def _create_indexes_from_coords(coords, data_vars=None): + """Maybe create default indexes from a mapping of coordinates. + + Return those indexes and updated coordinates. + """ + all_variables = dict(coords) + if data_vars is not None: + all_variables.update(data_vars) + + indexes = {} + updated_coords = {} + + # this is needed for backward compatibility: when a pandas multi-index + # is given as data variable, it is promoted as index / level coordinates + # TODO: depreciate this implicit behavior + index_vars = { + k: v + for k, v in all_variables.items() + if k in coords or isinstance(v, pd.MultiIndex) + } + + for name, obj in index_vars.items(): + variable = as_variable(obj, name=name) + if variable.dims == (name,): - yield name, variable._to_xindex() + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + updated_coords.update(idx_vars) + all_variables.update(idx_vars) + else: + updated_coords[name] = obj + + return indexes, updated_coords def assert_valid_explicit_coords(variables, dims, explicit_coords): @@ -566,7 +659,7 @@ class _MergeResult(NamedTuple): variables: dict[Hashable, Variable] coord_names: set[Hashable] dims: dict[Hashable, int] - indexes: dict[Hashable, pd.Index] + indexes: dict[Hashable, Index] attrs: dict[Hashable, Any] @@ -621,7 +714,7 @@ def merge_core( MergeError if the merge cannot be done successfully. """ from .dataarray import DataArray - from .dataset import Dataset, calculate_dimensions + from .dataset import Dataset _assert_compat_valid(compat) @@ -629,13 +722,11 @@ def merge_core( aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - collected = collect_variables_and_indexes(aligned) - + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs ) - assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) @@ -870,7 +961,7 @@ def merge( >>> xr.merge([x, y, z], join="exact") Traceback (most recent call last): ... - ValueError: indexes along dimension 'lat' are not equal + ValueError: cannot align objects with join='exact' where ... Raises ------ @@ -976,18 +1067,9 @@ def dataset_update_method(dataset: Dataset, other: CoercibleMapping) -> _MergeRe if coord_names: other[key] = value.drop_vars(coord_names) - # use ds.coords and not ds.indexes, else str coords are cast to object - # TODO: benbovy - flexible indexes: make it work with any xarray index - indexes = {} - for key, index in dataset.xindexes.items(): - if isinstance(index, PandasIndex): - indexes[key] = dataset.coords[key] - else: - indexes[key] = index - return merge_core( [dataset, other], priority_arg=1, - indexes=indexes, # type: ignore + indexes=dataset.xindexes, combine_attrs="override", ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index c1776145e21..3d33631bebd 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -321,12 +321,10 @@ def interp_na( if not is_scalar(max_gap): raise ValueError("max_gap must be a scalar.") - # TODO: benbovy - flexible indexes: update when CFTimeIndex (and DatetimeIndex?) - # has its own class inheriting from xarray.Index if ( - dim in self.xindexes + dim in self._indexes and isinstance( - self.xindexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + self._indexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) ) and use_coordinate ): diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3f6bb34a36e..fd1f3f9e999 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -292,7 +292,7 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.xindexes.items(): + for name, index in result._indexes.items(): if name in expected["shapes"]: if result.sizes[name] != expected["shapes"][name]: raise ValueError( @@ -359,27 +359,27 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0].xindexes) + input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.xindexes) + input_indexes.update(arg._indexes) if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template.xindexes) + template_indexes = set(template._indexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.xindexes[k] for k in new_indexes}) + indexes.update({k: template._indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template.xindexes) + indexes = dict(template._indexes) if isinstance(template, DataArray): output_chunks = dict( zip(template.dims, template.chunks) # type: ignore[arg-type] @@ -558,7 +558,7 @@ def subset_dataset_to_block( attrs=template.attrs, ) - for index in result.xindexes: + for index in result._indexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding @@ -568,7 +568,7 @@ def subset_dataset_to_block( for dim in dims: if dim in output_chunks: var_chunks.append(output_chunks[dim]) - elif dim in result.xindexes: + elif dim in result._indexes: var_chunks.append((result.sizes[dim],)) elif dim in template.dims: # new unindexed dimension diff --git a/xarray/core/types.py b/xarray/core/types.py index 9f0f9eee54c..3f368501b25 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -9,6 +9,7 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, GroupBy + from .indexes import Index from .npcompat import ArrayLike from .variable import Variable @@ -21,6 +22,7 @@ T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Index = TypeVar("T_Index", bound="Index") # Maybe we rename this to T_Data or something less Fortran-y? T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") diff --git a/xarray/core/utils.py b/xarray/core/utils.py index da3a196a621..a0f5bfdcf27 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -22,7 +22,6 @@ Mapping, MutableMapping, MutableSet, - Sequence, TypeVar, cast, ) @@ -69,10 +68,24 @@ def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: return index -def maybe_cast_to_coords_dtype(label, coords_dtype): - if coords_dtype.kind == "f" and not isinstance(label, slice): - label = np.asarray(label, dtype=coords_dtype) - return label +def get_valid_numpy_dtype(array: np.ndarray | pd.Index): + """Return a numpy compatible dtype from either + a numpy array or a pandas.Index. + + Used for wrapping a pandas.Index as an xarray,Variable. + + """ + if isinstance(array, pd.PeriodIndex): + dtype = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype = array.categories.dtype # type: ignore[union-attr] + elif not is_valid_numpy_dtype(array.dtype): + dtype = np.dtype("O") + else: + dtype = array.dtype + + return dtype def maybe_coerce_to_str(index, original_coords): @@ -105,9 +118,14 @@ def safe_cast_to_index(array: Any) -> pd.Index: if isinstance(array, pd.Index): index = array elif hasattr(array, "to_index"): + # xarray Variable index = array.to_index() elif hasattr(array, "to_pandas_index"): + # xarray Index index = array.to_pandas_index() + elif hasattr(array, "array") and isinstance(array.array, pd.Index): + # xarray PandasIndexingAdapter + index = array.array else: kwargs = {} if hasattr(array, "dtype") and array.dtype.kind == "O": @@ -116,33 +134,6 @@ def safe_cast_to_index(array: Any) -> pd.Index: return _maybe_cast_to_cftimeindex(index) -def multiindex_from_product_levels( - levels: Sequence[pd.Index], names: Sequence[str] = None -) -> pd.MultiIndex: - """Creating a MultiIndex from a product without refactorizing levels. - - Keeping levels the same gives back the original labels when we unstack. - - Parameters - ---------- - levels : sequence of pd.Index - Values for each MultiIndex level. - names : sequence of str, optional - Names for each level. - - Returns - ------- - pandas.MultiIndex - """ - if any(not isinstance(lev, pd.Index) for lev in levels): - raise TypeError("levels must be a list of pd.Index objects") - - split_labels, levels = zip(*[lev.factorize() for lev in levels]) - labels_mesh = np.meshgrid(*split_labels, indexing="ij") - labels = [x.ravel() for x in labels_mesh] - return pd.MultiIndex(levels, labels, sortorder=0, names=names) - - def maybe_wrap_array(original, new_array): """Wrap a transformed array with __array_wrap__ if it can be done safely. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c8d46d20d46..a21cf8c2d97 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -4,7 +4,6 @@ import itertools import numbers import warnings -from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING, Any, Hashable, Mapping, Sequence @@ -17,7 +16,6 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexes import PandasIndex, PandasMultiIndex from .indexing import ( BasicIndexer, OuterIndexer, @@ -531,18 +529,6 @@ def to_index_variable(self): to_coord = utils.alias(to_index_variable, "to_coord") - def _to_xindex(self): - # temporary function used internally as a replacement of to_index() - # returns an xarray Index instance instead of a pd.Index instance - index_var = self.to_index_variable() - index = index_var.to_index() - dim = index_var.dims[0] - - if isinstance(index, pd.MultiIndex): - return PandasMultiIndex(index, dim) - else: - return PandasIndex(index, dim) - def to_index(self): """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() @@ -3007,37 +2993,27 @@ def concat( return Variable.concat(variables, dim, positions, shortcut, combine_attrs) -def assert_unique_multiindex_level_names(variables): - """Check for uniqueness of MultiIndex level names in all given - variables. +def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]: + """Calculate the dimensions corresponding to a set of variables. - Not public API. Used for checking consistency of DataArray and Dataset - objects. + Returns dictionary mapping from dimension names to sizes. Raises ValueError + if any of the dimension sizes conflict. """ - level_names = defaultdict(list) - all_level_names = set() - for var_name, var in variables.items(): - if isinstance(var._data, PandasIndexingAdapter): - idx_level_names = var.to_index_variable().level_names - if idx_level_names is not None: - for n in idx_level_names: - level_names[n].append(f"{n!r} ({var_name})") - if idx_level_names: - all_level_names.update(idx_level_names) - - for k, v in level_names.items(): - if k in variables: - v.append(f"({k})") - - duplicate_names = [v for v in level_names.values() if len(v) > 1] - if duplicate_names: - conflict_str = "\n".join(", ".join(v) for v in duplicate_names) - raise ValueError(f"conflicting MultiIndex level name(s):\n{conflict_str}") - # Check confliction between level names and dimensions GH:2299 - for k, v in variables.items(): - for d in v.dims: - if d in all_level_names: + dims: dict[Hashable, int] = {} + last_used = {} + scalar_vars = {k for k, v in variables.items() if not v.dims} + for k, var in variables.items(): + for dim, size in zip(var.dims, var.shape): + if dim in scalar_vars: + raise ValueError( + f"dimension {dim!r} already exists as a scalar variable" + ) + if dim not in dims: + dims[dim] = size + last_used[dim] = k + elif dims[dim] != size: raise ValueError( - "conflicting level / dimension names. {} " - "already exists as a level name.".format(d) + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" ) + return dims diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f09d1eb1853..d942f6656ba 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from ..core.indexes import PandasMultiIndex from ..core.options import OPTIONS from ..core.pycompat import DuckArrayModule from ..core.utils import is_scalar @@ -383,11 +384,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): _assert_valid_xy(darray, x, "x") _assert_valid_xy(darray, y, "y") - if ( - all(k in darray._level_coords for k in (x, y)) - and darray._level_coords[x] == darray._level_coords[y] - ): - raise ValueError("x and y cannot be levels of the same MultiIndex") + if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): + if isinstance(darray._indexes[x], PandasMultiIndex): + raise ValueError("x and y cannot be levels of the same MultiIndex") return x, y @@ -398,11 +397,13 @@ def _assert_valid_xy(darray, xy, name): """ # MultiIndex cannot be plotted; no point in allowing them here - multiindex = {darray._level_coords[lc] for lc in darray._level_coords} + multiindex_dims = { + idx.dim + for idx in darray.xindexes.get_unique() + if isinstance(idx, PandasMultiIndex) + } - valid_xy = ( - set(darray.dims) | set(darray.coords) | set(darray._level_coords) - ) - multiindex + valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims if xy not in valid_xy: valid_xy_str = "', '".join(sorted(valid_xy)) diff --git a/xarray/testing.py b/xarray/testing.py index 4369b828daf..0df34a60e73 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -4,11 +4,12 @@ from typing import Hashable, Set, Union import numpy as np +import pandas as pd from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import Index, default_indexes +from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -251,7 +252,9 @@ def assert_chunks_equal(a, b): assert left.chunks == right.chunks -def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): +def _assert_indexes_invariants_checks( + indexes, possible_coord_variables, dims, check_default=True +): assert isinstance(indexes, dict), indexes assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() @@ -262,11 +265,42 @@ def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): } assert indexes.keys() <= index_vars, (set(indexes), index_vars) - # Note: when we support non-default indexes, these checks should be opt-in - # only! - defaults = default_indexes(possible_coord_variables, dims) - assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) - assert all(v.equals(defaults[k]) for k, v in indexes.items()), (indexes, defaults) + # check pandas index wrappers vs. coordinate data adapters + for k, index in indexes.items(): + if isinstance(index, PandasIndex): + pd_index = index.index + var = possible_coord_variables[k] + assert (index.dim,) == var.dims, (pd_index, var) + if k == index.dim: + # skip multi-index levels here (checked below) + assert index.coord_dtype == var.dtype, (index.coord_dtype, var.dtype) + assert isinstance(var._data.array, pd.Index), var._data.array + # TODO: check identity instead of equality? + assert pd_index.equals(var._data.array), (pd_index, var) + if isinstance(index, PandasMultiIndex): + pd_index = index.index + for name in index.index.names: + assert name in possible_coord_variables, (pd_index, index_vars) + var = possible_coord_variables[name] + assert (index.dim,) == var.dims, (pd_index, var) + assert index.level_coords_dtype[name] == var.dtype, ( + index.level_coords_dtype[name], + var.dtype, + ) + assert isinstance(var._data.array, pd.MultiIndex), var._data.array + assert pd_index.equals(var._data.array), (pd_index, var) + # check all all levels are in `indexes` + assert name in indexes, (name, set(indexes)) + # index identity is used to find unique indexes in `indexes` + assert index is indexes[name], (pd_index, indexes[name].index) + + if check_default: + defaults = default_indexes(possible_coord_variables, dims) + assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults)) + assert all(v.equals(defaults[k]) for k, v in indexes.items()), ( + indexes, + defaults, + ) def _assert_variable_invariants(var: Variable, name: Hashable = None): @@ -285,7 +319,7 @@ def _assert_variable_invariants(var: Variable, name: Hashable = None): assert isinstance(var._attrs, (type(None), dict)), name_or_empty + (var._attrs,) -def _assert_dataarray_invariants(da: DataArray): +def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool): assert isinstance(da._variable, Variable), da._variable _assert_variable_invariants(da._variable) @@ -302,10 +336,12 @@ def _assert_dataarray_invariants(da: DataArray): _assert_variable_invariants(v, k) if da._indexes is not None: - _assert_indexes_invariants_checks(da._indexes, da._coords, da.dims) + _assert_indexes_invariants_checks( + da._indexes, da._coords, da.dims, check_default=check_default_indexes + ) -def _assert_dataset_invariants(ds: Dataset): +def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): assert isinstance(ds._variables, dict), type(ds._variables) assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables for k, v in ds._variables.items(): @@ -336,13 +372,17 @@ def _assert_dataset_invariants(ds: Dataset): } if ds._indexes is not None: - _assert_indexes_invariants_checks(ds._indexes, ds._variables, ds._dims) + _assert_indexes_invariants_checks( + ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes + ) assert isinstance(ds._encoding, (type(None), dict)) assert isinstance(ds._attrs, (type(None), dict)) -def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]): +def _assert_internal_invariants( + xarray_obj: Union[DataArray, Dataset, Variable], check_default_indexes: bool +): """Validate that an xarray object satisfies its own internal invariants. This exists for the benefit of xarray's own test suite, but may be useful @@ -352,9 +392,13 @@ def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]) if isinstance(xarray_obj, Variable): _assert_variable_invariants(xarray_obj) elif isinstance(xarray_obj, DataArray): - _assert_dataarray_invariants(xarray_obj) + _assert_dataarray_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) elif isinstance(xarray_obj, Dataset): - _assert_dataset_invariants(xarray_obj) + _assert_dataset_invariants( + xarray_obj, check_default_indexes=check_default_indexes + ) else: raise TypeError( "{} is not a supported type for xarray invariant checks".format( diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 00fec07f793..7872fec2e62 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -177,25 +177,25 @@ def assert_no_warnings(): # invariants -def assert_equal(a, b): +def assert_equal(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_equal(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_identical(a, b): +def assert_identical(a, b, check_default_indexes=True): __tracebackhide__ = True xarray.testing.assert_identical(a, b) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) -def assert_allclose(a, b, **kwargs): +def assert_allclose(a, b, check_default_indexes=True, **kwargs): __tracebackhide__ = True xarray.testing.assert_allclose(a, b, **kwargs) - xarray.testing._assert_internal_invariants(a) - xarray.testing._assert_internal_invariants(b) + xarray.testing._assert_internal_invariants(a, check_default_indexes) + xarray.testing._assert_internal_invariants(b, check_default_indexes) def create_test_data(seed=None, add_attrs=True): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4696c41552f..825c6f7130f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3290,7 +3290,9 @@ def test_open_mfdataset_exact_join_raises_error(self, combine, concat_dim, opt): with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): if combine == "by_coords": files.reverse() - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align objects.*join.*exact.*" + ): open_mfdataset( files, data_vars=opt, diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 6ba0c6a9be2..7e50e0d8b53 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -392,7 +392,7 @@ def test_combine_nested_join(self, join, expected): def test_combine_nested_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact"): combine_nested(objs, concat_dim="x", join="exact") def test_empty_input(self): @@ -747,7 +747,7 @@ def test_combine_coords_join(self, join, expected): def test_combine_coords_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): combine_nested(objs, concat_dim="x", join="exact") @pytest.mark.parametrize( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index dac3c17b1f1..6a86738ab2f 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -960,7 +960,7 @@ def test_dataset_join() -> None: ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) # by default, cannot have different labels - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*"): apply_ufunc(operator.add, ds0, ds1) with pytest.raises(TypeError, match=r"must supply"): apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") @@ -1892,7 +1892,7 @@ def test_dot_align_coords(use_dask) -> None: xr.testing.assert_allclose(expected, actual) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.dot(da_a, da_b) # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 8a37df62261..8abede64761 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -7,6 +7,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.indexes import PandasIndex from . import ( InaccessibleArray, @@ -259,7 +260,7 @@ def test_concat_join_kwarg(self) -> None: coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -465,7 +466,7 @@ def test_concat_dim_is_variable(self) -> None: def test_concat_multiindex(self) -> None: x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) - expected = Dataset({"x": x}) + expected = Dataset(coords={"x": x}) actual = concat( [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" ) @@ -639,7 +640,7 @@ def test_concat_join_kwarg(self) -> None: coords={"x": [0, 1], "y": [0]}, ) - with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"cannot align.*exact.*dimensions.*'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -783,3 +784,27 @@ def test_concat_typing_check() -> None: match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): concat([da, ds], dim="foo") + + +def test_concat_not_all_indexes() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + # ds2.x has no default index + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + + with pytest.raises( + ValueError, match=r"'x' must have either an index or no index in all datasets.*" + ): + concat([ds1, ds2], dim="x") + + +def test_concat_index_not_same_dim() -> None: + ds1 = Dataset(coords={"x": ("x", [1, 2])}) + ds2 = Dataset(coords={"x": ("y", [3, 4])}) + # TODO: use public API for setting a non-default index, when available + ds2._indexes["x"] = PandasIndex([3, 4], "y") + + with pytest.raises( + ValueError, + match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*", + ): + concat([ds1, ds2], dim="x") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 42d8df57cb7..872c0c6f1db 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -51,14 +51,14 @@ def assertLazyAnd(self, expected, actual, test): if isinstance(actual, Dataset): for k, v in actual.variables.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) elif isinstance(actual, DataArray): assert isinstance(actual.data, da.Array) for k, v in actual.coords.items(): - if k in actual.dims: + if k in actual.xindexes: assert isinstance(v.data, np.ndarray) else: assert isinstance(v.data, da.Array) @@ -1226,7 +1226,7 @@ def sumda(da1, da2): with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) - with pytest.raises(ValueError, match=r"indexes along dimension 'x' are not equal"): + with pytest.raises(ValueError, match=r"cannot align.*index.*are not equal"): xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) # reduction diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 438bdf8bdc3..65efb3a732c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -24,7 +24,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.indexes import Index, PandasIndex, propagate_indexes +from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.utils import is_scalar from xarray.tests import ( ReturnItem, @@ -93,9 +93,9 @@ def test_repr_multiindex(self): array([0, 1, 2, 3]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2""" + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2""" ) assert expected == repr(self.mda) @@ -111,9 +111,9 @@ def test_repr_multiindex_long(self): array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - - level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + * level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" ) assert expected == repr(mda_long) @@ -769,11 +769,6 @@ def test_contains(self): assert 1 in data_array assert 3 not in data_array - def test_attr_sources_multiindex(self): - # make sure attr-style access for multi-index levels - # returns DataArray objects - assert isinstance(self.mda.level_1, DataArray) - def test_pickle(self): data = DataArray(np.random.random((3, 3)), dims=("id", "time")) roundtripped = pickle.loads(pickle.dumps(data)) @@ -901,7 +896,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"], stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): da.isel( x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), @@ -1007,6 +1002,21 @@ def test_sel_float(self): assert_equal(expected_scalar, actual_scalar) assert_equal(expected_16, actual_16) + def test_sel_float_multiindex(self): + # regression test https://github.com/pydata/xarray/issues/5691 + # test multi-index created from coordinates, one with dtype=float32 + lvl1 = ["a", "a", "b", "b"] + lvl2 = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + da = xr.DataArray( + [1, 2, 3, 4], dims="x", coords={"lvl1": ("x", lvl1), "lvl2": ("x", lvl2)} + ) + da = da.set_index(x=["lvl1", "lvl2"]) + + actual = da.sel(lvl1="a", lvl2=0.1) + expected = da.isel(x=0) + + assert_equal(actual, expected) + def test_sel_no_index(self): array = DataArray(np.arange(10), dims="x") assert_identical(array[0], array.sel(x=0)) @@ -1261,7 +1271,7 @@ def test_selection_multiindex_from_level(self): data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) assert data.dims == ("xy",) actual = data.sel(y="a") - expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop_vars("y") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y") assert_equal(actual, expected) def test_virtual_default_coords(self): @@ -1311,13 +1321,15 @@ def test_coords(self): assert expected == actual del da.coords["x"] - da._indexes = propagate_indexes(da._indexes, exclude="x") + da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords)) expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda["level_1"] = np.arange(4) - self.mda.coords["level_1"] = np.arange(4) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = ("x", np.arange(4)) + self.mda.coords["level_1"] = ("x", np.arange(4)) def test_coords_to_index(self): da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) @@ -1422,14 +1434,22 @@ def test_reset_coords(self): with pytest.raises(ValueError, match=r"cannot remove index"): data.reset_coords("y") + # non-dimension index coordinate + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + data = DataArray([1, 2, 3, 4], coords={"x": midx}, dims="x", name="foo") + with pytest.raises(ValueError, match=r"cannot remove index"): + data.reset_coords("lvl1") + def test_assign_coords(self): array = DataArray(10) actual = array.assign_coords(c=42) expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - self.mda.assign_coords(level_1=range(4)) + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda.assign_coords(level_1=("x", range(4))) # GH: 2112 da = xr.DataArray([0, 1, 2], dims="x") @@ -1437,6 +1457,8 @@ def test_assign_coords(self): da["x"] = [0, 1, 2, 3] # size conflict with pytest.raises(ValueError): da.coords["x"] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) @@ -1453,6 +1475,12 @@ def test_set_coords_update_index(self): actual.coords["x"] = ["a", "b", "c"] assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) + def test_set_coords_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + self.mda["level_1"] = range(4) + def test_coords_replacement_alignment(self): # regression test for GH725 arr = DataArray([0, 1, 2], dims=["abc"]) @@ -1473,6 +1501,12 @@ def test_coords_delitem_delete_indexes(self): del arr.coords["x"] assert "x" not in arr.xindexes + def test_coords_delitem_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del self.mda.coords["level_1"] + def test_broadcast_like(self): arr1 = DataArray( np.ones((2, 3)), @@ -1624,10 +1658,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # as kwargs array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") @@ -1635,10 +1666,7 @@ def test_swap_dims(self): actual = array.swap_dims(x="y") assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) # multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) @@ -1647,10 +1675,7 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].to_pandas_index(), - actual.xindexes[dim_name].to_pandas_index(), - ) + assert actual.xindexes[dim_name].equals(expected.xindexes[dim_name]) def test_expand_dims_error(self): array = DataArray( @@ -1843,49 +1868,52 @@ def test_set_index(self): array2d.set_index(x="level") # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r".*variable\(s\) do not exist"): obj.set_index(x="level_4") - assert str(excinfo.value) == "level_4 is not the name of an existing variable." def test_reset_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", self.mindex.values) expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(self.mindex.names) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 obj = self.mda.reset_index(["x", "level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert list(obj.xindexes) == ["level_2"] - coords = { - "x": ("x", self.mindex.droplevel("level_1")), - "level_1": ("x", self.mindex.get_level_values("level_1")), - } expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index(["level_1"]) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert list(obj.xindexes) == ["level_2"] + assert type(obj.xindexes["level_2"]) is PandasIndex - expected = DataArray(self.mda.values, dims="x") + coords = {k: v for k, v in coords.items() if k != "x"} + expected = DataArray(self.mda.values, coords=coords, dims="x") obj = self.mda.reset_index("x", drop=True) - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) array = self.mda.copy() array = array.reset_index(["x"], drop=True) - assert_identical(array, expected) + assert_identical(array, expected, check_default_indexes=False) # single index array = DataArray([1, 2], coords={"x": ["a", "b"]}, dims="x") - expected = DataArray([1, 2], coords={"x_": ("x", ["a", "b"])}, dims="x") - assert_identical(array.reset_index("x"), expected) + obj = array.reset_index("x") + assert_identical(obj, array, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) da = DataArray([1, 0], [coord_1]) - expected = DataArray([1, 0], {"coord_1_": coord_1}, dims=["coord_1"]) obj = da.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, da, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): midx = self.mindex.reorder_levels(["level_2", "level_1"]) @@ -2157,11 +2185,15 @@ def test_dataset_math(self): assert_identical(actual, expected) def test_stack_unstack(self): - orig = DataArray([[0, 1], [2, 3]], dims=["x", "y"], attrs={"foo": 2}) + orig = DataArray( + [[0, 1], [2, 3]], + dims=["x", "y"], + attrs={"foo": 2}, + ) assert_identical(orig, orig.unstack()) # test GH3000 - a = orig[:0, :1].stack(dim=("x", "y")).dim.to_index() + a = orig[:0, :1].stack(dim=("x", "y")).indexes["dim"] b = pd.MultiIndex( levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], codes=[[], []], @@ -2176,16 +2208,21 @@ def test_stack_unstack(self): assert_identical(orig, actual) dims = ["a", "b", "c", "d", "e"] - orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + coords = { + "a": [0], + "b": [1, 2], + "c": [3, 4, 5], + "d": [6, 7], + "e": [8], + } + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), coords=coords, dims=dims) stacked = orig.stack(ab=["a", "b"], cd=["c", "d"]) unstacked = stacked.unstack(["ab", "cd"]) - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) unstacked = stacked.unstack() - roundtripped = unstacked.drop_vars(["a", "b", "c", "d"]).transpose(*dims) - assert_identical(orig, roundtripped) + assert_identical(orig, unstacked.transpose(*dims)) def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 @@ -2317,6 +2354,19 @@ def test_drop_coordinates(self): actual = renamed.drop_vars("foo", errors="ignore") assert_identical(actual, renamed) + def test_drop_multiindex_level(self): + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + self.mda.drop_vars("level_1") + + def test_drop_all_multiindex_levels(self): + dim_levels = ["x", "level_1", "level_2"] + actual = self.mda.drop_vars(dim_levels) + # no error, multi-index dropped + for key in dim_levels: + assert key not in actual.xindexes + def test_drop_index_labels(self): arr = DataArray(np.random.randn(2, 3), coords={"y": [0, 1, 2]}, dims=["x", "y"]) actual = arr.drop_sel(y=[0, 1]) @@ -2707,7 +2757,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): align(left.isel(x=0).expand_dims("x"), right, join="override") @pytest.mark.parametrize( @@ -2726,7 +2778,9 @@ def test_align_override(self): ], ) def test_align_override_error(self, darrays): - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(*darrays, join="override") def test_align_exclude(self): @@ -2782,10 +2836,16 @@ def test_align_mixed_indexes(self): assert_identical(result1, array_with_coord) def test_align_without_indexes_errors(self): - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) - with pytest.raises(ValueError, match=r"cannot be aligned"): + with pytest.raises( + ValueError, + match=r"cannot.*align.*dimension.*conflicting.*sizes.*", + ): align( DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], coords=[("x", [0, 1])]), @@ -3584,7 +3644,9 @@ def test_dot_align_coords(self): dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da.dot(dm) da_aligned, dm_aligned = xr.align(da, dm, join="inner") @@ -3635,7 +3697,9 @@ def test_matmul_align_coords(self): assert_identical(result, expected) with xr.set_options(arithmetic_join="exact"): - with pytest.raises(ValueError, match=r"indexes along dimension"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*exact.*not equal.*" + ): da_a @ da_b def test_binary_op_propagate_indexes(self): @@ -6618,7 +6682,7 @@ def test_clip(da): assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1])) # Unclear whether we want this work, OK to adjust the test when we have decided. - with pytest.raises(ValueError, match="arguments without labels along dimension"): + with pytest.raises(ValueError, match="cannot reindex or align along dimension.*"): result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 41fdd9a373d..96fa4ef144e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -239,9 +239,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - level_1 (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + * level_1 (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -259,9 +259,9 @@ def test_repr_multiindex(self): Dimensions: (x: 4) Coordinates: - * x (x) MultiIndex - - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - - level_2 (x) int64 1 2 1 2 + * x (x) object MultiIndex + * a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' + * level_2 (x) int64 1 2 1 2 Data variables: *empty*""" ) @@ -740,7 +740,9 @@ def test_coords_setitem_with_new_dimension(self): def test_coords_setitem_multiindex(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.coords["level_1"] = range(4) def test_coords_set(self): @@ -1132,7 +1134,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"].drop_vars(["dim2"]), stations["station"]) - with pytest.raises(ValueError, match=r"conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values/indexes on "): data.isel( dim1=DataArray( [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} @@ -1482,6 +1484,16 @@ def test_sel_drop(self): selected = data.sel(x=0, drop=True) assert_identical(expected, selected) + def test_sel_drop_mindex(self): + midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) + data = Dataset(coords={"x": midx}) + + actual = data.sel(foo="a", drop=True) + assert "foo" not in actual.coords + + actual = data.sel(foo="a", drop=False) + assert_equal(actual.foo, DataArray("a", coords={"foo": "a"})) + def test_isel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) @@ -1680,7 +1692,7 @@ def test_sel_method(self): with pytest.raises(TypeError, match=r"``method``"): # this should not pass silently - data.sel(method=data) + data.sel(dim2=1, method=data) # cannot pass method if there is no associated coordinate with pytest.raises(ValueError, match=r"cannot supply"): @@ -1820,8 +1832,10 @@ def test_reindex(self): data.reindex("foo") # invalid dimension - with pytest.raises(ValueError, match=r"invalid reindex dim"): - data.reindex(invalid=0) + # TODO: (benbovy - explicit indexes): uncomment? + # --> from reindex docstrings: "any mis-matched dimension is simply ignored" + # with pytest.raises(ValueError, match=r"indexer keys.*not correspond.*"): + # data.reindex(invalid=0) # out of order expected = data.sel(dim2=data["dim2"][:5:-1]) @@ -2053,7 +2067,7 @@ def test_align_exact(self): assert_identical(left1, left) assert_identical(left2, left) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.align(left, right, join="exact") def test_align_override(self): @@ -2075,7 +2089,9 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): + with pytest.raises( + ValueError, match=r"cannot align.*join.*override.*same size" + ): xr.align(left.isel(x=0).expand_dims("x"), right, join="override") def test_align_exclude(self): @@ -2155,11 +2171,15 @@ def test_align_non_unique(self): def test_align_str_dtype(self): - a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]}) - b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]}) + a = Dataset({"foo": ("x", [0, 1])}, coords={"x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) - expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]}) - expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]}) + expected_a = Dataset( + {"foo": ("x", [0, 1, np.NaN])}, coords={"x": ["a", "b", "c"]} + ) + expected_b = Dataset( + {"foo": ("x", [np.NaN, 1, 2])}, coords={"x": ["a", "b", "c"]} + ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -2340,6 +2360,14 @@ def test_drop_variables(self): actual = data.drop({"time", "not_found_here"}, errors="ignore") assert_identical(expected, actual) + def test_drop_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + data.drop_vars("level_1") + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) @@ -2647,12 +2675,14 @@ def test_rename_dims(self): expected = Dataset( {"x": ("x_new", [0, 1, 2]), "y": ("x_new", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x") dims_dict = {"x": "x_new"} actual = original.rename_dims(dims_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_dims(**dims_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError dims_dict_bad = {"x_bad": "x_new"} @@ -2667,25 +2697,56 @@ def test_rename_vars(self): expected = Dataset( {"x_new": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42} ) + # TODO: (benbovy - explicit indexes) update when set_index supports + # seeting index for non-dimension variables expected = expected.set_coords("x_new") name_dict = {"x": "x_new"} actual = original.rename_vars(name_dict) - assert_identical(expected, actual) + assert_identical(expected, actual, check_default_indexes=False) actual_2 = original.rename_vars(**name_dict) - assert_identical(expected, actual_2) + assert_identical(expected, actual_2, check_default_indexes=False) # Test to raise ValueError names_dict_bad = {"x_bad": "x_new"} with pytest.raises(ValueError): original.rename_vars(names_dict_bad) - def test_rename_multiindex(self): - mindex = pd.MultiIndex.from_tuples( - [([1, 2]), ([3, 4])], names=["level0", "level1"] - ) - data = Dataset({}, {"x": mindex}) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - data.rename({"x": "level0"}) + def test_rename_dimension_coord(self) -> None: + # rename a dimension corodinate to a non-dimension coordinate + # should preserve index + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + + actual = original.rename_vars({"x": "x_new"}) + assert "x_new" in actual.xindexes + + actual_2 = original.rename_dims({"x": "x_new"}) + assert "x" in actual_2.xindexes + + def test_rename_multiindex(self) -> None: + mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + original = Dataset({}, {"x": mindex}) + expected = Dataset({}, {"x": mindex.rename(["a", "c"])}) + + actual = original.rename({"b": "c"}) + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"'a' conflicts"): + original.rename({"x": "a"}) + with pytest.raises(ValueError, match=r"'x' conflicts"): + original.rename({"a": "x"}) + with pytest.raises(ValueError, match=r"'b' conflicts"): + original.rename({"a": "b"}) + + def test_rename_perserve_attrs_encoding(self) -> None: + # test propagate attrs/encoding to new variable(s) created from Index object + original = Dataset(coords={"x": ("x", [0, 1, 2])}) + expected = Dataset(coords={"y": ("y", [0, 1, 2])}) + for ds, dim in zip([original, expected], ["x", "y"]): + ds[dim].attrs = {"foo": "bar"} + ds[dim].encoding = {"foo": "bar"} + + actual = original.rename({"x": "y"}) + assert_identical(actual, expected) @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): @@ -2745,10 +2806,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) roundtripped = actual.swap_dims({"y": "x"}) assert_identical(original.set_coords("y"), roundtripped) @@ -2779,10 +2837,7 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal( - actual.xindexes["y"].to_pandas_index(), - expected.xindexes["y"].to_pandas_index(), - ) + assert actual.xindexes["y"].equals(expected.xindexes["y"]) def test_expand_dims_error(self): original = Dataset( @@ -2974,33 +3029,49 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) + # ensure pre-existing indexes involved are removed + # (level_2 should be a coordinate with no index) + ds = create_test_multiindex() + coords = {"x": coords["level_1"], "level_2": coords["level_2"]} + expected = Dataset({}, coords=coords) + + obj = ds.set_index(x="level_1") + assert_identical(obj, expected) + # ensure set_index with no existing index and a single data var given # doesn't return multi-index ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) expected = Dataset(coords={"x": [0, 1, 2]}) assert_identical(ds.set_index(x="x_var"), expected) - # Issue 3176: Ensure clear error message on key error. - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=r"bar variable\(s\) do not exist"): ds.set_index(foo="bar") - assert str(excinfo.value) == "bar is not the name of an existing variable." + + with pytest.raises(ValueError, match=r"dimension mismatch.*"): + ds.set_index(y="x_var") def test_reset_index(self): ds = create_test_multiindex() mindex = ds["x"].to_index() indexes = [mindex.get_level_values(n) for n in mindex.names] coords = {idx.name: ("x", idx) for idx in indexes} + coords["x"] = ("x", mindex.values) expected = Dataset({}, coords=coords) obj = ds.reset_index("x") - assert_identical(obj, expected) + assert_identical(obj, expected, check_default_indexes=False) + assert len(obj.xindexes) == 0 + + ds = Dataset(coords={"y": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match=r".*not coordinates with an index"): + ds.reset_index("y") def test_reset_index_keep_attrs(self): coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) ds = Dataset({}, {"coord_1": coord_1}) - expected = Dataset({}, {"coord_1_": coord_1}) obj = ds.reset_index("coord_1") - assert_identical(expected, obj) + assert_identical(obj, ds, check_default_indexes=False) + assert len(obj.xindexes) == 0 def test_reorder_levels(self): ds = create_test_multiindex() @@ -3008,6 +3079,10 @@ def test_reorder_levels(self): midx = mindex.reorder_levels(["level_2", "level_1"]) expected = Dataset({}, coords={"x": midx}) + # check attrs propagated + ds["level_1"].attrs["foo"] = "bar" + expected["level_1"].attrs["foo"] = "bar" + reindexed = ds.reorder_levels(x=["level_2", "level_1"]) assert_identical(reindexed, expected) @@ -3017,15 +3092,22 @@ def test_reorder_levels(self): def test_stack(self): ds = Dataset( - {"a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), "y": ["a", "b"]} + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) expected = Dataset( - {"a": ("z", [0, 0, 1, 1]), "b": ("z", [0, 1, 2, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, ) + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "x", "y"] actual = ds.stack(z=[...]) assert_identical(expected, actual) @@ -3040,17 +3122,85 @@ def test_stack(self): exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) expected = Dataset( - {"a": ("z", [0, 1, 0, 1]), "b": ("z", [0, 2, 1, 3]), "z": exp_index} + data_vars={"b": ("z", [0, 2, 1, 3])}, + coords={"z": exp_index}, ) + expected["x"].attrs["foo"] = "bar" + actual = ds.stack(z=["y", "x"]) assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "y", "x"] + + @pytest.mark.parametrize( + "create_index,expected_keys", + [ + (True, ["z", "x", "y"]), + (False, []), + (None, ["z", "x", "y"]), + ], + ) + def test_stack_create_index(self, create_index, expected_keys) -> None: + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ) + + actual = ds.stack(z=["x", "y"], create_index=create_index) + assert list(actual.xindexes) == expected_keys + + # TODO: benbovy (flexible indexes) - test error multiple indexes found + # along dimension + create_index=True + + def test_stack_multi_index(self) -> None: + # multi-index on a dimension to stack is discarded too + midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + ds = xr.Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, + coords={"x": midx, "y": [0, 1]}, + ) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, + coords={ + "x": ("z", np.repeat(midx.values, 2)), + "lvl1": ("z", np.repeat(midx.get_level_values("lvl1"), 2)), + "lvl2": ("z", np.repeat(midx.get_level_values("lvl2"), 2)), + "y": ("z", [0, 1, 0, 1] * 2), + }, + ) + actual = ds.stack(z=["x", "y"], create_index=False) + assert_identical(expected, actual) + assert len(actual.xindexes) == 0 + + with pytest.raises(ValueError, match=r"cannot create.*wraps a multi-index"): + ds.stack(z=["x", "y"], create_index=True) + + def test_stack_non_dim_coords(self): + ds = Dataset( + data_vars={"b": (("x", "y"), [[0, 1], [2, 3]])}, + coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, + ).rename_vars(x="xx") + + exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) + expected = Dataset( + data_vars={"b": ("z", [0, 1, 2, 3])}, + coords={"z": exp_index}, + ) + + actual = ds.stack(z=["x", "y"]) + assert_identical(expected, actual) + assert list(actual.xindexes) == ["z", "xx", "y"] def test_unstack(self): index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) - ds = Dataset({"b": ("z", [0, 1, 2, 3]), "z": index}) + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords={"z": index}) expected = Dataset( {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} ) + + # check attrs propagated + ds["x"].attrs["foo"] = "bar" + expected["x"].attrs["foo"] = "bar" + for dim in ["z", ["z"], None]: actual = ds.unstack(dim) assert_identical(actual, expected) @@ -3059,7 +3209,7 @@ def test_unstack_errors(self): ds = Dataset({"x": [1, 2, 3]}) with pytest.raises(ValueError, match=r"does not contain the dimensions"): ds.unstack("foo") - with pytest.raises(ValueError, match=r"do not have a MultiIndex"): + with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): ds.unstack("x") def test_unstack_fill_value(self): @@ -3147,12 +3297,11 @@ def test_stack_unstack_fast(self): def test_stack_unstack_slow(self): ds = Dataset( - { + data_vars={ "a": ("x", [0, 1]), "b": (("x", "y"), [[0, 1], [2, 3]]), - "x": [0, 1], - "y": ["a", "b"], - } + }, + coords={"x": [0, 1], "y": ["a", "b"]}, ) stacked = ds.stack(z=["x", "y"]) actual = stacked.isel(z=slice(None, None, -1)).unstack("z") @@ -3187,8 +3336,6 @@ def test_to_stacked_array_dtype_dims(self): D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - # TODO: benbovy - flexible indexes: update when MultiIndex has its own class - # inherited from xarray.Index assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") @@ -3259,6 +3406,14 @@ def test_update_overwrite_coords(self): expected = Dataset({"a": ("x", [1, 2]), "c": 5}, {"b": 3}) assert_identical(data, expected) + def test_update_multiindex_level(self): + data = create_test_multiindex() + + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): + data.update({"level_1": range(4)}) + def test_update_auto_align(self): ds = Dataset({"x": ("t", [3, 4])}, {"t": [0, 1]}) @@ -3359,34 +3514,6 @@ def test_virtual_variable_same_name(self): expected = DataArray(times.time, [("time", times)], name="time") assert_identical(actual, expected) - def test_virtual_variable_multiindex(self): - # access multi-index levels as virtual variables - data = create_test_multiindex() - expected = DataArray( - ["a", "a", "b", "b"], - name="level_1", - coords=[data["x"].to_index()], - dims="x", - ) - assert_identical(expected, data["level_1"]) - - # combine multi-index level and datetime - dr_index = pd.date_range("1/1/2011", periods=4, freq="H") - mindex = pd.MultiIndex.from_arrays( - [["a", "a", "b", "b"], dr_index], names=("level_str", "level_date") - ) - data = Dataset({}, {"x": mindex}) - expected = DataArray( - mindex.get_level_values("level_date").hour, - name="hour", - coords=[mindex], - dims="x", - ) - assert_identical(expected, data["level_date.hour"]) - - # attribute style access - assert_identical(data.level_str, data["level_str"]) - def test_time_season(self): ds = Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] @@ -3427,7 +3554,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=r"already exists as a scalar"): data1["newvar"] = ("scalar", [3, 4, 5]) # can't resize a used dimension - with pytest.raises(ValueError, match=r"arguments without labels"): + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): data1["dim1"] = data1["dim1"][:5] # override an existing value data1["A"] = 3 * data2["A"] @@ -3464,7 +3591,7 @@ def test_setitem(self): with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3]}] data3["var2"] = data3["var2"].T - err_msg = "indexes along dimension 'dim2' are not equal" + err_msg = r"cannot align objects.*not equal along these coordinates.*" with pytest.raises(ValueError, match=err_msg): data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3, 4]}] err_msg = "Dataset assignment only accepts DataArrays, Datasets, and scalars." @@ -3680,22 +3807,39 @@ def test_assign_attrs(self): def test_assign_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) - # raise an Error when any level name is used as dimension GH:2299 - with pytest.raises(ValueError): - data["y"] = ("level_1", [0, 1]) + + def test_assign_all_multiindex_coords(self): + data = create_test_multiindex() + actual = data.assign(x=range(4), level_1=range(4), level_2=range(4)) + # no error but multi-index dropped in favor of single indexes for each level + assert ( + actual.xindexes["x"] + is not actual.xindexes["level_1"] + is not actual.xindexes["level_2"] + ) def test_merge_multiindex_level(self): data = create_test_multiindex() - other = Dataset({"z": ("level_1", [0, 1])}) # conflict dimension - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", [0, 1])}) + with pytest.raises(ValueError, match=r".*conflicting dimension sizes.*"): data.merge(other) - other = Dataset({"level_1": ("x", [0, 1])}) # conflict variable name - with pytest.raises(ValueError): + + other = Dataset({"level_1": ("x", range(4))}) + with pytest.raises( + ValueError, match=r"unable to determine.*coordinates or not.*" + ): data.merge(other) + # `other` Dataset coordinates are ignored (bug or feature?) + other = Dataset(coords={"level_1": ("x", range(4))}) + assert_identical(data.merge(other), data) + def test_setitem_original_non_unique_index(self): # regression test for GH943 original = Dataset({"data": ("x", np.arange(5))}, coords={"x": [0, 1, 2, 0, 1]}) @@ -3727,7 +3871,9 @@ def test_setitem_both_non_unique_index(self): def test_setitem_multiindex_level(self): data = create_test_multiindex() - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): + with pytest.raises( + ValueError, match=r"cannot set or update variable.*corrupt.*index " + ): data["level_1"] = range(4) def test_delitem(self): @@ -3745,6 +3891,13 @@ def test_delitem(self): del actual["y"] assert_identical(expected, actual) + def test_delitem_multiindex_level(self): + data = create_test_multiindex() + with pytest.raises( + ValueError, match=r"cannot remove coordinate.*corrupt.*index " + ): + del data["level_1"] + def test_squeeze(self): data = Dataset({"foo": (["x", "y", "z"], [[[1], [2]]])}) for args in [[], [["x"]], [["x", "z"]]]: @@ -4373,7 +4526,7 @@ def test_where_other(self): with pytest.raises(ValueError, match=r"cannot set"): ds.where(ds > 1, other=0, drop=True) - with pytest.raises(ValueError, match=r"indexes .* are not equal"): + with pytest.raises(ValueError, match=r"cannot align .* are not equal"): ds.where(ds > 1, ds.isel(x=slice(3))) with pytest.raises(ValueError, match=r"exact match required"): @@ -5505,7 +5658,7 @@ def test_ipython_key_completion(self): assert sorted(actual) == sorted(expected) # MultiIndex - ds_midx = ds.stack(dim12=["dim1", "dim2"]) + ds_midx = ds.stack(dim12=["dim2", "dim3"]) actual = ds_midx._ipython_key_completions_() expected = [ "var1", diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 529382279de..105cec7e850 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -558,9 +558,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: display_expand_attrs=False, ): actual = formatting.dataset_repr(ds) - col_width = formatting._calculate_col_width( - formatting._get_col_items(ds.variables) - ) + col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( ds, col_width=col_width + 1, max_rows=display_max_rows diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 4ee80f65027..c67619e18c7 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -64,27 +64,27 @@ def test_short_data_repr_html_dask(dask_dataarray) -> None: def test_format_dims_no_dims() -> None: dims: Dict = {} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert formatted == "" def test_format_dims_unsafe_dim_name() -> None: dims = {"": 3, "y": 2} - coord_names: List = [] - formatted = fh.format_dims(dims, coord_names) + dims_with_index: List = [] + formatted = fh.format_dims(dims, dims_with_index) assert "<x>" in formatted def test_format_dims_non_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["time"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["time"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" not in formatted def test_format_dims_index() -> None: - dims, coord_names = {"x": 3, "y": 2}, ["x"] - formatted = fh.format_dims(dims, coord_names) + dims, dims_with_index = {"x": 3, "y": 2}, ["x"] + formatted = fh.format_dims(dims, dims_with_index) assert "class='xr-has-index'" in formatted @@ -119,14 +119,6 @@ def test_repr_of_dataarray(dataarray) -> None: ) -def test_summary_of_multiindex_coord(multiindex) -> None: - idx = multiindex.x.variable.to_index_variable() - formatted = fh._summarize_coord_multiindex("foo", idx) - assert "(level_1, level_2)" in formatted - assert "MultiIndex" in formatted - assert "foo" in formatted - - def test_repr_of_multiindex(multiindex) -> None: formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4b6da82bdc7..f866b68dfa8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -519,10 +519,16 @@ def test_groupby_drops_nans() -> None: actual = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) expected = ( - stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset() + stacked.variable.where(stacked.id.notnull()) + .rename({"xy": "id"}) + .to_dataset() + .reset_index("id", drop=True) + .drop_vars(["lon", "lat"]) + .assign(id=stacked.id.values) + .dropna("id") + .transpose(*actual.dims) ) - expected["id"] = stacked.id.values - assert_identical(actual, expected.dropna("id").transpose(*actual.dims)) + assert_identical(actual, expected) # reduction operation along a different dimension actual = grouped.mean("time") diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 18f76df765d..7edcaa15105 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -1,10 +1,21 @@ +import copy +from typing import Any, Dict, List + import numpy as np import pandas as pd import pytest import xarray as xr -from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe -from xarray.core.variable import IndexVariable +from xarray.core.indexes import ( + Index, + Indexes, + PandasIndex, + PandasMultiIndex, + _asarray_tuplesafe, +) +from xarray.core.variable import IndexVariable, Variable + +from . import assert_equal, assert_identical def test_asarray_tuplesafe() -> None: @@ -19,23 +30,110 @@ def test_asarray_tuplesafe() -> None: assert res[1] == (1,) +class CustomIndex(Index): + def __init__(self, dims) -> None: + self.dims = dims + + +class TestIndex: + @pytest.fixture + def index(self) -> CustomIndex: + return CustomIndex({"x": 2}) + + def test_from_variables(self) -> None: + with pytest.raises(NotImplementedError): + Index.from_variables({}) + + def test_concat(self) -> None: + with pytest.raises(NotImplementedError): + Index.concat([], "x") + + def test_stack(self) -> None: + with pytest.raises(NotImplementedError): + Index.stack({}, "x") + + def test_unstack(self, index) -> None: + with pytest.raises(NotImplementedError): + index.unstack() + + def test_create_variables(self, index) -> None: + assert index.create_variables() == {} + assert index.create_variables({"x": "var"}) == {"x": "var"} + + def test_to_pandas_index(self, index) -> None: + with pytest.raises(TypeError): + index.to_pandas_index() + + def test_isel(self, index) -> None: + assert index.isel({}) is None + + def test_sel(self, index) -> None: + with pytest.raises(NotImplementedError): + index.sel({}) + + def test_join(self, index) -> None: + with pytest.raises(NotImplementedError): + index.join(CustomIndex({"y": 2})) + + def test_reindex_like(self, index) -> None: + with pytest.raises(NotImplementedError): + index.reindex_like(CustomIndex({"y": 2})) + + def test_equals(self, index) -> None: + with pytest.raises(NotImplementedError): + index.equals(CustomIndex({"y": 2})) + + def test_roll(self, index) -> None: + assert index.roll({}) is None + + def test_rename(self, index) -> None: + assert index.rename({}, {}) is index + + @pytest.mark.parametrize("deep", [True, False]) + def test_copy(self, index, deep) -> None: + copied = index.copy(deep=deep) + assert isinstance(copied, CustomIndex) + assert copied is not index + + copied.dims["x"] = 3 + if deep: + assert copied.dims != index.dims + assert copied.dims != copy.deepcopy(index).dims + else: + assert copied.dims is index.dims + assert copied.dims is copy.copy(index).dims + + def test_getitem(self, index) -> None: + with pytest.raises(NotImplementedError): + index[:] + + class TestPandasIndex: def test_constructor(self) -> None: pd_idx = pd.Index([1, 2, 3]) index = PandasIndex(pd_idx, "x") - assert index.index is pd_idx + assert index.index.equals(pd_idx) + # makes a shallow copy + assert index.index is not pd_idx assert index.dim == "x" + # test no name set for pd.Index + pd_idx.name = None + index = PandasIndex(pd_idx, "x") + assert index.index.name == "x" + def test_from_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) var = xr.Variable( - "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + "x", data, attrs={"unit": "m"}, encoding={"dtype": np.float64} ) - index, index_vars = PandasIndex.from_variables({"x": var}) - xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + index = PandasIndex.from_variables({"x": var}) assert index.dim == "x" - assert index.index.equals(index_vars["x"].to_index()) + assert index.index.equals(pd.Index(data)) + assert index.coord_dtype == data.dtype var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises(ValueError, match=r".*only accepts one variable.*"): @@ -46,94 +144,206 @@ def test_from_variables(self) -> None: ): PandasIndex.from_variables({"foo": var2}) - def test_from_pandas_index(self) -> None: - pd_idx = pd.Index([1, 2, 3], name="foo") - - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - - assert index.dim == "x" - assert index.index is pd_idx - assert index.index.name == "foo" - xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) - - # test no name set for pd.Index - pd_idx.name = None - index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") - assert "x" in index_vars - assert index.index is not pd_idx - assert index.index.name == "x" + def test_from_variables_index_adapter(self) -> None: + # test index type is preserved when variable wraps a pd.Index + data = pd.Series(["foo", "bar"], dtype="category") + pd_idx = pd.Index(data) + var = xr.Variable("x", pd_idx) + + index = PandasIndex.from_variables({"x": var}) + assert isinstance(index.index, pd.CategoricalIndex) + + def test_concat_periods(self): + periods = pd.period_range("2000-01-01", periods=10) + indexes = [PandasIndex(periods[:5], "t"), PandasIndex(periods[5:], "t")] + expected = PandasIndex(periods, "t") + actual = PandasIndex.concat(indexes, dim="t") + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + positions = [list(range(5)), list(range(5, 10))] + actual = PandasIndex.concat(indexes, dim="t", positions=positions) + assert actual.equals(expected) + assert isinstance(actual.index, pd.PeriodIndex) + + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype) -> None: + + a = PandasIndex(np.array(["a"], dtype=dtype), "x", coord_dtype=dtype) + b = PandasIndex(np.array(["b"], dtype=dtype), "x", coord_dtype=dtype) + expected = PandasIndex( + np.array(["a", "b"], dtype=dtype), "x", coord_dtype=dtype + ) - def to_pandas_index(self): + actual = PandasIndex.concat([a, b], "x") + assert actual.equals(expected) + assert np.issubdtype(actual.coord_dtype, dtype) + + def test_concat_empty(self) -> None: + idx = PandasIndex.concat([], "x") + assert idx.coord_dtype is np.dtype("O") + + def test_concat_dim_error(self) -> None: + indexes = [PandasIndex([0, 1], "x"), PandasIndex([2, 3], "y")] + + with pytest.raises(ValueError, match=r"Cannot concatenate.*dimensions.*"): + PandasIndex.concat(indexes, "x") + + def test_create_variables(self) -> None: + # pandas has only Float64Index but variable dtype should be preserved + data = np.array([1.1, 2.2, 3.3], dtype=np.float32) + pd_idx = pd.Index(data, name="foo") + index = PandasIndex(pd_idx, "x", coord_dtype=data.dtype) + index_vars = { + "foo": IndexVariable( + "x", data, attrs={"unit": "m"}, encoding={"fill_value": 0.0} + ) + } + + actual = index.create_variables(index_vars) + assert_identical(actual["foo"], index_vars["foo"]) + assert actual["foo"].dtype == index_vars["foo"].dtype + assert actual["foo"].dtype == index.coord_dtype + + def test_to_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index = PandasIndex(pd_idx, "x") - assert index.to_pandas_index() is pd_idx + assert index.to_pandas_index() is index.index - def test_query(self) -> None: + def test_sel(self) -> None: # TODO: add tests that aren't just for edge cases index = PandasIndex(pd.Index([1, 2, 3]), "x") with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"does not have a MultiIndex"): - index.query({"x": {"one": 0}}) + index.sel({"x": {"one": 0}}) + + def test_sel_boolean(self) -> None: + # index should be ignored and indexer dtype should not be coerced + # see https://github.com/pydata/xarray/issues/5727 + index = PandasIndex(pd.Index([0.0, 2.0, 1.0, 3.0]), "x") + actual = index.sel({"x": [False, True, False, True]}) + expected_dim_indexers = {"x": [False, True, False, True]} + np.testing.assert_array_equal( + actual.dim_indexers["x"], expected_dim_indexers["x"] + ) - def test_query_datetime(self) -> None: + def test_sel_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) - actual = index.query({"x": "2001-01-01"}) - expected = (1, None) - assert actual == expected + actual = index.sel({"x": "2001-01-01"}) + expected_dim_indexers = {"x": 1} + assert actual.dim_indexers == expected_dim_indexers - actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) - assert actual == expected + actual = index.sel({"x": index.to_pandas_index().to_numpy()[1]}) + assert actual.dim_indexers == expected_dim_indexers - def test_query_unsorted_datetime_index_raises(self) -> None: + def test_sel_unsorted_datetime_index_raises(self) -> None: index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should # raise instead, so we can be sure the result of indexing with a # slice is always a view. - index.query({"x": slice("2001", "2002")}) + index.sel({"x": slice("2001", "2002")}) def test_equals(self) -> None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([1, 2, 3], "x") assert index1.equals(index2) is True - def test_union(self) -> None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([4, 5, 6], "y") - actual = index1.union(index2) - assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) - assert actual.dim == "x" + def test_join(self) -> None: + index1 = PandasIndex(["a", "aa", "aaa"], "x", coord_dtype=" None: - index1 = PandasIndex([1, 2, 3], "x") - index2 = PandasIndex([2, 3, 4], "y") - actual = index1.intersection(index2) - assert actual.index.equals(pd.Index([2, 3])) - assert actual.dim == "x" + expected = PandasIndex(["aa", "aaa"], "x") + actual = index1.join(index2) + print(actual.index) + assert actual.equals(expected) + assert actual.coord_dtype == " None: + index1 = PandasIndex([0, 1, 2], "x") + index2 = PandasIndex([1, 2, 3, 4], "x") + + expected = {"x": [1, 2, -1, -1]} + actual = index1.reindex_like(index2) + assert actual.keys() == expected.keys() + np.testing.assert_array_equal(actual["x"], expected["x"]) + + index3 = PandasIndex([1, 1, 2], "x") + with pytest.raises(ValueError, match=r".*index has duplicate values"): + index3.reindex_like(index2) + + def test_rename(self) -> None: + index = PandasIndex(pd.Index([1, 2, 3], name="a"), "x", coord_dtype=np.int32) + + # shortcut + new_index = index.rename({}, {}) + assert new_index is index + + new_index = index.rename({"a": "b"}, {}) + assert new_index.index.name == "b" + assert new_index.dim == "x" + assert new_index.coord_dtype == np.int32 + + new_index = index.rename({}, {"x": "y"}) + assert new_index.index.name == "a" + assert new_index.dim == "y" + assert new_index.coord_dtype == np.int32 def test_copy(self) -> None: - expected = PandasIndex([1, 2, 3], "x") + expected = PandasIndex([1, 2, 3], "x", coord_dtype=np.int32) actual = expected.copy() assert actual.index.equals(expected.index) assert actual.index is not expected.index assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype def test_getitem(self) -> None: pd_idx = pd.Index([1, 2, 3]) - expected = PandasIndex(pd_idx, "x") + expected = PandasIndex(pd_idx, "x", coord_dtype=np.int32) actual = expected[1:] assert actual.index.equals(pd_idx[1:]) assert actual.dim == expected.dim + assert actual.coord_dtype == expected.coord_dtype class TestPandasMultiIndex: + def test_constructor(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + + index = PandasMultiIndex(pd_idx, "x") + + assert index.dim == "x" + assert index.index.equals(pd_idx) + assert index.index.names == ("foo", "bar") + assert index.index.name == "x" + assert index.level_coords_dtype == { + "foo": foo_data.dtype, + "bar": bar_data.dtype, + } + + with pytest.raises(ValueError, match=".*conflicting multi-index level name.*"): + PandasMultiIndex(pd_idx, "foo") + + # default level names + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) + index = PandasMultiIndex(pd_idx, "x") + assert index.index.names == ("x_level_0", "x_level_1") + def test_from_variables(self) -> None: v_level1 = xr.Variable( "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} @@ -142,18 +352,15 @@ def test_from_variables(self) -> None: "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} ) - index, index_vars = PandasMultiIndex.from_variables( + index = PandasMultiIndex.from_variables( {"level1": v_level1, "level2": v_level2} ) expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) assert index.dim == "x" assert index.index.equals(expected_idx) - - assert list(index_vars) == ["x", "level1", "level2"] - xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) - xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"]) - xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"]) + assert index.index.name == "x" + assert index.index.names == ["level1", "level2"] var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( @@ -162,35 +369,267 @@ def test_from_variables(self) -> None: PandasMultiIndex.from_variables({"var": var}) v_level3 = xr.Variable("y", [4, 5, 6]) - with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): + with pytest.raises( + ValueError, match=r"unmatched dimensions for multi-index variables.*" + ): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) - def test_from_pandas_index(self) -> None: - pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) + def test_concat(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [[0, 1, 2], ["a", "b"]], names=("foo", "bar") + ) + level_coords_dtype = {"foo": np.int32, "bar": " None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 3, 2])), + } + + index = PandasMultiIndex.stack(prod_vars, "z") + + assert index.dim == "z" + assert index.index.names == ["x", "y"] + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + ) + + with pytest.raises( + ValueError, match=r"conflicting dimensions for multi-index product.*" + ): + PandasMultiIndex.stack( + {"x": xr.Variable("x", ["a", "b"]), "x2": xr.Variable("x", [1, 2])}, + "z", + ) + + def test_stack_non_unique(self) -> None: + prod_vars = { + "x": xr.Variable("x", pd.Index(["b", "a"]), attrs={"foo": "bar"}), + "y": xr.Variable("y", pd.Index([1, 1, 2])), + } - def test_query(self) -> None: + index = PandasMultiIndex.stack(prod_vars, "z") + + np.testing.assert_array_equal( + index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + ) + np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + + def test_unstack(self) -> None: + pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2, 3]], names=["one", "two"] + ) + index = PandasMultiIndex(pd_midx, "x") + + new_indexes, new_pd_idx = index.unstack() + assert list(new_indexes) == ["one", "two"] + assert new_indexes["one"].equals(PandasIndex(["a", "b"], "one")) + assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) + assert new_pd_idx.equals(pd_midx) + + def test_create_variables(self) -> None: + foo_data = np.array([0, 0, 1], dtype="int64") + bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") + pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data], names=("foo", "bar")) + index_vars = { + "x": IndexVariable("x", pd_idx), + "foo": IndexVariable("x", foo_data, attrs={"unit": "m"}), + "bar": IndexVariable("x", bar_data, encoding={"fill_value": 0}), + } + + index = PandasMultiIndex(pd_idx, "x") + actual = index.create_variables(index_vars) + + for k, expected in index_vars.items(): + assert_identical(actual[k], expected) + assert actual[k].dtype == expected.dtype + if k != "x": + assert actual[k].dtype == index.level_coords_dtype[k] + + def test_sel(self) -> None: index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) + # test tuples inside slice are considered as scalar indexer values - assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) + actual = index.sel({"x": slice(("a", 1), ("b", 2))}) + expected_dim_indexers = {"x": slice(0, 4)} + assert actual.dim_indexers == expected_dim_indexers with pytest.raises(KeyError, match=r"not all values found"): - index.query({"x": [0]}) + index.sel({"x": [0]}) with pytest.raises(KeyError): - index.query({"x": 0}) + index.sel({"x": 0}) with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): - index.query({"one": 0, "x": "a"}) + index.sel({"one": 0, "x": "a"}) with pytest.raises(ValueError, match=r"invalid multi-index level names"): - index.query({"x": {"three": 0}}) + index.sel({"x": {"three": 0}}) with pytest.raises(IndexError): - index.query({"x": (slice(None), 1, "no_level")}) + index.sel({"x": (slice(None), 1, "no_level")}) + + def test_join(self): + midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two")) + level_coords_dtype = {"one": " None: + level_coords_dtype = {"one": " None: + level_coords_dtype = {"one": "U<1", "two": np.int32} + expected = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), + "x", + level_coords_dtype=level_coords_dtype, + ) + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + assert actual.level_coords_dtype == expected.level_coords_dtype + + +class TestIndexes: + @pytest.fixture + def unique_indexes(self) -> List[PandasIndex]: + x_idx = PandasIndex(pd.Index([1, 2, 3], name="x"), "x") + y_idx = PandasIndex(pd.Index([4, 5, 6], name="y"), "y") + z_pd_midx = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=["one", "two"] + ) + z_midx = PandasMultiIndex(z_pd_midx, "z") + + return [x_idx, y_idx, z_midx] + + @pytest.fixture + def indexes(self, unique_indexes) -> Indexes[Index]: + x_idx, y_idx, z_midx = unique_indexes + indexes: Dict[Any, Index] = { + "x": x_idx, + "y": y_idx, + "z": z_midx, + "one": z_midx, + "two": z_midx, + } + variables: Dict[Any, Variable] = {} + for idx in unique_indexes: + variables.update(idx.create_variables()) + + return Indexes(indexes, variables) + + def test_interface(self, unique_indexes, indexes) -> None: + x_idx = unique_indexes[0] + assert list(indexes) == ["x", "y", "z", "one", "two"] + assert len(indexes) == 5 + assert "x" in indexes + assert indexes["x"] is x_idx + + def test_variables(self, indexes) -> None: + assert tuple(indexes.variables) == ("x", "y", "z", "one", "two") + + def test_dims(self, indexes) -> None: + assert indexes.dims == {"x": 3, "y": 3, "z": 4} + + def test_get_unique(self, unique_indexes, indexes) -> None: + assert indexes.get_unique() == unique_indexes + + def test_is_multi(self, indexes) -> None: + assert indexes.is_multi("one") is True + assert indexes.is_multi("x") is False + + def test_get_all_coords(self, indexes) -> None: + expected = { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + } + assert indexes.get_all_coords("one") == expected + + with pytest.raises(ValueError, match="errors must be.*"): + indexes.get_all_coords("x", errors="invalid") + + with pytest.raises(ValueError, match="no index found.*"): + indexes.get_all_coords("no_coord") + + assert indexes.get_all_coords("no_coord", errors="ignore") == {} + + def test_get_all_dims(self, indexes) -> None: + expected = {"z": 4} + assert indexes.get_all_dims("one") == expected + + def test_group_by_index(self, unique_indexes, indexes): + expected = [ + (unique_indexes[0], {"x": indexes.variables["x"]}), + (unique_indexes[1], {"y": indexes.variables["y"]}), + ( + unique_indexes[2], + { + "z": indexes.variables["z"], + "one": indexes.variables["one"], + "two": indexes.variables["two"], + }, + ), + ] + + assert indexes.group_by_index() == expected + + def test_to_pandas_indexes(self, indexes) -> None: + pd_indexes = indexes.to_pandas_indexes() + assert isinstance(pd_indexes, Indexes) + assert all([isinstance(idx, pd.Index) for idx in pd_indexes.values()]) + assert indexes.variables == pd_indexes.variables + + def test_copy_indexes(self, indexes) -> None: + copied, index_vars = indexes.copy_indexes() + + assert copied.keys() == indexes.keys() + for new, original in zip(copied.values(), indexes.values()): + assert new.equals(original) + # check unique index objects preserved + assert copied["z"] is copied["one"] is copied["two"] + + assert index_vars.keys() == indexes.variables.keys() + for new, original in zip(index_vars.values(), indexes.variables.values()): + assert_identical(new, original) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 533f4a0cd62..0b40bd18223 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -1,4 +1,5 @@ import itertools +from typing import Any import numpy as np import pandas as pd @@ -6,6 +7,8 @@ from xarray import DataArray, Dataset, Variable from xarray.core import indexing, nputils +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.core.types import T_Xarray from . import IndexerMaker, ReturnItem, assert_array_equal @@ -62,31 +65,74 @@ def test_group_indexers_by_index(self) -> None: ) data.coords["y2"] = ("y", [2.0, 3.0]) - indexes, grouped_indexers = indexing.group_indexers_by_index( - data, {"z": 0, "one": "a", "two": 1, "y": 0} + grouped_indexers = indexing.group_indexers_by_index( + data, {"z": 0, "one": "a", "two": 1, "y": 0}, {} ) - assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]} - assert grouped_indexers == { - "x": {"one": "a", "two": 1}, - "y": {"y": 0}, - None: {"z": 0}, - } - - with pytest.raises(KeyError, match=r"no index found for coordinate y2"): - indexing.group_indexers_by_index(data, {"y2": 2.0}) - with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): - indexing.group_indexers_by_index(data, {"w": "a"}) + + for idx, indexers in grouped_indexers: + if idx is None: + assert indexers == {"z": 0} + elif idx.equals(data.xindexes["x"]): + assert indexers == {"one": "a", "two": 1} + elif idx.equals(data.xindexes["y"]): + assert indexers == {"y": 0} + assert len(grouped_indexers) == 3 + + with pytest.raises(KeyError, match=r"no index found for coordinate 'y2'"): + indexing.group_indexers_by_index(data, {"y2": 2.0}, {}) + with pytest.raises( + KeyError, match=r"'w' is not a valid dimension or coordinate" + ): + indexing.group_indexers_by_index(data, {"w": "a"}, {}) with pytest.raises(ValueError, match=r"cannot supply.*"): - indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") + indexing.group_indexers_by_index(data, {"z": 1}, {"method": "nearest"}) + + def test_map_index_queries(self) -> None: + def create_sel_results( + x_indexer, + x_index, + other_vars, + drop_coords, + drop_indexes, + rename_dims, + ): + dim_indexers = {"x": x_indexer} + index_vars = x_index.create_variables() + indexes = {k: x_index for k in index_vars} + variables = {} + variables.update(index_vars) + variables.update(other_vars) + + return indexing.IndexSelResult( + dim_indexers=dim_indexers, + indexes=indexes, + variables=variables, + drop_coords=drop_coords, + drop_indexes=drop_indexes, + rename_dims=rename_dims, + ) + + def test_indexer( + data: T_Xarray, + x: Any, + expected: indexing.IndexSelResult, + ) -> None: + results = indexing.map_index_queries(data, {"x": x}) + + assert results.dim_indexers.keys() == expected.dim_indexers.keys() + assert_array_equal(results.dim_indexers["x"], expected.dim_indexers["x"]) - def test_remap_label_indexers(self) -> None: - def test_indexer(data, x, expected_pos, expected_idx=None) -> None: - pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) - idx, _ = new_idx_vars.get("x", (None, None)) - if idx is not None: - idx = idx.to_pandas_index() - assert_array_equal(pos.get("x"), expected_pos) - assert_array_equal(idx, expected_idx) + assert results.indexes.keys() == expected.indexes.keys() + for k in results.indexes: + assert results.indexes[k].equals(expected.indexes[k]) + + assert results.variables.keys() == expected.variables.keys() + for k in results.variables: + assert_array_equal(results.variables[k], expected.variables[k]) + + assert set(results.drop_coords) == set(expected.drop_coords) + assert set(results.drop_indexes) == set(expected.drop_indexes) + assert results.rename_dims == expected.rename_dims data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -94,50 +140,96 @@ def test_indexer(data, x, expected_pos, expected_idx=None) -> None: ) mdata = DataArray(range(8), [("x", mindex)]) - test_indexer(data, 1, 0) - test_indexer(data, np.int32(1), 0) - test_indexer(data, Variable([], 1), 0) - test_indexer(mdata, ("a", 1, -1), 0) - test_indexer( - mdata, - ("a", 1), + test_indexer(data, 1, indexing.IndexSelResult({"x": 0})) + test_indexer(data, np.int32(1), indexing.IndexSelResult({"x": 0})) + test_indexer(data, Variable([], 1), indexing.IndexSelResult({"x": 0})) + test_indexer(mdata, ("a", 1, -1), indexing.IndexSelResult({"x": 0})) + + expected = create_sel_results( [True, True, False, False, False, False, False, False], - [-1, -2], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, ) - test_indexer( - mdata, - "a", + test_indexer(mdata, ("a", 1), expected) + + expected = create_sel_results( slice(0, 4, None), - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) - test_indexer( - mdata, - ("a",), + test_indexer(mdata, "a", expected) + + expected = create_sel_results( [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) - test_indexer(mdata, [("a", 1, -1), ("b", 2, -2)], [0, 7]) - test_indexer(mdata, slice("a", "b"), slice(0, 8, None)) - test_indexer(mdata, slice(("a", 1), ("b", 1)), slice(0, 6, None)) - test_indexer(mdata, {"one": "a", "two": 1, "three": -1}, 0) + test_indexer(mdata, ("a",), expected) + test_indexer( - mdata, - {"one": "a", "two": 1}, - [True, True, False, False, False, False, False, False], - [-1, -2], + mdata, [("a", 1, -1), ("b", 2, -2)], indexing.IndexSelResult({"x": [0, 7]}) + ) + test_indexer( + mdata, slice("a", "b"), indexing.IndexSelResult({"x": slice(0, 8, None)}) ) test_indexer( mdata, - {"one": "a", "three": -1}, - [True, False, True, False, False, False, False, False], - [1, 2], + slice(("a", 1), ("b", 1)), + indexing.IndexSelResult({"x": slice(0, 6, None)}), ) test_indexer( mdata, - {"one": "a"}, + {"one": "a", "two": 1, "three": -1}, + indexing.IndexSelResult({"x": 0}), + ) + + expected = create_sel_results( + [True, True, False, False, False, False, False, False], + PandasIndex(pd.Index([-1, -2]), "three"), + {"one": Variable((), "a"), "two": Variable((), 1)}, + ["x"], + ["one", "two"], + {"x": "three"}, + ) + test_indexer(mdata, {"one": "a", "two": 1}, expected) + + expected = create_sel_results( + [True, False, True, False, False, False, False, False], + PandasIndex(pd.Index([1, 2]), "two"), + {"one": Variable((), "a"), "three": Variable((), -1)}, + ["x"], + ["one", "three"], + {"x": "two"}, + ) + test_indexer(mdata, {"one": "a", "three": -1}, expected) + + expected = create_sel_results( [True, True, True, True, False, False, False, False], - pd.MultiIndex.from_product([[1, 2], [-1, -2]]), + PandasMultiIndex( + pd.MultiIndex.from_product([[1, 2], [-1, -2]], names=("two", "three")), + "x", + ), + {"one": Variable((), "a")}, + [], + ["one"], + {}, ) + test_indexer(mdata, {"one": "a"}, expected) def test_read_only_view(self) -> None: diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 555a29b1952..6dca04ed069 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -242,7 +242,7 @@ def test_merge_error(self): def test_merge_alignment_error(self): ds = xr.Dataset(coords={"x": [1, 2]}) other = xr.Dataset(coords={"x": [2, 3]}) - with pytest.raises(ValueError, match=r"indexes .* not equal"): + with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): xr.merge([ds, other], join="exact") def test_merge_wrong_input_error(self): diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index a083c50c3d1..bc3cc367c0e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -147,10 +147,10 @@ def strip_units(obj): new_obj = xr.Dataset(data_vars=data_vars, coords=coords) elif isinstance(obj, xr.DataArray): - data = array_strip_units(obj.data) + data = array_strip_units(obj.variable._data) coords = { strip_units(name): ( - (value.dims, array_strip_units(value.data)) + (value.dims, array_strip_units(value.variable._data)) if isinstance(value.data, Quantity) else value # to preserve multiindexes ) @@ -198,8 +198,7 @@ def attach_units(obj, units): name: ( (value.dims, array_attach_units(value.data, units.get(name) or 1)) if name in units - # to preserve multiindexes - else value + else (value.dims, value.data) ) for name, value in obj.coords.items() } @@ -3610,7 +3609,10 @@ def test_stacking_stacked(self, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_identical(expected, actual) + if func.name == "reset_index": + assert_identical(expected, actual, check_default_indexes=False) + else: + assert_identical(expected, actual) @pytest.mark.skip(reason="indexes don't support units") def test_to_unstacked_dataset(self, dtype): @@ -4719,7 +4721,10 @@ def test_stacking_stacked(self, variant, func, dtype): actual = func(stacked) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "reset_index": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.xfail( reason="stacked dimension's labels have to be hashable, but is a numpy.array" @@ -5503,7 +5508,10 @@ def test_content_manipulation(self, func, variant, dtype): actual = func(ds) assert_units_equal(expected, actual) - assert_equal(expected, actual) + if func.name == "rename_dims": + assert_equal(expected, actual, check_default_indexes=False) + else: + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ce796e9de49..0a720b23d3b 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -89,31 +89,6 @@ def test_safe_cast_to_index_datetime_datetime(): assert isinstance(actual, pd.Index) -def test_multiindex_from_product_levels(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 3, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 3, 2]) - - other = pd.MultiIndex.from_product([["b", "a"], [1, 3, 2]]) - np.testing.assert_array_equal(result.values, other.values) - - -def test_multiindex_from_product_levels_non_unique(): - result = utils.multiindex_from_product_levels( - [pd.Index(["b", "a"]), pd.Index([1, 1, 2])] - ) - np.testing.assert_array_equal( - result.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] - ) - np.testing.assert_array_equal(result.levels[0], ["b", "a"]) - np.testing.assert_array_equal(result.levels[1], [1, 2]) - - class TestArrayEquiv: def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional