Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked pint arrays break on rolling() #186

Open
riley-brady opened this issue Oct 5, 2022 · 5 comments
Open

Chunked pint arrays break on rolling() #186

riley-brady opened this issue Oct 5, 2022 · 5 comments

Comments

@riley-brady
Copy link

Hi folks,

I noticed that when running .rolling(...) on a chunked pint array, there is an exception raised that breaks the process:

TypeError: `pad_value` must be composed of integral typed values.

I outline three different cases below for running .rolling() on a pint-aware DataArray.

  1. Calculating the rolling sum on an in-memory pint array.
  2. Calculating the rolling sum on a chunked pint array, using xarray chunking.
    • This works, even without turning off bottleneck. However, this isn't an optimal solution for me, since one cannot query ds.pint.units on an xarray-chunked pint array. I like being able to do that for various QOL checks in a data pipeline.
  3. Calculating the rolling sum on a pint array chunked with ds.pint.chunk(...).
    • This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off bottleneck.
import pint_xarray
import xarray as xr
print(xr.__version__)
>>> '2022.6.0'
print(pint_xarray.__version__)
>>> '0.3'

data = xr.DataArray(range(3), dims='time').pint.quantify('kelvin')
print(data)
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([0 1 2], 'kelvin')>

# Case 1: rolling sum with `pint` units. 
# Lose the units as expected, but executes properly.
rs = data.rolling(time=2).sum()
print(rs)
>>> <xarray.DataArray (time: 3)>
>>> array([nan,  1.,  3.])

# Case 2: rolling sum with `xr.chunk()`
# Maintain the units after compute, 
# but `data_xr_chunk.pint.units` returns `None` in the interim
data_xr_chunk = data.chunk({'time': 1})
rs = data_xr_chunk.rolling(time=2).sum().compute()
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([nan  1.  3.], 'kelvin')>

# Case 3: rolling sum with `xr.pint.chunk()`
# Maintains units on chunked array, but raises exception
# (see full traceback below)
data_pint_chunk = data.pint.chunk({"time": 1})
rs = data_pint_chunk.rolling(time=2).sum().compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 rs = data_pint_chunk.rolling(time=2).sum().compute()

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:155, in Rolling._reduce_method.<locals>.method(self, keep_attrs, **kwargs)
    151 def method(self, keep_attrs=None, **kwargs):
    153     keep_attrs = self._get_keep_attrs(keep_attrs)
--> 155     return self._numpy_or_bottleneck_reduce(
    156         array_agg_func,
    157         bottleneck_move_func,
    158         rolling_agg_func,
    159         keep_attrs=keep_attrs,
    160         fillna=fillna,
    161         **kwargs,
    162     )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:589, in DataArrayRolling._numpy_or_bottleneck_reduce(self, array_agg_func, bottleneck_move_func, rolling_agg_func, keep_attrs, fillna, **kwargs)
    586     kwargs.setdefault("skipna", False)
    587     kwargs.setdefault("fillna", fillna)
--> 589 return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:472, in DataArrayRolling.reduce(self, func, keep_attrs, **kwargs)
    470 else:
    471     obj = self.obj
--> 472 windows = self._construct(
    473     obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
    474 )
    476 result = windows.reduce(
    477     func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
    478 )
    480 # Find valid windows based on count.

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:389, in DataArrayRolling._construct(self, obj, window_dim, stride, fill_value, keep_attrs, **window_dim_kwargs)
    384 window_dims = self._mapping_to_list(
    385     window_dim, allow_default=False, allow_allsame=False  # type: ignore[arg-type]  # https://github.com/python/mypy/issues/12506
    386 )
    387 strides = self._mapping_to_list(stride, default=1)
--> 389 window = obj.variable.rolling_window(
    390     self.dim, self.window, window_dims, self.center, fill_value=fill_value
    391 )
    393 attrs = obj.attrs if keep_attrs else {}
    395 result = DataArray(
    396     window,
    397     dims=obj.dims + tuple(window_dims),
   (...)
    400     name=obj.name,
    401 )

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:2314, in Variable.rolling_window(self, dim, window, window_dim, center, fill_value)
   2311     else:
   2312         pads[d] = (win - 1, 0)
-> 2314 padded = var.pad(pads, mode="constant", constant_values=fill_value)
   2315 axis = [self.get_axis_num(d) for d in dim]
   2316 new_dims = self.dims + tuple(window_dim)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:1416, in Variable.pad(self, pad_width, mode, stat_length, constant_values, end_values, reflect_type, **pad_width_kwargs)
   1413 if reflect_type is not None:
   1414     pad_option_kwargs["reflect_type"] = reflect_type
-> 1416 array = np.pad(  # type: ignore[call-overload]
   1417     self.data.astype(dtype, copy=False),
   1418     pad_width_by_index,
   1419     mode=mode,
   1420     **pad_option_kwargs,
   1421 )
   1423 return type(self)(self.dims, array)

File <__array_function__ internals>:180, in pad(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/quantity.py:1730, in Quantity.__array_function__(self, func, types, args, kwargs)
   1729 def __array_function__(self, func, types, args, kwargs):
-> 1730     return numpy_wrap("function", func, args, kwargs, types)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:936, in numpy_wrap(func_type, func, args, kwargs, types)
    934 if name not in handled or any(is_upcast_type(t) for t in types):
    935     return NotImplemented
--> 936 return handled[name](*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:660, in _pad(array, pad_width, mode, **kwargs)
    656     if key in kwargs:
    657         kwargs[key] = _recursive_convert(kwargs[key], units)
    659 return units._REGISTRY.Quantity(
--> 660     np.pad(array._magnitude, pad_width, mode=mode, **kwargs), units
    661 )

File <__array_function__ internals>:180, in pad(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/core.py:1762, in Array.__array_function__(self, func, types, args, kwargs)
   1759 if has_keyword(da_func, "like"):
   1760     kwargs["like"] = self
-> 1762 return da_func(*args, **kwargs)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:1229, in pad(array, pad_width, mode, **kwargs)
   1227 elif mode == "constant":
   1228     kwargs.setdefault("constant_values", 0)
-> 1229     return pad_edge(array, pad_width, mode, **kwargs)
   1230 elif mode == "linear_ramp":
   1231     kwargs.setdefault("end_values", 0)

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in pad_edge(array, pad_width, mode, **kwargs)
    957 def pad_edge(array, pad_width, mode, **kwargs):
    958     """
    959     Helper function for padding edges.
    960 
    961     Handles the cases where the only the values on the edge are needed.
    962     """
--> 964     kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
    966     result = array
    967     for d in range(array.ndim):

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in <dictcomp>(.0)
    957 def pad_edge(array, pad_width, mode, **kwargs):
    958     """
    959     Helper function for padding edges.
    960 
    961     Handles the cases where the only the values on the edge are needed.
    962     """
--> 964     kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
    966     result = array
    967     for d in range(array.ndim):

File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:910, in expand_pad_value(array, pad_value)
    908     pad_value = array.ndim * (tuple(pad_value[0]),)
    909 else:
--> 910     raise TypeError("`pad_value` must be composed of integral typed values.")
    912 return pad_value

TypeError: `pad_value` must be composed of integral typed values.

My solution in the interim is to do something like:

units = data.pint.units
data = data.pint.dequantify()
rs = data.rolling(time=2)
rs = rs.pint.quantify(units)
@riley-brady
Copy link
Author

Another side note that came up here -- I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum().

E.g.,

data = xr.DataArray(range(3), dims='time').pint.quantify('mm/day')
data.pint.units
>>> mm/day
data = data.rolling(time=2).sum()
data.pint.units
>>> mm

@keewis
Copy link
Collaborator

keewis commented Oct 6, 2022

thanks for the report, @riley-brady. It seems that xarray operations on pint+dask are not as thoroughly tested as pint and dask on their own. I think this is a bug in pint (or dask, not sure): we enable force_ndarray_like to convert scalars to 0d arrays, which means that the final call to np.pad becomes:

np.pad(magnitude, pad_width, mode="constant", constant_values=np.array(0))

numpy seems to be fine with that, but dask is not.

@jrbourbeau, what do you think? Would it make sense to extend expand_pad_value to unpack 0d arrays (using .item()), or would you rather have the caller (pint, in this case) do that?

@keewis
Copy link
Collaborator

keewis commented Oct 6, 2022

I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum()

I'm not sure I follow. Why would rolling().sum() work similar to integration, when all it does is compute a grouped sum? I'm not sure if this actually counts as integration, but you can multiply the result of the rolling sum with the diff of the time coordinate (which is a bit tricky because time is an indexed coordinate):

data = xr.DataArray(
    range(3), dims="time", coords={"time2": ("time", [1, 2, 3])}
).pint.quantify("mm/day", time="day")
dt = data.time2.pad(time=(1, 0)).diff(dim="time")
data.rolling(time=2).sum() * dt

and then you would have the correct units (with the same numerical result, because I chose the time coordinate to have increments of 1 day)

@riley-brady
Copy link
Author

Thanks for the quick feedback on this issue @keewis.

Also thanks for the demo with .diff(). You're right about the integration assumptions. In my specific use case I am doing a rolling sum of units mm/day with daily time steps, so in this case it should reflect total precip in mm, but that's not a fair assumption for many other cases. I'll give the .diff() method a try.

@keewis
Copy link
Collaborator

keewis commented Dec 13, 2023

this should be fixed in dask since quite a while ago, but I'll leave it open until we have tests for this (probably after copying the test suite from xarray)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants