diff --git a/lib/iris/_lazy_data.py b/lib/iris/_lazy_data.py
index 7d525477d6..d2cf9b2ac9 100644
--- a/lib/iris/_lazy_data.py
+++ b/lib/iris/_lazy_data.py
@@ -95,6 +95,10 @@ def as_concrete_data(data, **kwargs):
Returns:
A NumPy `ndarray` or masked array.
+ .. note::
+ Specific dask options for computation are controlled by
+ :class:`iris.config.Parallel`.
+
"""
if is_lazy_data(data):
# Realise dask array, ensuring the data result is always a NumPy array.
diff --git a/lib/iris/config.py b/lib/iris/config.py
index e14ac058a4..76cbf9410d 100644
--- a/lib/iris/config.py
+++ b/lib/iris/config.py
@@ -1,4 +1,4 @@
-# (C) British Crown Copyright 2010 - 2016, Met Office
+# (C) British Crown Copyright 2010 - 2017, Met Office
#
# This file is part of Iris.
#
@@ -66,11 +66,24 @@
from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip) # noqa
-
+import six
from six.moves import configparser
+
+import contextlib
+from multiprocessing import cpu_count
+from multiprocessing.pool import ThreadPool
import os.path
+import re
import warnings
+import dask
+import dask.multiprocessing
+
+try:
+ import distributed
+except ImportError:
+ distributed = None
+
# Returns simple string options
def get_option(section, option, default=None):
@@ -154,3 +167,271 @@ def get_dir_option(section, option, default=None):
IMPORT_LOGGER = get_option(_LOGGING_SECTION, 'import_logger')
+
+
+#################
+# Runtime options
+
+class Option(object):
+ """
+ An abstract superclass to enforce certain key behaviours for all `Option`
+ classes.
+
+ """
+ @property
+ def _defaults_dict(self):
+ raise NotImplementedError
+
+ def __setattr__(self, name, value):
+ if value is None:
+ # Set an explicitly unset value to the default value for the name.
+ value = self._defaults_dict[name]['default']
+ if self._defaults_dict[name]['options'] is not None:
+ # Replace a bad value with the default if there is a defined set of
+ # specified good values.
+ if value not in self._defaults_dict[name]['options']:
+ good_value = self._defaults_dict[name]['default']
+ wmsg = ('Attempting to set bad value {!r} for attribute {!r}. '
+ 'Defaulting to {!r}.')
+ warnings.warn(wmsg.format(value, name, good_value))
+ value = good_value
+ self.__dict__[name] = value
+
+ def context(self):
+ raise NotImplementedError
+
+
+class Parallel(Option):
+ """
+ Control dask parallel processing options for Iris.
+
+ """
+ def __init__(self, scheduler=None, num_workers=None):
+ """
+ Set up options for dask parallel processing.
+
+ Currently accepted kwargs:
+
+ * scheduler:
+ The scheduler used to run a dask graph. Must be set to one of:
+
+ * 'threaded': (default)
+ The scheduler processes the graph in parallel using a
+ thread pool. Good for processing dask arrays and dataframes.
+ * 'multiprocessing':
+ The scheduler processes the graph in parallel using a
+ process pool. Good for processing dask bags.
+ * 'async':
+ The scheduler runs synchronously (not in parallel). Good for
+ debugging.
+ * The IP address and port of a distributed scheduler:
+ Specifies the location of a distributed scheduler that has
+ already been set up. The distributed scheduler will process the
+ graph.
+
+ For more information see
+ http://dask.pydata.org/en/latest/scheduler-overview.html.
+
+ * num_workers:
+ The number of worker threads or processess to use to run the dask
+ graph in parallel. Defaults to 1 (that is, processed serially).
+
+ .. note::
+ The value for `num_workers` cannot be set to greater than the
+ number of CPUs available on the host system. If such a value is
+ requested, `num_workers` is automatically set to 1 less than
+ the number of CPUs available on the host system.
+
+ .. note::
+ Only the 'threaded' and 'multiprocessing' schedulers support
+ the `num_workers` kwarg. If it is specified with the `async` or
+ `distributed` scheduler, the kwarg is ignored:
+
+ * The 'async' scheduler runs serially so will only use a single
+ worker.
+ * The number of workers for the 'distributed' scheduler must be
+ defined when setting up the distributed scheduler. For more
+ information on setting up distributed schedulers, see
+ https://distributed.readthedocs.io/en/latest/index.html.
+
+ Example usages:
+
+ * Specify that we want to load a cube with dask parallel processing
+ using multiprocessing with six worker processes::
+
+ iris.config.parallel(scheduler='multiprocessing', num_workers=6)
+ iris.load('my_dataset.nc')
+
+ * Specify, with a context manager, that we want to load a cube with
+ dask parallel processing using four worker threads::
+
+ with iris.config.parallel(scheduler='threaded', num_workers=4):
+ iris.load('my_dataset.nc')
+
+ * Run dask parallel processing using a distributed scheduler that has
+ been set up at the IP address and port at ``192.168.0.219:8786``::
+
+ iris.config.parallel(scheduler='192.168.0.219:8786')
+
+ """
+ # Set `__dict__` keys first.
+ self.__dict__['_scheduler'] = scheduler
+ self.__dict__['scheduler'] = None
+ self.__dict__['num_workers'] = None
+ self.__dict__['dask_scheduler'] = None
+
+ # Set `__dict__` values for each kwarg.
+ setattr(self, 'scheduler', scheduler)
+ setattr(self, 'num_workers', num_workers)
+ setattr(self, 'dask_scheduler', self.get('scheduler'))
+
+ # Activate the specified dask options.
+ self._set_dask_options()
+
+ def __repr__(self):
+ msg = 'Dask parallel options: {}.'
+
+ # Automatically populate with all currently accepted kwargs.
+ options = ['{}={}'.format(k, v)
+ for k, v in six.iteritems(self.__dict__)
+ if not k.startswith('_')]
+ joined = ', '.join(options)
+ return msg.format(joined)
+
+ def __setattr__(self, name, value):
+ if name not in self.__dict__:
+ # Can't add new names.
+ msg = "{!r} object has no attribute {!r}"
+ raise AttributeError(msg.format(self.__class__.__name__, name))
+ if value is None:
+ value = self._defaults_dict[name]['default']
+ attr_setter = self._defaults_dict[name]['setter']
+ value = attr_setter(value)
+ super(Parallel, self).__setattr__(name, value)
+
+ @property
+ def _defaults_dict(self):
+ """
+ Define the default value and available options for each settable
+ `kwarg` of this `Option`.
+
+ Note: `'options'` can be set to `None` if it is not reasonable to
+ specify all possible options. For example, this may be reasonable if
+ the `'options'` were a range of numbers.
+
+ """
+ return {'_scheduler': {'default': None, 'options': None,
+ 'setter': self.set__scheduler},
+ 'scheduler': {'default': 'threaded',
+ 'options': ['threaded',
+ 'multiprocessing',
+ 'async',
+ 'distributed'],
+ 'setter': self.set_scheduler},
+ 'num_workers': {'default': 1, 'options': None,
+ 'setter': self.set_num_workers},
+ 'dask_scheduler': {'default': None, 'options': None,
+ 'setter': self.set_dask_scheduler},
+ }
+
+ def set__scheduler(self, value):
+ return value
+
+ def set_scheduler(self, value):
+ default = self._defaults_dict['scheduler']['default']
+ if value is None:
+ value = default
+ elif re.match(r'^(\d{1,3}\.){3}\d{1,3}:\d{1,5}$', value):
+ if distributed is not None:
+ value = 'distributed'
+ else:
+ # Distributed not available.
+ wmsg = 'Cannot import distributed. Defaulting to {}.'
+ warnings.warn(wmsg.format(default))
+ self.set_scheduler(default)
+ elif value not in self._defaults_dict['scheduler']['options']:
+ # Invalid value for `scheduler`.
+ wmsg = 'Invalid value for scheduler: {!r}. Defaulting to {}.'
+ warnings.warn(wmsg.format(value, default))
+ self.set_scheduler(default)
+ return value
+
+ def set_num_workers(self, value):
+ default = self._defaults_dict['num_workers']['default']
+ scheduler = self.get('scheduler')
+ if scheduler == 'async' and value != default:
+ wmsg = 'Cannot set `num_workers` for the serial scheduler {!r}.'
+ warnings.warn(wmsg.format(scheduler))
+ value = None
+ elif scheduler == 'distributed' and value != default:
+ wmsg = ('Attempting to set `num_workers` with the {!r} scheduler '
+ 'requested. Please instead specify number of workers when '
+ 'setting up the distributed scheduler. See '
+ 'https://distributed.readthedocs.io/en/latest/index.html '
+ 'for more details.')
+ warnings.warn(wmsg.format(scheduler))
+ value = None
+ else:
+ if value is None:
+ value = default
+ if value >= cpu_count():
+ # Limit maximum CPUs used to 1 fewer than all available CPUs.
+ wmsg = ('Requested more CPUs ({}) than total available ({}). '
+ 'Limiting number of used CPUs to {}.')
+ warnings.warn(wmsg.format(value, cpu_count(), cpu_count()-1))
+ value = cpu_count() - 1
+ return value
+
+ def set_dask_scheduler(self, scheduler):
+ if scheduler == 'threaded':
+ value = dask.threaded.get
+ elif scheduler == 'multiprocessing':
+ value = dask.multiprocessing.get
+ elif scheduler == 'async':
+ value = dask.async.get_sync
+ elif scheduler == 'distributed':
+ value = self.get('_scheduler')
+ return value
+
+ def _set_dask_options(self):
+ """
+ Use `dask.set_options` to globally apply the options specified at
+ instantiation, either for the lifetime of the session or
+ context manager.
+
+ """
+ scheduler = self.get('scheduler')
+ get = self.get('dask_scheduler')
+ pool = None
+
+ if scheduler in ['threaded', 'multiprocessing']:
+ num_workers = self.get('num_workers')
+ pool = ThreadPool(num_workers)
+ elif scheduler == 'distributed':
+ get = distributed.Client(get).get
+
+ dask.set_options(get=get, pool=pool)
+
+ def get(self, item):
+ return getattr(self, item)
+
+ @contextlib.contextmanager
+ def context(self, **kwargs):
+ # Snapshot the starting state for restoration at the end of the
+ # contextmanager block.
+ starting_state = self.__dict__.copy()
+ # Update the state to reflect the requested changes.
+ for name, value in six.iteritems(kwargs):
+ setattr(self, name, value)
+ setattr(self, 'dask_scheduler', self.get('scheduler'))
+ self._set_dask_options()
+ try:
+ yield
+ finally:
+ # Return the state to the starting state.
+ self.__dict__.clear()
+ self.__dict__.update(starting_state)
+ self._set_dask_options()
+
+
+parallel = Parallel()
diff --git a/lib/iris/tests/unit/config/__init__.py b/lib/iris/tests/unit/config/__init__.py
new file mode 100644
index 0000000000..dd625e1e91
--- /dev/null
+++ b/lib/iris/tests/unit/config/__init__.py
@@ -0,0 +1,20 @@
+# (C) British Crown Copyright 2017, Met Office
+#
+# This file is part of Iris.
+#
+# Iris is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Iris is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Iris. If not, see .
+"""Unit tests for the :mod:`iris.config` module."""
+
+from __future__ import (absolute_import, division, print_function)
+from six.moves import (filter, input, map, range, zip) # noqa
diff --git a/lib/iris/tests/unit/config/test_Parallel.py b/lib/iris/tests/unit/config/test_Parallel.py
new file mode 100644
index 0000000000..5dc77993b0
--- /dev/null
+++ b/lib/iris/tests/unit/config/test_Parallel.py
@@ -0,0 +1,296 @@
+# (C) British Crown Copyright 2017, Met Office
+#
+# This file is part of Iris.
+#
+# Iris is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Iris is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Iris. If not, see .
+"""Unit tests for the :class:`iris.config.Parallel` class."""
+
+from __future__ import (absolute_import, division, print_function)
+from six.moves import (filter, input, map, range, zip) # noqa
+import six
+
+# Import iris.tests first so that some things can be initialised before
+# importing anything else.
+import iris.tests as tests
+
+import warnings
+
+import dask
+
+from iris.config import Parallel
+from iris.tests import mock
+
+
+class Test__operation(tests.IrisTest):
+ def setUp(self):
+ self.parallel = Parallel()
+
+ def test_bad_name(self):
+ # Check we can't do `iris.config.parallel.foo = 'bar`.
+ exp_emsg = "'Parallel' object has no attribute 'foo'"
+ with self.assertRaisesRegexp(AttributeError, exp_emsg):
+ self.parallel.foo = 'bar'
+
+ def test_bad_name__contextmgr(self):
+ # Check we can't do `with iris.config.parallel.context('foo'='bar')`.
+ exp_emsg = "'Parallel' object has no attribute 'foo'"
+ with self.assertRaisesRegexp(AttributeError, exp_emsg):
+ with self.parallel.context(foo='bar'):
+ pass
+
+
+class Test__set_dask_options(tests.IrisTest):
+ def setUp(self):
+ ThreadPool = 'iris.config.ThreadPool'
+ self.pool = mock.sentinel.pool
+ self.patch_ThreadPool = self.patch(ThreadPool, return_value=self.pool)
+ self.default_num_workers = 1
+
+ Client = 'distributed.Client'
+ self.address = '192.168.0.128:8786'
+ mocker = mock.Mock(get=self.address)
+ self.patch_Client = self.patch(Client, return_value=mocker)
+
+ set_options = 'dask.set_options'
+ self.patch_set_options = self.patch(set_options)
+
+ def test_default(self):
+ Parallel()
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_once_with(self.default_num_workers)
+
+ pool = self.pool
+ get = dask.threaded.get
+ self.patch_set_options.assert_called_once_with(pool=pool, get=get)
+
+ def test__five_workers(self):
+ n_workers = 5
+ Parallel(num_workers=n_workers)
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_once_with(n_workers)
+
+ pool = self.pool
+ get = dask.threaded.get
+ self.patch_set_options.assert_called_once_with(pool=pool, get=get)
+
+ def test__five_workers__contextmgr(self):
+ n_workers = 5
+ options = Parallel()
+ pool = self.pool
+ get = dask.threaded.get
+
+ with options.context(num_workers=n_workers):
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_with(n_workers)
+
+ self.patch_set_options.assert_called_with(pool=pool, get=get)
+
+ self.patch_ThreadPool.assert_called_with(self.default_num_workers)
+ self.patch_set_options.assert_called_with(pool=pool, get=get)
+
+ def test_threaded(self):
+ scheduler = 'threaded'
+ Parallel(scheduler=scheduler)
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_once_with(self.default_num_workers)
+
+ pool = self.pool
+ get = dask.threaded.get
+ self.patch_set_options.assert_called_once_with(pool=pool, get=get)
+
+ def test_multiprocessing(self):
+ scheduler = 'multiprocessing'
+ Parallel(scheduler=scheduler)
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_once_with(self.default_num_workers)
+
+ pool = self.pool
+ get = dask.multiprocessing.get
+ self.patch_set_options.assert_called_once_with(pool=pool, get=get)
+
+ def test_multiprocessing__contextmgr(self):
+ scheduler = 'multiprocessing'
+ options = Parallel()
+ with options.context(scheduler=scheduler):
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.patch_ThreadPool.assert_called_with(self.default_num_workers)
+
+ pool = self.pool
+ get = dask.multiprocessing.get
+ self.patch_set_options.assert_called_with(pool=pool, get=get)
+
+ default_get = dask.threaded.get
+ self.patch_ThreadPool.assert_called_with(self.default_num_workers)
+ self.patch_set_options.assert_called_with(pool=pool,
+ get=default_get)
+
+ def test_async(self):
+ scheduler = 'async'
+ Parallel(scheduler=scheduler)
+ self.assertEqual(self.patch_Client.call_count, 0)
+ self.assertEqual(self.patch_ThreadPool.call_count, 0)
+
+ pool = self.pool
+ get = dask.async.get_sync
+ self.patch_set_options.assert_called_once_with(pool=None, get=get)
+
+ def test_distributed(self):
+ scheduler = self.address
+ Parallel(scheduler=scheduler)
+ self.assertEqual(self.patch_ThreadPool.call_count, 0)
+
+ get = scheduler
+ self.patch_Client.assert_called_once_with(get)
+
+ self.patch_set_options.assert_called_once_with(pool=None, get=get)
+
+
+class Test_set_schedulers(tests.IrisTest):
+ # Check that the correct scheduler is chosen given the inputs.
+ def setUp(self):
+ self.patch('iris.config.Parallel._set_dask_options')
+
+ def test_default(self):
+ opts = Parallel()
+ result = opts.get('scheduler')
+ expected = opts._defaults_dict['scheduler']['default']
+ self.assertEqual(result, expected)
+
+ def test_threaded(self):
+ scheduler = 'threaded'
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('scheduler')
+ self.assertEqual(result, scheduler)
+
+ def test_multiprocessing(self):
+ scheduler = 'multiprocessing'
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('scheduler')
+ self.assertEqual(result, scheduler)
+
+ def test_async(self):
+ scheduler = 'async'
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('scheduler')
+ self.assertEqual(result, scheduler)
+
+ def test_distributed(self):
+ scheduler = '192.168.0.128:8786'
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('scheduler')
+ self.assertEqual(result, 'distributed')
+
+ def test_bad(self):
+ scheduler = 'wibble'
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('scheduler')
+ expected = opts._defaults_dict['scheduler']['default']
+ self.assertEqual(result, expected)
+ exp_wmsg = 'Invalid value for scheduler: {!r}'
+ six.assertRegex(self, str(w[0].message), exp_wmsg.format(scheduler))
+
+
+class Test_set_num_workers(tests.IrisTest):
+ # Check that the correct `num_workers` are chosen given the inputs.
+ def setUp(self):
+ self.patch('iris.config.Parallel._set_dask_options')
+
+ def test_default(self):
+ opts = Parallel()
+ result = opts.get('num_workers')
+ expected = opts._defaults_dict['num_workers']['default']
+ self.assertEqual(result, expected)
+
+ def test_basic(self):
+ n_workers = 5
+ opts = Parallel(num_workers=n_workers)
+ result = opts.get('num_workers')
+ self.assertEqual(result, n_workers)
+
+ def test_too_many_workers(self):
+ max_cpus = 8
+ n_workers = 12
+ with mock.patch('multiprocessing.cpu_count', return_value=max_cpus):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ opts = Parallel(num_workers=n_workers)
+ result = opts.get('num_workers')
+ self.assertEqual(result, max_cpus-1)
+ exp_wmsg = ('Requested more CPUs ({}) than total available ({}). '
+ 'Limiting number of used CPUs to {}.')
+ self.assertEqual(str(w[0].message),
+ exp_wmsg.format(n_workers, max_cpus, max_cpus-1))
+
+ def test_async(self):
+ scheduler = 'async'
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ opts = Parallel(scheduler=scheduler, num_workers=5)
+ expected = opts._defaults_dict['num_workers']['default']
+ self.assertEqual(opts.get('num_workers'), expected)
+ exp_wmsg = 'Cannot set `num_workers` for the serial scheduler {!r}'
+ six.assertRegex(self, str(w[0].message), exp_wmsg.format(scheduler))
+
+ def test_distributed(self):
+ scheduler = '192.168.0.128:8786'
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ opts = Parallel(scheduler=scheduler, num_workers=5)
+ expected = opts._defaults_dict['num_workers']['default']
+ self.assertEqual(opts.get('num_workers'), expected)
+ exp_wmsg = 'Attempting to set `num_workers` with the {!r} scheduler'
+ six.assertRegex(self, str(w[0].message),
+ exp_wmsg.format('distributed'))
+
+
+class Test_set_dask_scheduler(tests.IrisTest):
+ # Check that the correct dask scheduler is chosen given the inputs.
+ def setUp(self):
+ self.patch('iris.config.Parallel._set_dask_options')
+
+ def test_default(self):
+ opts = Parallel()
+ result = opts.get('dask_scheduler')
+ expected = dask.threaded.get
+ self.assertIs(result, expected)
+
+ def test_threaded(self):
+ opts = Parallel(scheduler='threaded')
+ result = opts.get('dask_scheduler')
+ expected = dask.threaded.get
+ self.assertIs(result, expected)
+
+ def test_multiprocessing(self):
+ opts = Parallel(scheduler='multiprocessing')
+ result = opts.get('dask_scheduler')
+ expected = dask.multiprocessing.get
+ self.assertIs(result, expected)
+
+ def test_async(self):
+ opts = Parallel(scheduler='async')
+ result = opts.get('dask_scheduler')
+ expected = dask.async.get_sync
+ self.assertIs(result, expected)
+
+ def test_distributed(self):
+ scheduler = '192.168.0.128:8786'
+ opts = Parallel(scheduler=scheduler)
+ result = opts.get('dask_scheduler')
+ self.assertEqual(result, scheduler)
+
+
+if __name__ == '__main__':
+ tests.main()