Skip to content
15 changes: 15 additions & 0 deletions lib/iris/_lazy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip) # noqa

from functools import wraps

import dask
import dask.array as da
import dask.context
Expand All @@ -31,6 +33,19 @@
import numpy.ma as ma


def non_lazy(func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrocklin Do you know if there is already such a decorator defined in dask that we could use? If not, is this a common requirement (i.e. would you like us to submit it to dask)? (e.g. dask.delayed.immediate_dispatch)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No such decorator exists, but depending on what you want to do, Array plugins might be useful:

http://dask.pydata.org/en/latest/array-creation.html#plugins

There is an example there for automatic computation. Again, not exactly what I'm seeing here, but depending on the underlying problem you have this might help.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @mrocklin! I wonder though if Array plugins aren't quite what we're looking for in this case.

"""
Turn a lazy function into a function that returns a result immediately.
"""
@wraps(func)
def inner(*args, **kwargs):
"""Immediately return the results of a lazy function."""
result = func(*args, **kwargs)
return dask.compute(result)[0]
return inner


def is_lazy_data(data):
"""
Return whether the argument is an Iris 'lazy' data array.
Expand Down
149 changes: 80 additions & 69 deletions lib/iris/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from iris.analysis._regrid import RectilinearRegridder
import iris.coords
from iris.exceptions import LazyAggregatorError
import iris._lazy_data as iris_lazy_data
import iris._lazy_data

__all__ = ('COUNT', 'GMEAN', 'HMEAN', 'MAX', 'MEAN', 'MEDIAN', 'MIN',
'PEAK', 'PERCENTILE', 'PROPORTION', 'RMS', 'STD_DEV', 'SUM',
Expand Down Expand Up @@ -468,7 +468,7 @@ def lazy_aggregate(self, data, axis, **kwargs):
# provided to __init__.
kwargs = dict(list(self._kwargs.items()) + list(kwargs.items()))

return self.lazy_func(data, axis, **kwargs)
return self.lazy_func(data, axis=axis, **kwargs)

def aggregate(self, data, axis, **kwargs):
"""
Expand Down Expand Up @@ -1015,6 +1015,40 @@ def post_process(self, collapsed_cube, data_result, coords, **kwargs):
return result


def _build_dask_mdtol_function(dask_stats_function):
"""
Make a wrapped dask statistic function that supports the 'mdtol' keyword.

'dask_function' must be a dask statistical function, compatible with the
call signature : "dask_stats_function(data, axis=axis, **kwargs)".
It must be masked-data tolerant, i.e. it ignores masked input points and
performs a calculation on only the unmasked points.
For example, mean([1, --, 2]) = (1 + 2) / 2 = 1.5.

The returned value is a new function operating on dask arrays.
It has the call signature `stat(data, axis=-1, mdtol=None, **kwargs)`.

"""
@wraps(dask_stats_function)
def inner_stat(array, axis=-1, mdtol=None, **kwargs):
# Call the statistic to get the basic result (missing-data tolerant).
dask_result = dask_stats_function(array, axis=axis, **kwargs)
if mdtol is None or mdtol >= 1.0:
result = dask_result
else:
# Build a lazy computation to compare the fraction of missing
# input points at each output point to the 'mdtol' threshold.
point_mask_counts = da.sum(da.ma.getmaskarray(array), axis=axis)
points_per_calc = array.size / dask_result.size
masked_point_fractions = point_mask_counts / points_per_calc
boolean_mask = masked_point_fractions > mdtol
# Return an mdtol-masked version of the basic result.
result = da.ma.masked_array(da.ma.getdata(dask_result),
boolean_mask)
return result
return inner_stat


def _percentile(data, axis, percent, fast_percentile_method=False,
**kwargs):
"""
Expand Down Expand Up @@ -1191,20 +1225,24 @@ def _weighted_percentile(data, axis, weights, percent, returned=False,
return result


def _count(array, function, axis, **kwargs):
if not callable(function):
raise ValueError('function must be a callable. Got %s.'
% type(function))
return ma.sum(function(array), axis=axis, **kwargs)
@_build_dask_mdtol_function
def _lazy_count(array, **kwargs):
array = iris._lazy_data.as_lazy_data(array)
func = kwargs.pop('function', None)
if not callable(func):
emsg = 'function must be a callable. Got {}.'
raise TypeError(emsg.format(type(func)))
return da.sum(func(array), **kwargs)


def _proportion(array, function, axis, **kwargs):
count = iris._lazy_data.non_lazy(_lazy_count)
# if the incoming array is masked use that to count the total number of
# values
if ma.isMaskedArray(array):
# calculate the total number of non-masked values across the given axis
total_non_masked = _count(array.mask, np.logical_not,
axis=axis, **kwargs)
total_non_masked = count(
array.mask, axis=axis, function=np.logical_not, **kwargs)
total_non_masked = ma.masked_equal(total_non_masked, 0)
else:
total_non_masked = array.shape[axis]
Expand All @@ -1215,7 +1253,7 @@ def _proportion(array, function, axis, **kwargs):
# a dtype for its data that is different to the dtype of the fill-value,
# which can cause issues outside this function.
# Reference - tests/unit/analyis/test_PROPORTION.py Test_masked.test_ma
numerator = _count(array, function, axis=axis, **kwargs)
numerator = count(array, axis=axis, function=function, **kwargs)
result = ma.asarray(numerator / total_non_masked)

return result
Expand All @@ -1228,21 +1266,23 @@ def _rms(array, axis, **kwargs):
return rval


def _sum(array, **kwargs):
@_build_dask_mdtol_function
def _lazy_sum(array, **kwargs):
array = iris._lazy_data.as_lazy_data(array)
# weighted or scaled sum
axis_in = kwargs.get('axis', None)
weights_in = kwargs.pop('weights', None)
returned_in = kwargs.pop('returned', False)
if weights_in is not None:
wsum = ma.sum(weights_in * array, **kwargs)
wsum = da.sum(weights_in * array, **kwargs)
else:
wsum = ma.sum(array, **kwargs)
wsum = da.sum(array, **kwargs)
if returned_in:
if weights_in is None:
weights = np.ones_like(array)
weights = iris._lazy_data.as_lazy_data(np.ones_like(array))
else:
weights = weights_in
rvalue = (wsum, ma.sum(weights, axis=axis_in))
rvalue = (wsum, da.sum(weights, axis=axis_in))
else:
rvalue = wsum
return rvalue
Expand Down Expand Up @@ -1352,8 +1392,9 @@ def interp_order(length):
#
# Common partial Aggregation class constructors.
#
COUNT = Aggregator('count', _count,
units_func=lambda units: 1)
COUNT = Aggregator('count', iris._lazy_data.non_lazy(_lazy_count),
units_func=lambda units: 1,
lazy_func=_lazy_count)
"""
An :class:`~iris.analysis.Aggregator` instance that counts the number
of :class:`~iris.cube.Cube` data occurrences that satisfy a particular
Expand Down Expand Up @@ -1419,56 +1460,6 @@ def interp_order(length):
"""


MAX = Aggregator('maximum', ma.max)
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
the maximum over a :class:`~iris.cube.Cube`, as computed by
:func:`numpy.ma.max`.

**For example**:

To compute zonal maximums over the *longitude* axis of a cube::

result = cube.collapsed('longitude', iris.analysis.MAX)

This aggregator handles masked data.

"""


def _build_dask_mdtol_function(dask_stats_function):
"""
Make a wrapped dask statistic function that supports the 'mdtol' keyword.

'dask_function' must be a dask statistical function, compatible with the
call signature : "dask_stats_function(data, axis, **kwargs)".
It must be masked-data tolerant, i.e. it ignores masked input points and
performs a calculation on only the unmasked points.
For example, mean([1, --, 2]) = (1 + 2) / 2 = 1.5.

The returned value is a new function operating on dask arrays.
It has the call signature `stat(data, axis=-1, mdtol=None, **kwargs)`.

"""
@wraps(dask_stats_function)
def inner_stat(array, axis=-1, mdtol=None, **kwargs):
# Call the statistic to get the basic result (missing-data tolerant).
dask_result = dask_stats_function(array, axis=axis, **kwargs)
if mdtol is None or mdtol >= 1.0:
result = dask_result
else:
# Build a lazy computation to compare the fraction of missing
# input points at each output point to the 'mdtol' threshold.
point_mask_counts = da.sum(da.ma.getmaskarray(array), axis=axis)
points_per_calc = array.size / dask_result.size
masked_point_fractions = point_mask_counts / points_per_calc
boolean_mask = masked_point_fractions > mdtol
# Return an mdtol-masked version of the basic result.
result = da.ma.masked_array(da.ma.getdata(dask_result),
boolean_mask)
return result
return inner_stat

MEAN = WeightedAggregator('mean', ma.average,
lazy_func=_build_dask_mdtol_function(da.mean))
"""
Expand Down Expand Up @@ -1534,7 +1525,8 @@ def inner_stat(array, axis=-1, mdtol=None, **kwargs):
"""


MIN = Aggregator('minimum', ma.min)
MIN = Aggregator('minimum', ma.min,
lazy_func=_build_dask_mdtol_function(da.min))
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
the minimum over a :class:`~iris.cube.Cube`, as computed by
Expand All @@ -1551,6 +1543,24 @@ def inner_stat(array, axis=-1, mdtol=None, **kwargs):
"""


MAX = Aggregator('maximum', ma.max,
lazy_func=_build_dask_mdtol_function(da.max))
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
the maximum over a :class:`~iris.cube.Cube`, as computed by
:func:`numpy.ma.max`.

**For example**:

To compute zonal maximums over the *longitude* axis of a cube::

result = cube.collapsed('longitude', iris.analysis.MAX)

This aggregator handles masked data.

"""


PEAK = Aggregator('peak', _peak)
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
Expand Down Expand Up @@ -1700,7 +1710,8 @@ def inner_stat(array, axis=-1, mdtol=None, **kwargs):
"""


SUM = WeightedAggregator('sum', _sum)
SUM = WeightedAggregator('sum', iris._lazy_data.non_lazy(_lazy_sum),
lazy_func=_build_dask_mdtol_function(_lazy_sum))
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
the sum over a :class:`~iris.cube.Cube`, as computed by :func:`numpy.ma.sum`.
Expand Down
8 changes: 4 additions & 4 deletions lib/iris/tests/unit/analysis/test_Aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_kwarg_pass_through_no_kwargs(self):
axis = mock.sentinel.axis
aggregator = Aggregator('', None, lazy_func=lazy_func)
aggregator.lazy_aggregate(data, axis)
lazy_func.assert_called_once_with(data, axis)
lazy_func.assert_called_once_with(data, axis=axis)

def test_kwarg_pass_through_call_kwargs(self):
lazy_func = mock.Mock()
Expand All @@ -292,7 +292,7 @@ def test_kwarg_pass_through_call_kwargs(self):
kwargs = dict(wibble='wobble', foo='bar')
aggregator = Aggregator('', None, lazy_func=lazy_func)
aggregator.lazy_aggregate(data, axis, **kwargs)
lazy_func.assert_called_once_with(data, axis, **kwargs)
lazy_func.assert_called_once_with(data, axis=axis, **kwargs)

def test_kwarg_pass_through_init_kwargs(self):
lazy_func = mock.Mock()
Expand All @@ -301,7 +301,7 @@ def test_kwarg_pass_through_init_kwargs(self):
kwargs = dict(wibble='wobble', foo='bar')
aggregator = Aggregator('', None, lazy_func=lazy_func, **kwargs)
aggregator.lazy_aggregate(data, axis)
lazy_func.assert_called_once_with(data, axis, **kwargs)
lazy_func.assert_called_once_with(data, axis=axis, **kwargs)

def test_kwarg_pass_through_combined_kwargs(self):
lazy_func = mock.Mock()
Expand All @@ -313,7 +313,7 @@ def test_kwarg_pass_through_combined_kwargs(self):
aggregator.lazy_aggregate(data, axis, **call_kwargs)
expected_kwargs = init_kwargs.copy()
expected_kwargs.update(call_kwargs)
lazy_func.assert_called_once_with(data, axis, **expected_kwargs)
lazy_func.assert_called_once_with(data, axis=axis, **expected_kwargs)


if __name__ == "__main__":
Expand Down
Loading