Skip to content

Commit 2aa5b8a

Browse files
shoyerfujiisoup
authored andcommitted
Use getitem_with_mask in reindex_variables (#1847)
* WIP: use getitem_with_mask in reindex_variables * Fix dtype promotion for where * Add whats new * Fix flake8 * Fix test_align_dtype and bool+str promotion * tests and docstring for dtypes.result_type * More dtype promotion fixes, including for concat
1 parent 33660b7 commit 2aa5b8a

File tree

11 files changed

+234
-77
lines changed

11 files changed

+234
-77
lines changed

asv_bench/asv.conf.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"netcdf4": [""],
6464
"scipy": [""],
6565
"bottleneck": ["", null],
66-
"dask": ["", null],
66+
"dask": [""],
6767
},
6868

6969

asv_bench/benchmarks/reindexing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import numpy as np
6+
import xarray as xr
7+
8+
from . import requires_dask
9+
10+
11+
class Reindex(object):
12+
def setup(self):
13+
data = np.random.RandomState(0).randn(1000, 100, 100)
14+
self.ds = xr.Dataset({'temperature': (('time', 'x', 'y'), data)},
15+
coords={'time': np.arange(1000),
16+
'x': np.arange(100),
17+
'y': np.arange(100)})
18+
19+
def time_1d_coarse(self):
20+
self.ds.reindex(time=np.arange(0, 1000, 5)).load()
21+
22+
def time_1d_fine_all_found(self):
23+
self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest').load()
24+
25+
def time_1d_fine_some_missing(self):
26+
self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest',
27+
tolerance=0.1).load()
28+
29+
def time_2d_coarse(self):
30+
self.ds.reindex(x=np.arange(0, 100, 2), y=np.arange(0, 100, 2)).load()
31+
32+
def time_2d_fine_all_found(self):
33+
self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5),
34+
method='nearest').load()
35+
36+
def time_2d_fine_some_missing(self):
37+
self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5),
38+
method='nearest', tolerance=0.1).load()
39+
40+
41+
class ReindexDask(Reindex):
42+
def setup(self):
43+
requires_dask()
44+
super(ReindexDask, self).setup()
45+
self.ds = self.ds.chunk({'time': 100})

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Enhancements
8181
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
8282
- `.dt` accessor can now ceil, floor and round timestamps to specified frequency.
8383
By `Deepak Cherian <https://github.com/dcherian>`_.
84+
- Speed of reindexing/alignment with dask array is orders of magnitude faster
85+
when inserting missing values (:issue:`1847`).
86+
By `Stephan Hoyer <https://github.com/shoyer>`_.
8487

8588
.. _Zarr: http://zarr.readthedocs.io/
8689

@@ -140,6 +143,10 @@ Bug fixes
140143
``parse_coordinates`` kwarg has beed added to :py:func:`~open_rasterio`
141144
(set to ``True`` per default).
142145
By `Fabien Maussion <https://github.com/fmaussion>`_.
146+
- Fixed dtype promotion rules in :py:func:`where` and :py:func:`concat` to
147+
match pandas (:issue:`1847`). A combination of strings/numbers or
148+
unicode/bytes now promote to object dtype, instead of strings or unicode.
149+
By `Stephan Hoyer <https://github.com/shoyer>`_.
143150

144151
.. _whats-new.0.10.0:
145152

xarray/core/alignment.py

Lines changed: 34 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88

99
import numpy as np
1010

11-
from . import duck_array_ops
12-
from . import dtypes
1311
from . import utils
1412
from .indexing import get_indexer_nd
1513
from .pycompat import iteritems, OrderedDict, suppress
1614
from .utils import is_full_slice, is_dict_like
17-
from .variable import Variable, IndexVariable
15+
from .variable import IndexVariable
1816

1917

2018
def _get_joiner(join):
@@ -306,59 +304,51 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None,
306304
from .dataarray import DataArray
307305

308306
# build up indexers for assignment along each dimension
309-
to_indexers = {}
310-
from_indexers = {}
307+
int_indexers = {}
308+
targets = {}
309+
masked_dims = set()
310+
unchanged_dims = set()
311+
311312
# size of reindexed dimensions
312313
new_sizes = {}
313314

314315
for name, index in iteritems(indexes):
315316
if name in indexers:
316-
target = utils.safe_cast_to_index(indexers[name])
317317
if not index.is_unique:
318318
raise ValueError(
319319
'cannot reindex or align along dimension %r because the '
320320
'index has duplicate values' % name)
321-
indexer = get_indexer_nd(index, target, method, tolerance)
322321

322+
target = utils.safe_cast_to_index(indexers[name])
323323
new_sizes[name] = len(target)
324-
# Note pandas uses negative values from get_indexer_nd to signify
325-
# values that are missing in the index
326-
# The non-negative values thus indicate the non-missing values
327-
to_indexers[name] = indexer >= 0
328-
if to_indexers[name].all():
329-
# If an indexer includes no negative values, then the
330-
# assignment can be to a full-slice (which is much faster,
331-
# and means we won't need to fill in any missing values)
332-
to_indexers[name] = slice(None)
333-
334-
from_indexers[name] = indexer[to_indexers[name]]
335-
if np.array_equal(from_indexers[name], np.arange(len(index))):
336-
# If the indexer is equal to the original index, use a full
337-
# slice object to speed up selection and so we can avoid
338-
# unnecessary copies
339-
from_indexers[name] = slice(None)
324+
325+
int_indexer = get_indexer_nd(index, target, method, tolerance)
326+
327+
# We uses negative values from get_indexer_nd to signify
328+
# values that are missing in the index.
329+
if (int_indexer < 0).any():
330+
masked_dims.add(name)
331+
elif np.array_equal(int_indexer, np.arange(len(index))):
332+
unchanged_dims.add(name)
333+
334+
int_indexers[name] = int_indexer
335+
targets[name] = target
340336

341337
for dim in sizes:
342338
if dim not in indexes and dim in indexers:
343339
existing_size = sizes[dim]
344-
new_size = utils.safe_cast_to_index(indexers[dim]).size
340+
new_size = indexers[dim].size
345341
if existing_size != new_size:
346342
raise ValueError(
347343
'cannot reindex or align along dimension %r without an '
348344
'index because its size %r is different from the size of '
349345
'the new index %r' % (dim, existing_size, new_size))
350346

351-
def any_not_full_slices(indexers):
352-
return any(not is_full_slice(idx) for idx in indexers)
353-
354-
def var_indexers(var, indexers):
355-
return tuple(indexers.get(d, slice(None)) for d in var.dims)
356-
357347
# create variables for the new dataset
358348
reindexed = OrderedDict()
359349

360350
for dim, indexer in indexers.items():
361-
if isinstance(indexer, DataArray) and indexer.dims != (dim, ):
351+
if isinstance(indexer, DataArray) and indexer.dims != (dim,):
362352
warnings.warn(
363353
"Indexer has dimensions {0:s} that are different "
364354
"from that to be indexed along {1:s}. "
@@ -375,47 +365,24 @@ def var_indexers(var, indexers):
375365

376366
for name, var in iteritems(variables):
377367
if name not in indexers:
378-
assign_to = var_indexers(var, to_indexers)
379-
assign_from = var_indexers(var, from_indexers)
380-
381-
if any_not_full_slices(assign_to):
382-
# there are missing values to in-fill
383-
data = var[assign_from].data
384-
dtype, fill_value = dtypes.maybe_promote(var.dtype)
385-
386-
if isinstance(data, np.ndarray):
387-
shape = tuple(new_sizes.get(dim, size)
388-
for dim, size in zip(var.dims, var.shape))
389-
new_data = np.empty(shape, dtype=dtype)
390-
new_data[...] = fill_value
391-
# create a new Variable so we can use orthogonal indexing
392-
# use fastpath=True to avoid dtype inference
393-
new_var = Variable(var.dims, new_data, var.attrs,
394-
fastpath=True)
395-
new_var[assign_to] = data
396-
397-
else: # dask array
398-
data = data.astype(dtype, copy=False)
399-
for axis, indexer in enumerate(assign_to):
400-
if not is_full_slice(indexer):
401-
indices = np.cumsum(indexer)[~indexer]
402-
data = duck_array_ops.insert(
403-
data, indices, fill_value, axis=axis)
404-
new_var = Variable(var.dims, data, var.attrs,
405-
fastpath=True)
406-
407-
elif any_not_full_slices(assign_from):
408-
# type coercion is not necessary as there are no missing
409-
# values
410-
new_var = var[assign_from]
411-
412-
else:
413-
# no reindexing is necessary
368+
key = tuple(slice(None)
369+
if d in unchanged_dims
370+
else int_indexers.get(d, slice(None))
371+
for d in var.dims)
372+
needs_masking = any(d in masked_dims for d in var.dims)
373+
374+
if needs_masking:
375+
new_var = var._getitem_with_mask(key)
376+
elif all(is_full_slice(k) for k in key):
377+
# no reindexing necessary
414378
# here we need to manually deal with copying data, since
415379
# we neither created a new ndarray nor used fancy indexing
416380
new_var = var.copy(deep=copy)
381+
else:
382+
new_var = var[key]
417383

418384
reindexed[name] = new_var
385+
419386
return reindexed
420387

421388

xarray/core/dtypes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
NA = utils.ReprObject('<NA>')
88

99

10+
# Pairs of types that, if both found, should be promoted to object dtype
11+
# instead of following NumPy's own type-promotion rules. These type promotion
12+
# rules match pandas instead. For reference, see the NumPy type hierarchy:
13+
# https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
14+
PROMOTE_TO_OBJECT = [
15+
{np.number, np.character}, # numpy promotes to character
16+
{np.bool_, np.character}, # numpy promotes to character
17+
{np.bytes_, np.unicode_}, # numpy promotes to unicode
18+
]
19+
20+
1021
def maybe_promote(dtype):
1122
"""Simpler equivalent of pandas.core.common._maybe_promote
1223
@@ -60,3 +71,29 @@ def is_datetime_like(dtype):
6071
"""
6172
return (np.issubdtype(dtype, np.datetime64) or
6273
np.issubdtype(dtype, np.timedelta64))
74+
75+
76+
def result_type(*arrays_and_dtypes):
77+
"""Like np.result_type, but with type promotion rules matching pandas.
78+
79+
Examples of changed behavior:
80+
number + string -> object (not string)
81+
bytes + unicode -> object (not unicode)
82+
83+
Parameters
84+
----------
85+
*arrays_and_dtypes : list of arrays and dtypes
86+
The dtype is extracted from both numpy and dask arrays.
87+
88+
Returns
89+
-------
90+
numpy.dtype for the result.
91+
"""
92+
types = {np.result_type(t).type for t in arrays_and_dtypes}
93+
94+
for left, right in PROMOTE_TO_OBJECT:
95+
if (any(issubclass(t, left) for t in types) and
96+
any(issubclass(t, right) for t in types)):
97+
return np.dtype(object)
98+
99+
return np.result_type(*arrays_and_dtypes)

xarray/core/duck_array_ops.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def isnull(data):
8282

8383

8484
transpose = _dask_or_eager_func('transpose')
85-
where = _dask_or_eager_func('where', n_array_args=3)
85+
_where = _dask_or_eager_func('where', n_array_args=3)
8686
insert = _dask_or_eager_func('insert')
8787
take = _dask_or_eager_func('take')
8888
broadcast_to = _dask_or_eager_func('broadcast_to')
8989

90-
concatenate = _dask_or_eager_func('concatenate', list_of_args=True)
91-
stack = _dask_or_eager_func('stack', list_of_args=True)
90+
_concatenate = _dask_or_eager_func('concatenate', list_of_args=True)
91+
_stack = _dask_or_eager_func('stack', list_of_args=True)
9292

9393
array_all = _dask_or_eager_func('all')
9494
array_any = _dask_or_eager_func('any')
@@ -100,6 +100,17 @@ def asarray(data):
100100
return data if isinstance(data, dask_array_type) else np.asarray(data)
101101

102102

103+
def as_shared_dtype(scalars_or_arrays):
104+
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
105+
arrays = [asarray(x) for x in scalars_or_arrays]
106+
# Pass arrays directly instead of dtypes to result_type so scalars
107+
# get handled properly.
108+
# Note that result_type() safely gets the dtype from dask arrays without
109+
# evaluating them.
110+
out_type = dtypes.result_type(*arrays)
111+
return [x.astype(out_type, copy=False) for x in arrays]
112+
113+
103114
def as_like_arrays(*data):
104115
if all(isinstance(d, dask_array_type) for d in data):
105116
return data
@@ -151,6 +162,11 @@ def count(data, axis=None):
151162
return sum(~isnull(data), axis=axis)
152163

153164

165+
def where(condition, x, y):
166+
"""Three argument where() with better dtype promotion rules."""
167+
return _where(condition, *as_shared_dtype([x, y]))
168+
169+
154170
def where_method(data, cond, other=dtypes.NA):
155171
if other is dtypes.NA:
156172
other = dtypes.get_fill_value(data.dtype)
@@ -161,6 +177,16 @@ def fillna(data, other):
161177
return where(isnull(data), other, data)
162178

163179

180+
def concatenate(arrays, axis=0):
181+
"""concatenate() with better dtype promotion rules."""
182+
return _concatenate(as_shared_dtype(arrays), axis=axis)
183+
184+
185+
def stack(arrays, axis=0):
186+
"""stack() with better dtype promotion rules."""
187+
return _stack(as_shared_dtype(arrays), axis=axis)
188+
189+
164190
@contextlib.contextmanager
165191
def _ignore_warnings_if(condition):
166192
if condition:

xarray/core/variable.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,8 +1273,6 @@ def concat(cls, variables, dim='concat_dim', positions=None,
12731273

12741274
arrays = [v.data for v in variables]
12751275

1276-
# TODO: use our own type promotion rules to ensure that
1277-
# [str, float] -> object, not str like numpy
12781276
if dim in first_var.dims:
12791277
axis = first_var.get_axis_num(dim)
12801278
dims = first_var.dims

xarray/tests/test_dataarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,12 @@ def test_where(self):
17171717
actual = arr.where(arr.x < 2, drop=True)
17181718
assert_identical(actual, expected)
17191719

1720+
def test_where_string(self):
1721+
array = DataArray(['a', 'b'])
1722+
expected = DataArray(np.array(['a', np.nan], dtype=object))
1723+
actual = array.where([True, False])
1724+
assert_identical(actual, expected)
1725+
17201726
def test_cumops(self):
17211727
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
17221728
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),

0 commit comments

Comments
 (0)