Skip to content

Commit

Permalink
Avoid duplicates in RandomSearch
Browse files Browse the repository at this point in the history
Why:

The filtering of duplicates in RandomSearch is a lot more efficient than
relying on the database with DuplicateKeyErrors. There will still be
race conditions leading to this error, but filtering upfront in
RandomSearch reduces significantly the number of registration failures
and lead to better workload balance between the workers. Without this
fix, some workers end up always updating their algo and trying to sample
the next point but loose all race conditions.

How:

Use _trial_info to filter out duplicates while sampling points...
  • Loading branch information
bouthilx committed Mar 18, 2020
1 parent c378510 commit f3464bd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
log = logging.getLogger(__name__)


def infer_trial_id(point):
return hashlib.md5(str(list(point)).encode('utf-8')).hexdigest()


# pylint: disable=too-many-public-methods
class BaseAlgorithm(object, metaclass=ABCMeta):
"""Base class describing what an algorithm can do.
Expand Down Expand Up @@ -187,11 +191,10 @@ def observe(self, points, results):
"""
for point, result in zip(points, results):
_point = list(point)
_id = hashlib.md5(str(_point).encode('utf-8')).hexdigest()
point_id = infer_trial_id(point)

if _id not in self._trials_info:
self._trials_info[_id] = result
if point_id not in self._trials_info:
self._trials_info[point_id] = result

@property
def is_done(self):
Expand Down
15 changes: 13 additions & 2 deletions src/orion/algo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import numpy

from orion.algo.base import BaseAlgorithm
from orion.algo.base import BaseAlgorithm, infer_trial_id


class Random(BaseAlgorithm):
Expand Down Expand Up @@ -57,4 +57,15 @@ def suggest(self, num=1):
.. note:: New parameters must be compliant with the problem's domain
`orion.algo.space.Space`.
"""
return self.space.sample(num, seed=tuple(self.rng.randint(0, 1000000, size=3)))
points = []
point_ids = set(self._trials_info.keys())
i = 0
while len(points) < num:
new_point = self.space.sample(1, seed=tuple(self.rng.randint(0, 1000000, size=3)))[0]
point_id = infer_trial_id(new_point)
if point_id not in point_ids:
point_ids.add(point_id)
points.append(new_point)
i += 1

return points
35 changes: 35 additions & 0 deletions tests/unittests/algo/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,38 @@ def test_set_state(space):

random_search.set_state(state)
assert numpy.allclose(a, random_search.suggest(1)[0])


def test_suggest_unique():
"""Verify that RandomSearch do not sample duplicates"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 6
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples


def test_suggest_unique_history():
"""Verify that RandomSearch do not sample duplicates based observed points"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 3
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples

random_search.observe([[value] for value in values], [1] * n_samples)

n_samples = 3
new_values = sum(random_search.suggest(n_samples), tuple())
assert len(new_values) == n_samples
assert len(set(new_values)) == n_samples
# No duplicates
assert (set(new_values) & set(values)) == set()

0 comments on commit f3464bd

Please sign in to comment.