diff --git a/lib/iris/_lazy_data.py b/lib/iris/_lazy_data.py index b413d8752a..f940216051 100644 --- a/lib/iris/_lazy_data.py +++ b/lib/iris/_lazy_data.py @@ -42,17 +42,23 @@ def _iris_dask_defaults(): all available CPUs. .. note:: + We only want Iris to set dask options in the case where doing so will not change user-specified options that have already been set. """ - if 'pool' not in dask.context._globals and \ - 'get' not in dask.context._globals: - dask.set_options(get=dget_sync) - - -# Run this at import time to set dask options for Iris. -_iris_dask_defaults() + dask_opts = {} + dask_globals = getattr(dask.context, '_globals') + if dask_globals is not None: + if 'pool' not in dask_globals and \ + 'get' not in dask_globals: + dask_opts.update(get=dget_sync) + else: + # We may need to unset a previously-set default. + if dask_opts.get('get') is not None: + dask_opts = {key: value for key, value in dask_opts.items() + if key != 'get'} + return dask_opts def is_lazy_data(data): @@ -122,11 +128,14 @@ def as_concrete_data(data): """ if is_lazy_data(data): + # Check dask options at runtime to see if we need to set dask options + # for use in Iris. + dask_opts = _iris_dask_defaults() # Realise dask array, ensuring the data result is always a NumPy array. # In some cases dask may return a scalar numpy.int/numpy.float object # rather than a numpy.ndarray object. # Recorded in https://github.com/dask/dask/issues/2111. - data = np.asanyarray(data.compute()) + data = np.asanyarray(data.compute(**dask_opts)) return data diff --git a/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py b/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py index cc9bbdfa3e..5b9a515a63 100644 --- a/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py +++ b/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py @@ -25,38 +25,67 @@ # Import iris.tests first so that some things can be initialised before # importing anything else. import iris.tests as tests +import dask.context from iris._lazy_data import _iris_dask_defaults class Test__iris_dask_defaults(tests.IrisTest): def setUp(self): - set_options = 'dask.set_options' - self.patch_set_options = self.patch(set_options) self.mock_get_sync = tests.mock.sentinel.get_sync get_sync = 'iris._lazy_data.dget_sync' self.patch_get_sync = self.patch(get_sync, self.mock_get_sync) + self.iris_defaults = {'get': self.patch_get_sync} + + def test_dask_context_api(self): + # A first line of defence to check `dask.context._globals` + # still exists. + self.assertTrue(hasattr(dask.context, '_globals')) def test_no_user_options(self): - self.patch('dask.context._globals', {}) - _iris_dask_defaults() - self.patch_set_options.assert_called_once_with(get=self.patch_get_sync) + with tests.mock.patch('dask.context._globals', {}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, self.iris_defaults) def test_user_options__pool(self): - self.patch('dask.context._globals', {'pool': 5}) - _iris_dask_defaults() - self.assertEqual(self.patch_set_options.call_count, 0) + with tests.mock.patch('dask.context._globals', {'pool': 5}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, {}) def test_user_options__get(self): - self.patch('dask.context._globals', {'get': 'threaded'}) - _iris_dask_defaults() - self.assertEqual(self.patch_set_options.call_count, 0) + with tests.mock.patch('dask.context._globals', {'get': 'threaded'}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, {}) def test_user_options__wibble(self): # Test a user-specified dask option that does not affect Iris. - self.patch('dask.context._globals', {'wibble': 'foo'}) - _iris_dask_defaults() - self.patch_set_options.assert_called_once_with(get=self.patch_get_sync) + with tests.mock.patch('dask.context._globals', {'wibble': 'foo'}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, self.iris_defaults) + + def test_changed_options__add(self): + # Check that adding dask options during a session alters Iris dask + # processing options. + # Starting condition: no dask options set. + with tests.mock.patch('dask.context._globals', {}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, self.iris_defaults) + # Updated condition: dask option is set. + with tests.mock.patch('dask.context._globals', {'pool': 5}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, {}) + + def test_changed_options__remove(self): + # Check that removing dask options during a session alters Iris dask + # processing options. + # Starting condition: dask option is set. + with tests.mock.patch('dask.context._globals', {'get': 'threaded'}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, {}) + # Updated condition: no dask options set. + with tests.mock.patch('dask.context._globals', {}): + opts = _iris_dask_defaults() + self.assertDictEqual(opts, self.iris_defaults) if __name__ == '__main__':