Skip to content

Commit

Permalink
added test of EvalParallel2 and polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
nikohansen committed Apr 21, 2020
1 parent 67feff2 commit ed33dc3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
15 changes: 9 additions & 6 deletions cma/optimization_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
from __future__ import absolute_import, division, print_function #, unicode_literals
import sys
import warnings
import numpy as np
from multiprocessing import Pool as ProcessingPool
# from pathos.multiprocessing import ProcessingPool
Expand Down Expand Up @@ -207,8 +208,10 @@ class EvalParallel2(object):
Examples:
>>> import cma
>>> from cma.optimization_tools import EvalParallel2
>>> for n_jobs in [None, -1, 0, 1, 2, 4]:
... with EvalParallel2(cma.fitness_functions.elli, n_jobs) as eval_all:
... res = eval_all([[1,2], [3,4]])
>>> # class usage, don't forget to call terminate
>>> ep = EvalParallel2(cma.fitness_functions.elli, 4)
>>> ep([[1,2], [3,4], [4, 5]]) # doctest:+ELLIPSIS
Expand Down Expand Up @@ -244,11 +247,11 @@ class EvalParallel2(object):
"""
def __init__(self, fitness_function=None, number_of_processes=None):
self.fitness_function = fitness_function
self.processes = number_of_processes
if self.processes is not None and self.processes <= 0:
self.pool = None
else:
self.processes = number_of_processes # for the record
if self.processes is None or self.processes > 0:
self.pool = ProcessingPool(self.processes)
else:
self.pool = None

def __call__(self, solutions, fitness_function=None, args=(), timeout=None):
"""evaluate a list/sequence of solution-"vectors", return a list
Expand All @@ -269,7 +272,7 @@ def __call__(self, solutions, fitness_function=None, args=(), timeout=None):
warning_str = ("`fitness_function` must be a function, not a"
" `lambda` or an instancemethod, in order to work with"
" `multiprocessing` under Python 2")
if sys.version[0] == '2': # not necessary anymore?
if sys.version[0] == '2':
if isinstance(fitness_function, type(self.__init__)):
warnings.warn(warning_str)
jobs = [self.pool.apply_async(fitness_function, (x,) + args)
Expand Down
7 changes: 7 additions & 0 deletions cma/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,15 @@ def various_doctests():
For VD- and VkD-CMA, see `cma.restricted_gaussian_sampler`.
>>> import sys
>>> import cma
>>> assert cma.interfaces.EvalParallel2 is not None
>>> try:
... with warnings.catch_warnings(record=True) as warn:
... with cma.optimization_tools.EvalParallel2(cma.ff.elli) as eval_all:
... res = eval_all([[1,2], [3,4]])
... except:
... assert sys.version[0] == '2'
"""

Expand Down

0 comments on commit ed33dc3

Please sign in to comment.