diff --git a/lib/iris/_lazy_data.py b/lib/iris/_lazy_data.py index 7d525477d6..d2cf9b2ac9 100644 --- a/lib/iris/_lazy_data.py +++ b/lib/iris/_lazy_data.py @@ -95,6 +95,10 @@ def as_concrete_data(data, **kwargs): Returns: A NumPy `ndarray` or masked array. + .. note:: + Specific dask options for computation are controlled by + :class:`iris.config.Parallel`. + """ if is_lazy_data(data): # Realise dask array, ensuring the data result is always a NumPy array. diff --git a/lib/iris/config.py b/lib/iris/config.py index e14ac058a4..76cbf9410d 100644 --- a/lib/iris/config.py +++ b/lib/iris/config.py @@ -1,4 +1,4 @@ -# (C) British Crown Copyright 2010 - 2016, Met Office +# (C) British Crown Copyright 2010 - 2017, Met Office # # This file is part of Iris. # @@ -66,11 +66,24 @@ from __future__ import (absolute_import, division, print_function) from six.moves import (filter, input, map, range, zip) # noqa - +import six from six.moves import configparser + +import contextlib +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool import os.path +import re import warnings +import dask +import dask.multiprocessing + +try: + import distributed +except ImportError: + distributed = None + # Returns simple string options def get_option(section, option, default=None): @@ -154,3 +167,271 @@ def get_dir_option(section, option, default=None): IMPORT_LOGGER = get_option(_LOGGING_SECTION, 'import_logger') + + +################# +# Runtime options + +class Option(object): + """ + An abstract superclass to enforce certain key behaviours for all `Option` + classes. + + """ + @property + def _defaults_dict(self): + raise NotImplementedError + + def __setattr__(self, name, value): + if value is None: + # Set an explicitly unset value to the default value for the name. + value = self._defaults_dict[name]['default'] + if self._defaults_dict[name]['options'] is not None: + # Replace a bad value with the default if there is a defined set of + # specified good values. + if value not in self._defaults_dict[name]['options']: + good_value = self._defaults_dict[name]['default'] + wmsg = ('Attempting to set bad value {!r} for attribute {!r}. ' + 'Defaulting to {!r}.') + warnings.warn(wmsg.format(value, name, good_value)) + value = good_value + self.__dict__[name] = value + + def context(self): + raise NotImplementedError + + +class Parallel(Option): + """ + Control dask parallel processing options for Iris. + + """ + def __init__(self, scheduler=None, num_workers=None): + """ + Set up options for dask parallel processing. + + Currently accepted kwargs: + + * scheduler: + The scheduler used to run a dask graph. Must be set to one of: + + * 'threaded': (default) + The scheduler processes the graph in parallel using a + thread pool. Good for processing dask arrays and dataframes. + * 'multiprocessing': + The scheduler processes the graph in parallel using a + process pool. Good for processing dask bags. + * 'async': + The scheduler runs synchronously (not in parallel). Good for + debugging. + * The IP address and port of a distributed scheduler: + Specifies the location of a distributed scheduler that has + already been set up. The distributed scheduler will process the + graph. + + For more information see + http://dask.pydata.org/en/latest/scheduler-overview.html. + + * num_workers: + The number of worker threads or processess to use to run the dask + graph in parallel. Defaults to 1 (that is, processed serially). + + .. note:: + The value for `num_workers` cannot be set to greater than the + number of CPUs available on the host system. If such a value is + requested, `num_workers` is automatically set to 1 less than + the number of CPUs available on the host system. + + .. note:: + Only the 'threaded' and 'multiprocessing' schedulers support + the `num_workers` kwarg. If it is specified with the `async` or + `distributed` scheduler, the kwarg is ignored: + + * The 'async' scheduler runs serially so will only use a single + worker. + * The number of workers for the 'distributed' scheduler must be + defined when setting up the distributed scheduler. For more + information on setting up distributed schedulers, see + https://distributed.readthedocs.io/en/latest/index.html. + + Example usages: + + * Specify that we want to load a cube with dask parallel processing + using multiprocessing with six worker processes:: + + iris.config.parallel(scheduler='multiprocessing', num_workers=6) + iris.load('my_dataset.nc') + + * Specify, with a context manager, that we want to load a cube with + dask parallel processing using four worker threads:: + + with iris.config.parallel(scheduler='threaded', num_workers=4): + iris.load('my_dataset.nc') + + * Run dask parallel processing using a distributed scheduler that has + been set up at the IP address and port at ``192.168.0.219:8786``:: + + iris.config.parallel(scheduler='192.168.0.219:8786') + + """ + # Set `__dict__` keys first. + self.__dict__['_scheduler'] = scheduler + self.__dict__['scheduler'] = None + self.__dict__['num_workers'] = None + self.__dict__['dask_scheduler'] = None + + # Set `__dict__` values for each kwarg. + setattr(self, 'scheduler', scheduler) + setattr(self, 'num_workers', num_workers) + setattr(self, 'dask_scheduler', self.get('scheduler')) + + # Activate the specified dask options. + self._set_dask_options() + + def __repr__(self): + msg = 'Dask parallel options: {}.' + + # Automatically populate with all currently accepted kwargs. + options = ['{}={}'.format(k, v) + for k, v in six.iteritems(self.__dict__) + if not k.startswith('_')] + joined = ', '.join(options) + return msg.format(joined) + + def __setattr__(self, name, value): + if name not in self.__dict__: + # Can't add new names. + msg = "{!r} object has no attribute {!r}" + raise AttributeError(msg.format(self.__class__.__name__, name)) + if value is None: + value = self._defaults_dict[name]['default'] + attr_setter = self._defaults_dict[name]['setter'] + value = attr_setter(value) + super(Parallel, self).__setattr__(name, value) + + @property + def _defaults_dict(self): + """ + Define the default value and available options for each settable + `kwarg` of this `Option`. + + Note: `'options'` can be set to `None` if it is not reasonable to + specify all possible options. For example, this may be reasonable if + the `'options'` were a range of numbers. + + """ + return {'_scheduler': {'default': None, 'options': None, + 'setter': self.set__scheduler}, + 'scheduler': {'default': 'threaded', + 'options': ['threaded', + 'multiprocessing', + 'async', + 'distributed'], + 'setter': self.set_scheduler}, + 'num_workers': {'default': 1, 'options': None, + 'setter': self.set_num_workers}, + 'dask_scheduler': {'default': None, 'options': None, + 'setter': self.set_dask_scheduler}, + } + + def set__scheduler(self, value): + return value + + def set_scheduler(self, value): + default = self._defaults_dict['scheduler']['default'] + if value is None: + value = default + elif re.match(r'^(\d{1,3}\.){3}\d{1,3}:\d{1,5}$', value): + if distributed is not None: + value = 'distributed' + else: + # Distributed not available. + wmsg = 'Cannot import distributed. Defaulting to {}.' + warnings.warn(wmsg.format(default)) + self.set_scheduler(default) + elif value not in self._defaults_dict['scheduler']['options']: + # Invalid value for `scheduler`. + wmsg = 'Invalid value for scheduler: {!r}. Defaulting to {}.' + warnings.warn(wmsg.format(value, default)) + self.set_scheduler(default) + return value + + def set_num_workers(self, value): + default = self._defaults_dict['num_workers']['default'] + scheduler = self.get('scheduler') + if scheduler == 'async' and value != default: + wmsg = 'Cannot set `num_workers` for the serial scheduler {!r}.' + warnings.warn(wmsg.format(scheduler)) + value = None + elif scheduler == 'distributed' and value != default: + wmsg = ('Attempting to set `num_workers` with the {!r} scheduler ' + 'requested. Please instead specify number of workers when ' + 'setting up the distributed scheduler. See ' + 'https://distributed.readthedocs.io/en/latest/index.html ' + 'for more details.') + warnings.warn(wmsg.format(scheduler)) + value = None + else: + if value is None: + value = default + if value >= cpu_count(): + # Limit maximum CPUs used to 1 fewer than all available CPUs. + wmsg = ('Requested more CPUs ({}) than total available ({}). ' + 'Limiting number of used CPUs to {}.') + warnings.warn(wmsg.format(value, cpu_count(), cpu_count()-1)) + value = cpu_count() - 1 + return value + + def set_dask_scheduler(self, scheduler): + if scheduler == 'threaded': + value = dask.threaded.get + elif scheduler == 'multiprocessing': + value = dask.multiprocessing.get + elif scheduler == 'async': + value = dask.async.get_sync + elif scheduler == 'distributed': + value = self.get('_scheduler') + return value + + def _set_dask_options(self): + """ + Use `dask.set_options` to globally apply the options specified at + instantiation, either for the lifetime of the session or + context manager. + + """ + scheduler = self.get('scheduler') + get = self.get('dask_scheduler') + pool = None + + if scheduler in ['threaded', 'multiprocessing']: + num_workers = self.get('num_workers') + pool = ThreadPool(num_workers) + elif scheduler == 'distributed': + get = distributed.Client(get).get + + dask.set_options(get=get, pool=pool) + + def get(self, item): + return getattr(self, item) + + @contextlib.contextmanager + def context(self, **kwargs): + # Snapshot the starting state for restoration at the end of the + # contextmanager block. + starting_state = self.__dict__.copy() + # Update the state to reflect the requested changes. + for name, value in six.iteritems(kwargs): + setattr(self, name, value) + setattr(self, 'dask_scheduler', self.get('scheduler')) + self._set_dask_options() + try: + yield + finally: + # Return the state to the starting state. + self.__dict__.clear() + self.__dict__.update(starting_state) + self._set_dask_options() + + +parallel = Parallel() diff --git a/lib/iris/tests/unit/config/__init__.py b/lib/iris/tests/unit/config/__init__.py new file mode 100644 index 0000000000..dd625e1e91 --- /dev/null +++ b/lib/iris/tests/unit/config/__init__.py @@ -0,0 +1,20 @@ +# (C) British Crown Copyright 2017, Met Office +# +# This file is part of Iris. +# +# Iris is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Iris is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Iris. If not, see . +"""Unit tests for the :mod:`iris.config` module.""" + +from __future__ import (absolute_import, division, print_function) +from six.moves import (filter, input, map, range, zip) # noqa diff --git a/lib/iris/tests/unit/config/test_Parallel.py b/lib/iris/tests/unit/config/test_Parallel.py new file mode 100644 index 0000000000..5dc77993b0 --- /dev/null +++ b/lib/iris/tests/unit/config/test_Parallel.py @@ -0,0 +1,296 @@ +# (C) British Crown Copyright 2017, Met Office +# +# This file is part of Iris. +# +# Iris is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Iris is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Iris. If not, see . +"""Unit tests for the :class:`iris.config.Parallel` class.""" + +from __future__ import (absolute_import, division, print_function) +from six.moves import (filter, input, map, range, zip) # noqa +import six + +# Import iris.tests first so that some things can be initialised before +# importing anything else. +import iris.tests as tests + +import warnings + +import dask + +from iris.config import Parallel +from iris.tests import mock + + +class Test__operation(tests.IrisTest): + def setUp(self): + self.parallel = Parallel() + + def test_bad_name(self): + # Check we can't do `iris.config.parallel.foo = 'bar`. + exp_emsg = "'Parallel' object has no attribute 'foo'" + with self.assertRaisesRegexp(AttributeError, exp_emsg): + self.parallel.foo = 'bar' + + def test_bad_name__contextmgr(self): + # Check we can't do `with iris.config.parallel.context('foo'='bar')`. + exp_emsg = "'Parallel' object has no attribute 'foo'" + with self.assertRaisesRegexp(AttributeError, exp_emsg): + with self.parallel.context(foo='bar'): + pass + + +class Test__set_dask_options(tests.IrisTest): + def setUp(self): + ThreadPool = 'iris.config.ThreadPool' + self.pool = mock.sentinel.pool + self.patch_ThreadPool = self.patch(ThreadPool, return_value=self.pool) + self.default_num_workers = 1 + + Client = 'distributed.Client' + self.address = '192.168.0.128:8786' + mocker = mock.Mock(get=self.address) + self.patch_Client = self.patch(Client, return_value=mocker) + + set_options = 'dask.set_options' + self.patch_set_options = self.patch(set_options) + + def test_default(self): + Parallel() + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_once_with(self.default_num_workers) + + pool = self.pool + get = dask.threaded.get + self.patch_set_options.assert_called_once_with(pool=pool, get=get) + + def test__five_workers(self): + n_workers = 5 + Parallel(num_workers=n_workers) + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_once_with(n_workers) + + pool = self.pool + get = dask.threaded.get + self.patch_set_options.assert_called_once_with(pool=pool, get=get) + + def test__five_workers__contextmgr(self): + n_workers = 5 + options = Parallel() + pool = self.pool + get = dask.threaded.get + + with options.context(num_workers=n_workers): + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_with(n_workers) + + self.patch_set_options.assert_called_with(pool=pool, get=get) + + self.patch_ThreadPool.assert_called_with(self.default_num_workers) + self.patch_set_options.assert_called_with(pool=pool, get=get) + + def test_threaded(self): + scheduler = 'threaded' + Parallel(scheduler=scheduler) + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_once_with(self.default_num_workers) + + pool = self.pool + get = dask.threaded.get + self.patch_set_options.assert_called_once_with(pool=pool, get=get) + + def test_multiprocessing(self): + scheduler = 'multiprocessing' + Parallel(scheduler=scheduler) + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_once_with(self.default_num_workers) + + pool = self.pool + get = dask.multiprocessing.get + self.patch_set_options.assert_called_once_with(pool=pool, get=get) + + def test_multiprocessing__contextmgr(self): + scheduler = 'multiprocessing' + options = Parallel() + with options.context(scheduler=scheduler): + self.assertEqual(self.patch_Client.call_count, 0) + self.patch_ThreadPool.assert_called_with(self.default_num_workers) + + pool = self.pool + get = dask.multiprocessing.get + self.patch_set_options.assert_called_with(pool=pool, get=get) + + default_get = dask.threaded.get + self.patch_ThreadPool.assert_called_with(self.default_num_workers) + self.patch_set_options.assert_called_with(pool=pool, + get=default_get) + + def test_async(self): + scheduler = 'async' + Parallel(scheduler=scheduler) + self.assertEqual(self.patch_Client.call_count, 0) + self.assertEqual(self.patch_ThreadPool.call_count, 0) + + pool = self.pool + get = dask.async.get_sync + self.patch_set_options.assert_called_once_with(pool=None, get=get) + + def test_distributed(self): + scheduler = self.address + Parallel(scheduler=scheduler) + self.assertEqual(self.patch_ThreadPool.call_count, 0) + + get = scheduler + self.patch_Client.assert_called_once_with(get) + + self.patch_set_options.assert_called_once_with(pool=None, get=get) + + +class Test_set_schedulers(tests.IrisTest): + # Check that the correct scheduler is chosen given the inputs. + def setUp(self): + self.patch('iris.config.Parallel._set_dask_options') + + def test_default(self): + opts = Parallel() + result = opts.get('scheduler') + expected = opts._defaults_dict['scheduler']['default'] + self.assertEqual(result, expected) + + def test_threaded(self): + scheduler = 'threaded' + opts = Parallel(scheduler=scheduler) + result = opts.get('scheduler') + self.assertEqual(result, scheduler) + + def test_multiprocessing(self): + scheduler = 'multiprocessing' + opts = Parallel(scheduler=scheduler) + result = opts.get('scheduler') + self.assertEqual(result, scheduler) + + def test_async(self): + scheduler = 'async' + opts = Parallel(scheduler=scheduler) + result = opts.get('scheduler') + self.assertEqual(result, scheduler) + + def test_distributed(self): + scheduler = '192.168.0.128:8786' + opts = Parallel(scheduler=scheduler) + result = opts.get('scheduler') + self.assertEqual(result, 'distributed') + + def test_bad(self): + scheduler = 'wibble' + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + opts = Parallel(scheduler=scheduler) + result = opts.get('scheduler') + expected = opts._defaults_dict['scheduler']['default'] + self.assertEqual(result, expected) + exp_wmsg = 'Invalid value for scheduler: {!r}' + six.assertRegex(self, str(w[0].message), exp_wmsg.format(scheduler)) + + +class Test_set_num_workers(tests.IrisTest): + # Check that the correct `num_workers` are chosen given the inputs. + def setUp(self): + self.patch('iris.config.Parallel._set_dask_options') + + def test_default(self): + opts = Parallel() + result = opts.get('num_workers') + expected = opts._defaults_dict['num_workers']['default'] + self.assertEqual(result, expected) + + def test_basic(self): + n_workers = 5 + opts = Parallel(num_workers=n_workers) + result = opts.get('num_workers') + self.assertEqual(result, n_workers) + + def test_too_many_workers(self): + max_cpus = 8 + n_workers = 12 + with mock.patch('multiprocessing.cpu_count', return_value=max_cpus): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + opts = Parallel(num_workers=n_workers) + result = opts.get('num_workers') + self.assertEqual(result, max_cpus-1) + exp_wmsg = ('Requested more CPUs ({}) than total available ({}). ' + 'Limiting number of used CPUs to {}.') + self.assertEqual(str(w[0].message), + exp_wmsg.format(n_workers, max_cpus, max_cpus-1)) + + def test_async(self): + scheduler = 'async' + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + opts = Parallel(scheduler=scheduler, num_workers=5) + expected = opts._defaults_dict['num_workers']['default'] + self.assertEqual(opts.get('num_workers'), expected) + exp_wmsg = 'Cannot set `num_workers` for the serial scheduler {!r}' + six.assertRegex(self, str(w[0].message), exp_wmsg.format(scheduler)) + + def test_distributed(self): + scheduler = '192.168.0.128:8786' + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + opts = Parallel(scheduler=scheduler, num_workers=5) + expected = opts._defaults_dict['num_workers']['default'] + self.assertEqual(opts.get('num_workers'), expected) + exp_wmsg = 'Attempting to set `num_workers` with the {!r} scheduler' + six.assertRegex(self, str(w[0].message), + exp_wmsg.format('distributed')) + + +class Test_set_dask_scheduler(tests.IrisTest): + # Check that the correct dask scheduler is chosen given the inputs. + def setUp(self): + self.patch('iris.config.Parallel._set_dask_options') + + def test_default(self): + opts = Parallel() + result = opts.get('dask_scheduler') + expected = dask.threaded.get + self.assertIs(result, expected) + + def test_threaded(self): + opts = Parallel(scheduler='threaded') + result = opts.get('dask_scheduler') + expected = dask.threaded.get + self.assertIs(result, expected) + + def test_multiprocessing(self): + opts = Parallel(scheduler='multiprocessing') + result = opts.get('dask_scheduler') + expected = dask.multiprocessing.get + self.assertIs(result, expected) + + def test_async(self): + opts = Parallel(scheduler='async') + result = opts.get('dask_scheduler') + expected = dask.async.get_sync + self.assertIs(result, expected) + + def test_distributed(self): + scheduler = '192.168.0.128:8786' + opts = Parallel(scheduler=scheduler) + result = opts.get('dask_scheduler') + self.assertEqual(result, scheduler) + + +if __name__ == '__main__': + tests.main()