diff --git a/docs/src/whatsnew/latest.rst b/docs/src/whatsnew/latest.rst index b94cd11517..1790f51daa 100644 --- a/docs/src/whatsnew/latest.rst +++ b/docs/src/whatsnew/latest.rst @@ -32,14 +32,15 @@ This document explains the changes made to Iris for this release =========== #. `@ESadek-MO`_ edited :func:`~iris.io.expand_filespecs` to allow expansion of - non-existing paths, and added expansion functionality to :func:`~iris.io.save`. - (:issue:`4772`, :pull:`4913`) + non-existing paths, and added expansion functionality to :func:`~iris.io.save`. + (:issue:`4772`, :pull:`4913`) 🐛 Bugs Fixed ============= -#. N/A +#. `@rcomer`_ and `@pp-mo`_ (reviewer) factored masking into the returned + sum-of-weights calculation from :obj:`~iris.analysis.SUM`. (:pull:`4905`) 💣 Incompatible Changes @@ -51,7 +52,9 @@ This document explains the changes made to Iris for this release 🚀 Performance Enhancements =========================== -#. N/A +#. `@rcomer`_ and `@pp-mo`_ (reviewer) increased aggregation speed for + :obj:`~iris.analysis.SUM`, :obj:`~iris.analysis.COUNT` and + :obj:`~iris.analysis.PROPORTION` on real data. (:pull:`4905`) 🔥 Deprecations @@ -63,7 +66,8 @@ This document explains the changes made to Iris for this release 🔗 Dependencies =============== -#. N/A +#. `@rcomer`_ introduced the ``dask >=2.26`` minimum pin, so that Iris can benefit + from Dask's support for `NEP13`_ and `NEP18`_. (:pull:`4905`) 📚 Documentation @@ -89,3 +93,5 @@ This document explains the changes made to Iris for this release Whatsnew resources in alphabetical order: +.. _NEP13: https://numpy.org/neps/nep-0013-ufunc-overrides.html +.. _NEP18: https://numpy.org/neps/nep-0018-array-function-protocol.html \ No newline at end of file diff --git a/lib/iris/analysis/__init__.py b/lib/iris/analysis/__init__.py index 9d8392cd95..11810f2901 100644 --- a/lib/iris/analysis/__init__.py +++ b/lib/iris/analysis/__init__.py @@ -1499,18 +1499,21 @@ def _weighted_percentile( return result -@_build_dask_mdtol_function -def _lazy_count(array, **kwargs): - array = iris._lazy_data.as_lazy_data(array) +def _count(array, **kwargs): + """ + Counts the number of points along the axis that satisfy the condition + specified by ``function``. Uses Dask's support for NEP13/18 to work as + either a lazy or a real function. + + """ func = kwargs.pop("function", None) if not callable(func): emsg = "function must be a callable. Got {}." raise TypeError(emsg.format(type(func))) - return da.sum(func(array), **kwargs) + return np.sum(func(array), **kwargs) def _proportion(array, function, axis, **kwargs): - count = iris._lazy_data.non_lazy(_lazy_count) # if the incoming array is masked use that to count the total number of # values if ma.isMaskedArray(array): @@ -1521,7 +1524,7 @@ def _proportion(array, function, axis, **kwargs): # case pass the array shape instead of the mask: total_non_masked = array.shape[axis] else: - total_non_masked = count( + total_non_masked = _count( array.mask, axis=axis, function=np.logical_not, **kwargs ) total_non_masked = ma.masked_equal(total_non_masked, 0) @@ -1534,7 +1537,7 @@ def _proportion(array, function, axis, **kwargs): # a dtype for its data that is different to the dtype of the fill-value, # which can cause issues outside this function. # Reference - tests/unit/analyis/test_PROPORTION.py Test_masked.test_ma - numerator = count(array, axis=axis, function=function, **kwargs) + numerator = _count(array, axis=axis, function=function, **kwargs) result = ma.asarray(numerator / total_non_masked) return result @@ -1604,23 +1607,33 @@ def _lazy_rms(array, axis, **kwargs): return da.sqrt(da.mean(array**2, axis=axis, **kwargs)) -@_build_dask_mdtol_function -def _lazy_sum(array, **kwargs): - array = iris._lazy_data.as_lazy_data(array) - # weighted or scaled sum +def _sum(array, **kwargs): + """ + Weighted or scaled sum. Uses Dask's support for NEP13/18 to work as either + a lazy or a real function. + + """ axis_in = kwargs.get("axis", None) weights_in = kwargs.pop("weights", None) returned_in = kwargs.pop("returned", False) if weights_in is not None: - wsum = da.sum(weights_in * array, **kwargs) + wsum = np.sum(weights_in * array, **kwargs) else: - wsum = da.sum(array, **kwargs) + wsum = np.sum(array, **kwargs) if returned_in: + al = da if iris._lazy_data.is_lazy_data(array) else np if weights_in is None: - weights = iris._lazy_data.as_lazy_data(np.ones_like(array)) + weights = al.ones_like(array) + if al is da: + # Dask version of ones_like does not preserve masks. See dask#9301. + weights = da.ma.masked_array( + weights, da.ma.getmaskarray(array) + ) else: - weights = weights_in - rvalue = (wsum, da.sum(weights, axis=axis_in)) + weights = al.ma.masked_array( + weights_in, mask=al.ma.getmaskarray(array) + ) + rvalue = (wsum, np.sum(weights, axis=axis_in)) else: rvalue = wsum return rvalue @@ -1740,9 +1753,9 @@ def interp_order(length): # COUNT = Aggregator( "count", - iris._lazy_data.non_lazy(_lazy_count), + _count, units_func=lambda units: 1, - lazy_func=_lazy_count, + lazy_func=_build_dask_mdtol_function(_count), ) """ An :class:`~iris.analysis.Aggregator` instance that counts the number @@ -2114,8 +2127,8 @@ def interp_order(length): SUM = WeightedAggregator( "sum", - iris._lazy_data.non_lazy(_lazy_sum), - lazy_func=_build_dask_mdtol_function(_lazy_sum), + _sum, + lazy_func=_build_dask_mdtol_function(_sum), ) """ An :class:`~iris.analysis.Aggregator` instance that calculates diff --git a/lib/iris/tests/unit/analysis/test_SUM.py b/lib/iris/tests/unit/analysis/test_SUM.py index dd2dcf9f9c..64699b442f 100644 --- a/lib/iris/tests/unit/analysis/test_SUM.py +++ b/lib/iris/tests/unit/analysis/test_SUM.py @@ -9,6 +9,7 @@ # importing anything else. import iris.tests as tests # isort:skip +import dask.array as da import numpy as np import numpy.ma as ma @@ -91,6 +92,16 @@ def test_weights_and_returned(self): self.assertArrayEqual(data, [14, 9, 11, 13, 15]) self.assertArrayEqual(weights, [4, 2, 2, 2, 2]) + def test_masked_weights_and_returned(self): + array = ma.array( + self.cube_2d.data, mask=[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]] + ) + data, weights = SUM.aggregate( + array, axis=0, weights=self.weights, returned=True + ) + self.assertArrayEqual(data, [14, 9, 8, 4, 15]) + self.assertArrayEqual(weights, [4, 2, 1, 1, 2]) + class Test_lazy_weights_and_returned(tests.IrisTest): def setUp(self): @@ -128,6 +139,17 @@ def test_weights_and_returned(self): self.assertArrayEqual(lazy_data.compute(), [14, 9, 11, 13, 15]) self.assertArrayEqual(weights, [4, 2, 2, 2, 2]) + def test_masked_weights_and_returned(self): + array = da.ma.masked_array( + self.cube_2d.lazy_data(), mask=[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]] + ) + lazy_data, weights = SUM.lazy_aggregate( + array, axis=0, weights=self.weights, returned=True + ) + self.assertTrue(is_lazy_data(lazy_data)) + self.assertArrayEqual(lazy_data.compute(), [14, 9, 8, 4, 15]) + self.assertArrayEqual(weights, [4, 2, 1, 1, 2]) + class Test_aggregate_shape(tests.IrisTest): def test(self): diff --git a/requirements/ci/py310.yml b/requirements/ci/py310.yml index 81c33c494b..795ecdc1ce 100644 --- a/requirements/ci/py310.yml +++ b/requirements/ci/py310.yml @@ -14,7 +14,7 @@ dependencies: - cartopy >=0.20 - cf-units >=3.1 - cftime >=1.5 - - dask-core >=2 + - dask-core >=2.26 - matplotlib - netcdf4 - numpy >=1.19 diff --git a/requirements/ci/py38.yml b/requirements/ci/py38.yml index 6353cfa3a5..da5e00cfac 100644 --- a/requirements/ci/py38.yml +++ b/requirements/ci/py38.yml @@ -14,7 +14,7 @@ dependencies: - cartopy >=0.20 - cf-units >=3.1 - cftime >=1.5 - - dask-core >=2 + - dask-core >=2.26 - matplotlib - netcdf4 - numpy >=1.19 diff --git a/requirements/ci/py39.yml b/requirements/ci/py39.yml index 14a93f00fd..c5600c06b4 100644 --- a/requirements/ci/py39.yml +++ b/requirements/ci/py39.yml @@ -14,7 +14,7 @@ dependencies: - cartopy >=0.20 - cf-units >=3.1 - cftime >=1.5 - - dask-core >=2 + - dask-core >=2.26 - matplotlib - netcdf4 - numpy >=1.19 diff --git a/setup.cfg b/setup.cfg index 25b0ecca5a..142d2114e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ install_requires = cartopy>=0.20 cf-units>=3.1 cftime>=1.5.0 - dask[array]>=2 + dask[array]>=2.26 matplotlib netcdf4 numpy>=1.19