Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions lib/iris/_lazy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def _limited_shape(shape):
return tuple(shape)


def _getall(a):
res = a[()]
if isinstance(res, ma.core.MaskedConstant):
res = ma.masked_array(res.data, mask=res.mask)
return res

_getall_delayed = dask.delayed(_getall)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a decorator for this purpose...



def as_lazy_data(data, chunks=None, asarray=False):
"""
Convert the input array `data` to a dask array.
Expand Down Expand Up @@ -119,10 +128,15 @@ def as_lazy_data(data, chunks=None, asarray=False):
# but reduce it if larger than a default maximum size.
chunks = _limited_shape(data.shape)

if isinstance(data, ma.core.MaskedConstant):
data = ma.masked_array(data.data, mask=data.mask)
if not is_lazy_data(data):
data = da.from_array(data, chunks=chunks, asarray=asarray)
if data.shape == ():
# Workaround for https://github.com/dask/dask/issues/2823. Make
# sure scalar dask arrays return numpy objects.
dtype = data.dtype
data = _getall_delayed(data)
data = da.from_delayed(data, (), dtype)
else:
data = da.from_array(data, chunks=chunks, asarray=asarray)
return data


Expand Down
42 changes: 42 additions & 0 deletions lib/iris/tests/unit/lazy_data/test_as_concrete_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,26 @@
# importing anything else.
import iris.tests as tests

import unittest

import dask.array as da
import numpy as np
import numpy.ma as ma

from iris._lazy_data import as_concrete_data, as_lazy_data, is_lazy_data
from iris.tests import mock


class MyProxy(object):
def __init__(self, a):
self.shape = a.shape
self.dtype = a.dtype
self.a = a

def __getitem__(self, keys):
return self.a[keys]


class Test_as_concrete_data(tests.IrisTest):
def test_concrete_input_data(self):
data = np.arange(24).reshape((4, 6))
Expand Down Expand Up @@ -62,6 +75,35 @@ def test_lazy_mask_data(self):
self.assertMaskedArrayEqual(result, mask_data)
self.assertEqual(result.fill_value, fill_value)

def test_lazy_scalar_proxy(self):
a = np.array(5)
proxy = MyProxy(a)
lazy_array = as_lazy_data(proxy)
self.assertTrue(is_lazy_data(lazy_array))
result = as_concrete_data(lazy_array)
self.assertFalse(is_lazy_data(result))
self.assertEqual(result, a)

def test_lazy_scalar_proxy_masked(self):
a = np.ma.masked_array(5, True)
proxy = MyProxy(a)
lazy_array = as_lazy_data(proxy)
self.assertTrue(is_lazy_data(lazy_array))
result = as_concrete_data(lazy_array)
self.assertFalse(is_lazy_data(result))
self.assertMaskedArrayEqual(result, a)

def test_dask_scalar_proxy_pass_through(self):
# This test will fail when using a version of Dask with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little bit naughty (hello future person who is going back and reading this PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello

# https://github.com/dask/dask/issues/2823 fixed. At that point the
# changes introduced in https://github.com/SciTools/iris/pull/2878 can
# be reversed.
a = np.array(5)
proxy = MyProxy(a)
d = da.from_array(proxy, 1, asarray=False)
result = d.compute()
self.assertEqual(proxy, result)


if __name__ == '__main__':
tests.main()