Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interp performance with chunked dimensions #6799

Open
slevang opened this issue Jul 17, 2022 · 12 comments
Open

interp performance with chunked dimensions #6799

slevang opened this issue Jul 17, 2022 · 12 comments

Comments

@slevang
Copy link
Contributor

slevang commented Jul 17, 2022

What is your issue?

I'm trying to perform 2D interpolation on a large 3D array that is heavily chunked along the interpolation dimensions and not the third dimension. The application could be extracting a timeseries from a reanalysis dataset chunked in space but not time, to compare to observed station data with more precise coordinates.

I use the advanced interpolation method as described in the documentation, with the interpolation coordinates specified by DataArray's with a shared dimension like so:

%load_ext memory_profiler
import numpy as np
import dask.array as da
import xarray as xr

# Synthetic dataset chunked in the two interpolation dimensions
nt = 40000
nx = 200
ny = 200
ds = xr.Dataset(
    data_vars = {
        'foo':(
            ('t', 'x', 'y'), 
            da.random.random(size=(nt, nx, ny), chunks=(-1, 10, 10))),
    },
    coords = {
        't': np.linspace(0, 1, nt),
        'x': np.linspace(0, 1, nx),
        'y': np.linspace(0, 1, ny),
    }
)

# Interpolate to some random 2D locations
ni = 10
xx = xr.DataArray(np.random.random(ni), dims='z', name='x')
yy = xr.DataArray(np.random.random(ni), dims='z', name='y')
interpolated = ds.foo.interp(x=xx, y=yy)
%memit interpolated.compute()

With just 10 interpolation points, this example calculation uses about 1.5 * ds.nbytes of memory, and saturates around 2 * ds.nbytes by about 100 interpolation points.

This could definitely work better, as each interpolated point usually only requires a single chunk of the input dataset, and at most 4 if it is right on the corner of a chunk. For example we can instead do it in a loop and get very reasonable memory usage, but this isn't very scalable:

interpolated = []
for n in range(ni):
    interpolated.append(ds.foo.interp(x=xx.isel(z=n), y=yy.isel(z=n)))
interpolated = xr.concat(interpolated, dim='z')
%memit interpolated.compute()

I tried adding a .chunk({'z':1}) to the interpolation coordinates but this doesn't help. We can also do .sel(x=xx, y=yy, method='nearest') with very good performance.

Any tips to make this calculation work better with existing options, or otherwise ways we might improve the interp method to handle this case? Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d call for all the xx points and from there another to the yy points, which for even a small number of points would require nearly all chunks to be loaded in. But I haven't explored the code enough yet to understand the details.

@slevang slevang added the needs triage Issue that has not been reviewed by xarray team member label Jul 17, 2022
@TomNicholas TomNicholas added topic-performance and removed needs triage Issue that has not been reviewed by xarray team member labels Jul 17, 2022
@dcherian
Copy link
Contributor

dcherian commented Jul 18, 2022

Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d call for all the xx points and from there another to the yy points, which for even a small number of points would require nearly all chunks to be loaded in.

Yeah I think this is right.

You could check if it was better before #4155 (if it worked that is)

cc @pums974 @Illviljan

@slevang
Copy link
Contributor Author

slevang commented Jul 18, 2022

Interpolating on chunked dimensions doesn't work at all prior to #4155. The changes in #4069 are also relevant.

@pums974
Copy link
Contributor

pums974 commented Jul 18, 2022

You are right about the behavior of the code. I don't see any way to enhance that in the general case.

Maybe, in your case, rechunking before interpolating might be a good idea

@slevang
Copy link
Contributor Author

slevang commented Jul 18, 2022

The chunking structure on disk is pretty instrumental to my application, which requires fast retrievals of full slices in the time dimension. The loop option in my first post only takes about 10 seconds with ni=1000 which is fine for my use case, so I'll probably go with that for now. It would be interesting to dig deeper though and see if there is a way to handle this better in the interp logic.

@dcherian
Copy link
Contributor

The current code also has the unfortunate side-effect of merging all chunks too.

I think we should instead think of generating a dask array of weights and then using xr.dot

@gjoseph92
Copy link

The current code also has the unfortunate side-effect of merging all chunks too

Don't really know what I'm talking about here, but it looks to me like the current dask-interpolation routine uses blockwise. That is, it's trying to simply map a function over each chunk in the array. To get the chunks into a structure where this is correct to do, you have to first merge all the chunks along the interpolation axis.

I would have expected interpolation to use map_overlap. You'd add some padding to each chunk, map the interpolation over each chunk (without combining them), then trim off the extra. By using overlap, you don't need to combine all the chunks into one big array first, so the operation can actually be parallel.

FYI, fixing this would probably be a big deal to geospatial people—then you could do array reprojection without GDAL! Unfortunately not something I have time to work on right now, but perhaps someone else would be interested?

@dcherian
Copy link
Contributor

The challenge is you could be interping to an unordered set of locations.

So perhaps we can sort the input locations, do the interp with map_overlap, then argsort the result back to expected order.

@dcherian
Copy link
Contributor

Linking the dask issue: dask/dask#6474

@rhugonnet
Copy link

Hi all,
Maybe this is of interest here: With @ameliefroessl, we might have managed to reduce the memory usage for the case of a regular grid (while interp() assumes rectilinear), see: GlacioHack/geoutils#537 with code here: https://github.com/rhugonnet/geoutils/blob/add_delayed_raster_functions/geoutils/raster/delayed.py#L242.

See a comparison of RAM/execution time here: GlacioHack/xdem#501 (reply in thread). The RAM usage is also checked automatically in our tests and doesn't seem to exceed what we expect 🙂

Using @GenevieveBuckley's very nice blogpost on ragged output (https://blog.dask.org/2021/07/02/ragged-output), we tested both map_overlap(drop_axis=) and delayed and found that the latter really performs better to minimize memory usage.

Unfortunately the implementation is not generic for Xarray, having a regular or equal grid along the interpolated dimensions is only a specific case here. So I guess the question is: Is it common enough that maybe it could be interesting to implement that functionality directly in Xarray if the interpolated dimensions are detected to be regular?
Otherwise, for raster data, it will be soon become available through our Xarray accessor #8041 😉

@ks905383

This comment was marked as outdated.

@dcherian
Copy link
Contributor

q=ds['bar'] yes, all interpolation points are sent to every chunk, so this isn't working as you want it to atm IIUC.

I'm looking in to this and can report back late next week.

@dcherian
Copy link
Contributor

dcherian commented Dec 19, 2024

#9881 fixes #6799 (comment) so I marked it as "Outdated" to not distract future readers


For the OP, I'm experimenting with using vectorized indexing and xr.dot for method="linear".

from typing import Hashable
from xarray import DataArray, Variable


def digitize(to, from_, right=True):
    return np.digitize(to, from_, right) - 1


def xr_interp(obj: DataArray | Variable, to: dict[Hashable, Variable]) -> DataArray | Variable:
    from_ = {dim: obj[dim].variable for dim in to}

    weights = []
    indexers = {}
    sum_dims = []
    for dim in to:
        lo_index = digitize(to[dim], from_[dim])
        hi_index = np.minimum(lo_index + 1, obj.sizes[dim] - 1)
        lo_weight = np.abs((to[dim] - from_[dim][lo_index]) / (from_[dim][hi_index] - from_[dim][lo_index]))
        lo_weight[(to[dim] < from_[dim][0]) | (to[dim] > from_[dim][-1])] = np.nan
        hi_weight = 1. - lo_weight
    
        concat_dim = f"__{dim}__"
        sum_dims.append(concat_dim)
    
        weights_concat = Variable.concat([hi_weight, lo_weight], dim=concat_dim)
        weights.append(weights_concat)
    
        indexers[dim] = Variable.concat([Variable(dim, lo_index), Variable(dim, hi_index)], dim=concat_dim)
    
        #var = hi_weight * var.isel({dim: lo_index}) + lo_weight * var.isel({dim: hi_index})

    result = xr.dot(obj.isel(indexers), *weights, dim=sum_dims, optimize=True)
    result = result.assign_coords(to)
    return result

This works nicely for the vectorized interpolation example in the original post but totally falls apart inside dask for the outer interpolation case: #9907 (code reproduced below; cc @phofl ). I noticed that optimize=True had a big effect for the vectorized indexing example, though the graph is not that great.

EDIT*: It works a lot better with opt_einsum installed

import dask.array
import pandas as pd
import numpy as np

import xarray as xr

arr = xr.DataArray(
    dask.array.random.random((1, 75902, 45910), chunks=(1, "auto", -1)),
    dims=["band", "y", "x"],
    coords={"x": np.linspace(-73.58, -62.11, 45910), "y": np.linspace(-36.08, -55.05, 75902)},
    name="bla",
)

arr2 = xr.DataArray(
    dask.array.random.random((1, 75902, 45910), chunks=(1, "auto", -1)),
    dims=["band", "y", "x"],
    coords={"x": np.linspace(-73.58, -62.11, 45910), "y": np.linspace(-36.08, -55.05, 75902)},
    name="bla",
)

x = arr2.interp(
    x=arr.coords["x"],
    y=arr.coords["y"],
    method="linear",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants