Skip to content

Commit 87a25b6

Browse files
authored
2x~5x speed up for isel() in most cases (#3533)
* Speed up isel in most cases * What's New * Trivial * Use _replace * isort * Code review * What's New * mypy
1 parent cf17317 commit 87a25b6

11 files changed

+116
-14
lines changed

doc/whats-new.rst

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ Bug fixes
3737
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
3838
By `Deepak Cherian <https://github.com/dcherian>`_.
3939

40-
4140
Documentation
4241
~~~~~~~~~~~~~
4342
- Switch doc examples to use nbsphinx and replace sphinx_gallery with
@@ -58,8 +57,10 @@ Documentation
5857

5958
Internal Changes
6059
~~~~~~~~~~~~~~~~
61-
62-
60+
- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`,
61+
:py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int,
62+
slice, list of int, scalar ndarray, or 1-dimensional ndarray.
63+
(:pull:`3533`) by `Guido Imperiale <https://github.com/crusaderky>`_.
6364
- Removed internal method ``Dataset._from_vars_and_coord_names``,
6465
which was dominated by ``Dataset._construct_direct``. (:pull:`3565`)
6566
By `Maximilian Roos <https://github.com/max-sixty>`_
@@ -190,6 +191,7 @@ Documentation
190191

191192
Internal Changes
192193
~~~~~~~~~~~~~~~~
194+
193195
- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
194196
(:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`)
195197
by `Justus Magin <https://github.com/keewis>`_.

xarray/coding/cftime_offsets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import re
4444
from datetime import timedelta
45+
from distutils.version import LooseVersion
4546
from functools import partial
4647
from typing import ClassVar, Optional
4748

@@ -50,7 +51,6 @@
5051
from ..core.pdcompat import count_not_none
5152
from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso
5253
from .times import format_cftime_datetime
53-
from distutils.version import LooseVersion
5454

5555

5656
def get_date_type(calendar):

xarray/core/dataarray.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
)
5151
from .dataset import Dataset, split_indexes
5252
from .formatting import format_item
53-
from .indexes import Indexes, propagate_indexes, default_indexes
53+
from .indexes import Indexes, default_indexes, propagate_indexes
54+
from .indexing import is_fancy_indexer
5455
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
5556
from .options import OPTIONS
5657
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
@@ -1027,8 +1028,27 @@ def isel(
10271028
DataArray.sel
10281029
"""
10291030
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
1030-
ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers)
1031-
return self._from_temp_dataset(ds)
1031+
if any(is_fancy_indexer(idx) for idx in indexers.values()):
1032+
ds = self._to_temp_dataset()._isel_fancy(indexers, drop=drop)
1033+
return self._from_temp_dataset(ds)
1034+
1035+
# Much faster algorithm for when all indexers are ints, slices, one-dimensional
1036+
# lists, or zero or one-dimensional np.ndarray's
1037+
1038+
variable = self._variable.isel(indexers)
1039+
1040+
coords = {}
1041+
for coord_name, coord_value in self._coords.items():
1042+
coord_indexers = {
1043+
k: v for k, v in indexers.items() if k in coord_value.dims
1044+
}
1045+
if coord_indexers:
1046+
coord_value = coord_value.isel(coord_indexers)
1047+
if drop and coord_value.ndim == 0:
1048+
continue
1049+
coords[coord_name] = coord_value
1050+
1051+
return self._replace(variable=variable, coords=coords)
10321052

10331053
def sel(
10341054
self,

xarray/core/dataset.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
propagate_indexes,
6767
roll_index,
6868
)
69+
from .indexing import is_fancy_indexer
6970
from .merge import (
7071
dataset_merge_method,
7172
dataset_update_method,
@@ -78,8 +79,8 @@
7879
Default,
7980
Frozen,
8081
SortedKeysDict,
81-
_default,
8282
_check_inplace,
83+
_default,
8384
decode_numpy_dict_values,
8485
either_dict_or_kwargs,
8586
hashable,
@@ -1907,6 +1908,48 @@ def isel(
19071908
DataArray.isel
19081909
"""
19091910
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
1911+
if any(is_fancy_indexer(idx) for idx in indexers.values()):
1912+
return self._isel_fancy(indexers, drop=drop)
1913+
1914+
# Much faster algorithm for when all indexers are ints, slices, one-dimensional
1915+
# lists, or zero or one-dimensional np.ndarray's
1916+
invalid = indexers.keys() - self.dims.keys()
1917+
if invalid:
1918+
raise ValueError("dimensions %r do not exist" % invalid)
1919+
1920+
variables = {}
1921+
dims: Dict[Hashable, Tuple[int, ...]] = {}
1922+
coord_names = self._coord_names.copy()
1923+
indexes = self._indexes.copy() if self._indexes is not None else None
1924+
1925+
for var_name, var_value in self._variables.items():
1926+
var_indexers = {k: v for k, v in indexers.items() if k in var_value.dims}
1927+
if var_indexers:
1928+
var_value = var_value.isel(var_indexers)
1929+
if drop and var_value.ndim == 0 and var_name in coord_names:
1930+
coord_names.remove(var_name)
1931+
if indexes:
1932+
indexes.pop(var_name, None)
1933+
continue
1934+
if indexes and var_name in indexes:
1935+
if var_value.ndim == 1:
1936+
indexes[var_name] = var_value.to_index()
1937+
else:
1938+
del indexes[var_name]
1939+
variables[var_name] = var_value
1940+
dims.update(zip(var_value.dims, var_value.shape))
1941+
1942+
return self._construct_direct(
1943+
variables=variables,
1944+
coord_names=coord_names,
1945+
dims=dims,
1946+
attrs=self._attrs,
1947+
indexes=indexes,
1948+
encoding=self._encoding,
1949+
file_obj=self._file_obj,
1950+
)
1951+
1952+
def _isel_fancy(self, indexers: Mapping[Hashable, Any], *, drop: bool) -> "Dataset":
19101953
# Note: we need to preserve the original indexers variable in order to merge the
19111954
# coords below
19121955
indexers_list = list(self._validate_indexers(indexers))

xarray/core/formatting_html.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import uuid
2-
import pkg_resources
32
from collections import OrderedDict
43
from functools import partial
54
from html import escape
65

7-
from .formatting import inline_variable_array_repr, short_data_repr
6+
import pkg_resources
87

8+
from .formatting import inline_variable_array_repr, short_data_repr
99

1010
CSS_FILE_PATH = "/".join(("static", "css", "style.css"))
1111
CSS_STYLE = pkg_resources.resource_string("xarray", CSS_FILE_PATH).decode("utf8")

xarray/core/indexing.py

+13
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,19 @@ def posify_mask_indexer(indexer):
12131213
return type(indexer)(key)
12141214

12151215

1216+
def is_fancy_indexer(indexer: Any) -> bool:
1217+
"""Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or
1218+
1-dimensional ndarray; in all other cases return True
1219+
"""
1220+
if isinstance(indexer, (int, slice)):
1221+
return False
1222+
if isinstance(indexer, np.ndarray):
1223+
return indexer.ndim > 1
1224+
if isinstance(indexer, list):
1225+
return bool(indexer) and not isinstance(indexer[0], int)
1226+
return True
1227+
1228+
12161229
class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
12171230
"""Wrap a NumPy array to use explicit indexing."""
12181231

xarray/core/variable.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,10 @@ def _broadcast_indexes_outer(self, key):
617617
k = k.data
618618
if not isinstance(k, BASIC_INDEXING_TYPES):
619619
k = np.asarray(k)
620-
if k.dtype.kind == "b":
620+
if k.size == 0:
621+
# Slice by empty list; numpy could not infer the dtype
622+
k = k.astype(int)
623+
elif k.dtype.kind == "b":
621624
(k,) = np.nonzero(k)
622625
new_key.append(k)
623626

xarray/tests/test_dask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from xarray.testing import assert_chunks_equal
1717
from xarray.tests import mock
1818

19+
from ..core.duck_array_ops import lazy_array_equiv
1920
from . import (
2021
assert_allclose,
2122
assert_array_equal,
@@ -25,7 +26,6 @@
2526
raises_regex,
2627
requires_scipy_or_netCDF4,
2728
)
28-
from ..core.duck_array_ops import lazy_array_equiv
2929
from .test_backends import create_tmp_file
3030

3131
dask = pytest.importorskip("dask")

xarray/tests/test_dataarray.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from xarray.core.common import full_like
1717
from xarray.core.indexes import propagate_indexes
1818
from xarray.core.utils import is_scalar
19-
2019
from xarray.tests import (
2120
LooseVersion,
2221
ReturnItem,

xarray/tests/test_missing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
NumpyInterpolator,
1010
ScipyInterpolator,
1111
SplineInterpolator,
12-
get_clean_interp_index,
1312
_get_nan_block_lengths,
13+
get_clean_interp_index,
1414
)
1515
from xarray.core.pycompat import dask_array_type
1616
from xarray.tests import (

xarray/tests/test_variable.py

+22
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,26 @@ def test_items(self):
11561156
def test_getitem_basic(self):
11571157
v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]])
11581158

1159+
# int argument
1160+
v_new = v[0]
1161+
assert v_new.dims == ("y",)
1162+
assert_array_equal(v_new, v._data[0])
1163+
1164+
# slice argument
1165+
v_new = v[:2]
1166+
assert v_new.dims == ("x", "y")
1167+
assert_array_equal(v_new, v._data[:2])
1168+
1169+
# list arguments
1170+
v_new = v[[0]]
1171+
assert v_new.dims == ("x", "y")
1172+
assert_array_equal(v_new, v._data[[0]])
1173+
1174+
v_new = v[[]]
1175+
assert v_new.dims == ("x", "y")
1176+
assert_array_equal(v_new, v._data[[]])
1177+
1178+
# dict arguments
11591179
v_new = v[dict(x=0)]
11601180
assert v_new.dims == ("y",)
11611181
assert_array_equal(v_new, v._data[0])
@@ -1196,6 +1216,8 @@ def test_isel(self):
11961216
assert_identical(v.isel(time=0), v[0])
11971217
assert_identical(v.isel(time=slice(0, 3)), v[:3])
11981218
assert_identical(v.isel(x=0), v[:, 0])
1219+
assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]])
1220+
assert_identical(v.isel(time=[]), v[[]])
11991221
with raises_regex(ValueError, "do not exist"):
12001222
v.isel(not_a_dim=0)
12011223

0 commit comments

Comments
 (0)