Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite interp to use apply_ufunc #9881

Merged
merged 26 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c1603c5
Don't eagerly compute dask arrays in localize
dcherian Dec 11, 2024
6109f6e
Clean up test
dcherian Dec 11, 2024
61c5b2c
Clean up Variable handling
dcherian Dec 10, 2024
81b73b9
Silence test warning
dcherian Dec 11, 2024
652bcc1
Use apply_ufunc instead
dcherian Dec 5, 2024
03f0b36
Add test for #4463
dcherian Dec 13, 2024
9b915b2
complete tests
dcherian Dec 13, 2024
be5c783
Add comments
dcherian Dec 13, 2024
a5e1854
Clear up broadcasting
dcherian Dec 14, 2024
eef94fa
typing
dcherian Dec 14, 2024
79a7e56
try a different warning filter
dcherian Dec 14, 2024
6e22072
one more fix
dcherian Dec 14, 2024
8d4503a
types + more duck_array_ops
dcherian Dec 14, 2024
586f638
fixes
dcherian Dec 14, 2024
972e9fb
Merge branch 'main' into pr/9881
Illviljan Dec 14, 2024
38e66fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2024
43a0691
Merge branch 'main' into pr/9881
Illviljan Dec 14, 2024
97a388e
Apply suggestions from code review
dcherian Dec 17, 2024
ef24840
Merge branch 'main' into redo-blockwise-interp
dcherian Dec 18, 2024
437219f
Merge branch 'main' into pr/9881
Illviljan Dec 18, 2024
c152ca3
Apply suggestions from code review
dcherian Dec 19, 2024
6c1dd95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2024
1b9845d
Apply suggestions from code review
dcherian Dec 19, 2024
81b8a90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2024
6f6dd0a
fix
dcherian Dec 19, 2024
e6ec62b
Revert "Apply suggestions from code review"
dcherian Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 18 additions & 37 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,19 +2921,11 @@ def _validate_interp_indexers(
"""Variant of _validate_indexers to be used for interpolation"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, v
elif is_scalar(v):
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v, attrs=self.coords[k].attrs)
elif v.ndim == 1:
yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs)
else:
raise AssertionError() # Already tested by _validate_indexers
yield k, Variable(dims=(k,), data=v, attrs=self.coords[k].attrs)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError(type(v))

Expand Down Expand Up @@ -4127,18 +4119,6 @@ def interp(

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = dict(self._validate_interp_indexers(coords))

if coords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled by vectorize=True. This is possibly a perf regression with numpy arrays, but a massive improvement with chunked arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity the bad thing about this approach is that it can greatly expand the number of core dimensions for the problem, limiting the potential for parallelism.

Consider the problem in #6799 (comment). In the following, dimension names are listed out in [].

da[time, q, lat, lon].interp(q=bar[lat,lon]) gets rewritten to da[time,q,lat,lon].interp(q=bar[lat, lon], lat=lat[lat], lon=lon[lon]) which thanks to our automatic rechunking, makes dask merge chunks in lat, lon too, for no benefit.

# This avoids broadcasting over coordinates that are both in
# the original array AND in the indexing array. It essentially
# forces interpolation along the shared coordinates.
sdims = (
set(self.dims)
.intersection(*[set(nx.dims) for nx in indexers.values()])
.difference(coords.keys())
)
indexers.update({d: self.variables[d] for d in sdims})

obj = self if assume_sorted else self.sortby(list(coords))

def maybe_variable(obj, k):
Expand Down Expand Up @@ -4169,16 +4149,18 @@ def _validate_interp_indexer(x, new_x):
for k, v in indexers.items()
}

# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]

# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
if obj.__dask_graph__():
has_chunked_array = bool(
any(is_chunked_array(v._data) for v in obj._variables.values())
)
if has_chunked_array:
# optimization: subset to coordinate range of the target index
if method in ["linear", "nearest"]:
for k, v in validated_indexers.items():
obj, newidx = missing._localize(obj, {k: v})
validated_indexers[k] = newidx[k]
# optimization: create dask coordinate arrays once per Dataset
# rather than once per Variable when dask.array.unify_chunks is called later
# GH4739
dask_indexers = {
k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
for k, (index, dest) in validated_indexers.items()
Expand All @@ -4190,10 +4172,9 @@ def _validate_interp_indexer(x, new_x):
if name in indexers:
continue

if is_duck_dask_array(var.data):
use_indexers = dask_indexers
else:
use_indexers = validated_indexers
use_indexers = (
dask_indexers if is_duck_dask_array(var.data) else validated_indexers
)

dtype_kind = var.dtype.kind
if dtype_kind in "uifc":
Expand Down
Loading
Loading