Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid duplicates in RandomSearch #361

Merged
merged 1 commit into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
log = logging.getLogger(__name__)


def infer_trial_id(point):
"""Compute a hashing of a point"""
return hashlib.md5(str(list(point)).encode('utf-8')).hexdigest()


# pylint: disable=too-many-public-methods
class BaseAlgorithm(object, metaclass=ABCMeta):
"""Base class describing what an algorithm can do.
Expand Down Expand Up @@ -187,11 +192,10 @@ def observe(self, points, results):
"""
for point, result in zip(points, results):
_point = list(point)
_id = hashlib.md5(str(_point).encode('utf-8')).hexdigest()
point_id = infer_trial_id(point)

if _id not in self._trials_info:
self._trials_info[_id] = result
if point_id not in self._trials_info:
self._trials_info[point_id] = result

@property
def is_done(self):
Expand Down
15 changes: 13 additions & 2 deletions src/orion/algo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import numpy

from orion.algo.base import BaseAlgorithm
from orion.algo.base import BaseAlgorithm, infer_trial_id


class Random(BaseAlgorithm):
Expand Down Expand Up @@ -57,4 +57,15 @@ def suggest(self, num=1):
.. note:: New parameters must be compliant with the problem's domain
`orion.algo.space.Space`.
"""
return self.space.sample(num, seed=tuple(self.rng.randint(0, 1000000, size=3)))
points = []
point_ids = set(self._trials_info.keys())
i = 0
while len(points) < num:
new_point = self.space.sample(1, seed=tuple(self.rng.randint(0, 1000000, size=3)))[0]
point_id = infer_trial_id(new_point)
if point_id not in point_ids:
point_ids.add(point_id)
points.append(new_point)
i += 1

return points
35 changes: 35 additions & 0 deletions tests/unittests/algo/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,38 @@ def test_set_state(space):

random_search.set_state(state)
assert numpy.allclose(a, random_search.suggest(1)[0])


def test_suggest_unique():
"""Verify that RandomSearch do not sample duplicates"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 6
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples


def test_suggest_unique_history():
"""Verify that RandomSearch do not sample duplicates based observed points"""
space = Space()
space.register(Integer('yolo1', 'uniform', -3, 6))

random_search = Random(space)

n_samples = 3
values = sum(random_search.suggest(n_samples), tuple())
assert len(values) == n_samples
assert len(set(values)) == n_samples

random_search.observe([[value] for value in values], [1] * n_samples)

n_samples = 3
new_values = sum(random_search.suggest(n_samples), tuple())
assert len(new_values) == n_samples
assert len(set(new_values)) == n_samples
# No duplicates
assert (set(new_values) & set(values)) == set()