Skip to content

Commit

Permalink
Merge pull request #361 from bouthilx/fix/balance_workload
Browse files Browse the repository at this point in the history
Avoid duplicates in RandomSearch
  • Loading branch information
bouthilx authored Mar 18, 2020
2 parents c378510 + fc9a477 commit e7e2b23
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
log = logging.getLogger(__name__)


def infer_trial_id(point):
"""Compute a hashing of a point"""
return hashlib.md5(str(list(point)).encode('utf-8')).hexdigest()


# pylint: disable=too-many-public-methods
class BaseAlgorithm(object, metaclass=ABCMeta):
"""Base class describing what an algorithm can do.
Expand Down Expand Up @@ -187,11 +192,10 @@ def observe(self, points, results):
"""
for point, result in zip(points, results):
_point = list(point)
_id = hashlib.md5(str(_point).encode('utf-8')).hexdigest()
point_id = infer_trial_id(point)

if _id not in self._trials_info:
self._trials_info[_id] = result
if point_id not in self._trials_info:
self._trials_info[point_id] = result

@property
def is_done(self):
Expand Down
15 changes: 13 additions & 2 deletions src/orion/algo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import numpy

from orion.algo.base import BaseAlgorithm
from orion.algo.base import BaseAlgorithm, infer_trial_id


class Random(BaseAlgorithm):
Expand Down Expand Up @@ -57,4 +57,15 @@ def suggest(self, num=1):
.. note:: New parameters must be compliant with the problem's domain
`orion.algo.space.Space`.
"""
return self.space.sample(num, seed=tuple(self.rng.randint(0, 1000000, size=3)))
points = []
point_ids = set(self._trials_info.keys())
i = 0
while len(points) < num:
new_point = self.space.sample(1, seed=tuple(self.rng.randint(0, 1000000, size=3)))[0]
point_id = infer_trial_id(new_point)
if point_id not in point_ids:
point_ids.add(point_id)
points.append(new_point)
i += 1

return points
35 changes: 35 additions & 0 deletions tests/unittests/algo/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,38 @@ def test_set_state(space):

random_search.set_state(state)
assert numpy.allclose(a, random_search.suggest(1)[0])


def test_suggest_unique():
"""Verify that RandomSearch do not sample duplicates"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 6
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples


def test_suggest_unique_history():
"""Verify that RandomSearch do not sample duplicates based observed points"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 3
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples

random_search.observe([[value] for value in values], [1] * n_samples)

n_samples = 3
new_values = sum(random_search.suggest(n_samples), tuple())
assert len(new_values) == n_samples
assert len(set(new_values)) == n_samples
# No duplicates
assert (set(new_values) & set(values)) == set()

0 comments on commit e7e2b23

Please sign in to comment.