diff --git a/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py b/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py index 4dd80f944b..b370f751d8 100644 --- a/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py +++ b/lib/iris/tests/unit/lazy_data/test_iris_dask_defaults.py @@ -26,47 +26,38 @@ # importing anything else. import iris.tests as tests +import mock + import dask.context from iris._lazy_data import _iris_dask_defaults class Test__iris_dask_defaults(tests.IrisTest): - def setUp(self): - self.context = 'dask.context' - self._globals = 'iris._lazy_data._globals' - set_options = 'dask.set_options' - self.patch_set_options = self.patch(set_options) - get_sync = 'dask.async.get_sync' - self.patch_get_sync = self.patch(get_sync) + def check_call_with_settings(self, test_settings_dict, expect_call=True): + # Check the calls to 'dask.set_options' which result from calling + # _iris_dask_defaults, with the given dask global settings. + self.patch('dask.context._globals', test_settings_dict) + set_options_patch = self.patch('dask.set_options') + _iris_dask_defaults() + self.assertEqual(dask.context._globals, test_settings_dict) + if expect_call: + expect_calls = [mock.call(get=dask.async.get_sync)] + else: + expect_calls = [] + self.assertEqual(set_options_patch.call_args_list, expect_calls) def test_no_user_options(self): - test_dict = {} - with self.patch(self.context, _globals=test_dict): - _iris_dask_defaults() - self.assertEqual(dask.context._globals, test_dict) - self.patch_set_options.assert_called_once_with(get=self.patch_get_sync) + self.check_call_with_settings({}) def test_user_options__pool(self): - test_dict = {'pool': 5} - with self.patch(self.context, _globals=test_dict): - _iris_dask_defaults() - self.assertEqual(dask.context._globals, test_dict) - self.assertEqual(self.patch_set_options.call_count, 0) + self.check_call_with_settings({'pool': 5}, expect_call=False) def test_user_options__get(self): - test_dict = {'get': 'threaded'} - with self.patch(self.context, _globals=test_dict): - _iris_dask_defaults() - self.assertEqual(dask.context._globals, test_dict) - self.assertEqual(self.patch_set_options.call_count, 0) + self.check_call_with_settings({'get': 'threaded'}, expect_call=False) def test_user_options__wibble(self): # Test a user-specified dask option that does not affect Iris. - test_dict = {'wibble': 'foo'} - with self.patch(self.context, _globals=test_dict): - _iris_dask_defaults() - self.assertEqual(dask.context._globals, test_dict) - self.patch_set_options.assert_called_once_with(get=self.patch_get_sync) + self.check_call_with_settings({'wibble': 'foo'}) if __name__ == '__main__':