Skip to content

Commit

Permalink
Remove default algo.suggest()
Browse files Browse the repository at this point in the history
Why:

The producer should be responsible of requesting as many trials as
necessary to reach `max_trials`. It is then up to the algo to return as
many as possible.
  • Loading branch information
bouthilx committed May 7, 2021
1 parent f6a45c5 commit f6aea11
Show file tree
Hide file tree
Showing 25 changed files with 209 additions and 134 deletions.
7 changes: 2 additions & 5 deletions src/orion/algo/asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,8 @@ def sample(self, num):

return samples

def suggest(self, num=None):
if num is None:
num = 1

return super(ASHA, self).suggest(num)
def suggest(self, num):
return super(ASHA, self).suggest(1)

def create_bracket(self, i, budgets, iteration):
return ASHABracket(self, budgets, iteration)
Expand Down
8 changes: 4 additions & 4 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,14 @@ def _is_fidelity(dim):
return None

@abstractmethod
def suggest(self, num=1):
def suggest(self, num):
"""Suggest a `num` of new sets of parameters.
Parameters
----------
num: int, optional
Number of points to suggest. Defaults to None, in which case the algorithms
returns the number of points it considers most optimal.
num: int
Number of points to suggest. The algorithm may return less than the number of points
requested.
Returns
-------
Expand Down
5 changes: 1 addition & 4 deletions src/orion/algo/gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def set_state(self, state_dict):
super(GridSearch, self).set_state(state_dict)
self.grid = state_dict["grid"]

def suggest(self, num=None):
def suggest(self, num):
"""Return the entire grid of suggestions
Returns
Expand All @@ -202,9 +202,6 @@ def suggest(self, num=None):
"""
if self.grid is None:
self._initialize()
if num is None:
num = len(self.grid)

i = 0
points = []
while len(points) < num and i < len(self.grid):
Expand Down
5 changes: 1 addition & 4 deletions src/orion/algo/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def sample(self, num):

return samples

def suggest(self, num=None):
def suggest(self, num):
"""Suggest a number of new sets of parameters.
Sample new points until first rung is filled. Afterwards
Expand All @@ -310,9 +310,6 @@ def suggest(self, num=None):
trials to complete), in which case it will return None.
"""
if num is None:
num = 100000

self._refresh_brackets()

samples = self.promote(num)
Expand Down
5 changes: 1 addition & 4 deletions src/orion/algo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def set_state(self, state_dict):
self.seed_rng(0)
self.rng.set_state(state_dict["rng_state"])

def suggest(self, num=None):
def suggest(self, num):
"""Suggest a `num` of new sets of parameters. Randomly draw samples
from the import space and return them.
Expand All @@ -55,9 +55,6 @@ def suggest(self, num=None):
.. note:: New parameters must be compliant with the problem's domain
`orion.algo.space.Space`.
"""
if num is None:
num = max(self.max_trials - self.n_suggested, 1)

points = []
while len(points) < num and not self.is_done:
seed = tuple(self.rng.randint(0, 1000000, size=3))
Expand Down
5 changes: 3 additions & 2 deletions src/orion/algo/tpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ def suggest(self, num=None):
.. note:: New parameters must be compliant with the problem's domain
`orion.algo.space.Space`.
"""
if num is None:
num = max(self.n_initial_points - self.n_observed, 1)
# Only sample up to `n_initial_points` and after that only sample one at a time.
num = min(num, max(self.n_initial_points - self.n_suggested, 1))

samples = []
candidates = []
while len(samples) < num and self.n_suggested < self.space.cardinality:
Expand Down
4 changes: 2 additions & 2 deletions src/orion/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def define_experiment_config(config):
experiment_config.add_option(
"max_trials",
option_type=int,
default=int(10e8),
default=1000,
env_var="ORION_EXP_MAX_TRIALS",
help="number of trials to be completed for the experiment. This value "
"will be saved within the experiment configuration and reused "
Expand All @@ -144,7 +144,7 @@ def define_experiment_config(config):
experiment_config.add_option(
"worker_trials",
option_type=int,
default=int(10e8),
default=1000,
deprecate=dict(
version="v0.3",
alternative="worker.max_trials",
Expand Down
2 changes: 1 addition & 1 deletion src/orion/core/worker/primary_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def set_state(self, state_dict):
"""
self.algorithm.set_state(state_dict)

def suggest(self, num=None):
def suggest(self, num):
"""Suggest a `num` of new sets of parameters.
:param num: how many sets to be suggested.
Expand Down
12 changes: 11 additions & 1 deletion src/orion/core/worker/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, experiment, max_idle_time=None):
self.params_hashes = set()
self.naive_trials_history = None
self.failure_count = 0
self.num_trials = 0
self.num_broken = 0

@property
def pool_size(self):
Expand All @@ -81,6 +83,12 @@ def _sample_guard(self, start):
)
)

def suggest(self):
"""Try suggesting new points with the naive algorithm"""
num_pending = self.num_trials - self.num_broken
num = max(self.experiment.max_trials - num_pending, 1)
return self.naive_algorithm.suggest(num)

def produce(self):
"""Create and register new trials."""
sampled_points = 0
Expand All @@ -95,7 +103,7 @@ def produce(self):
self._sample_guard(start)

log.debug("### Algorithm suggests new points.")
new_points = self.naive_algorithm.suggest()
new_points = self.suggest()

# Sync state of original algo so that state continues evolving.
self.algorithm.set_state(self.naive_algorithm.state_dict)
Expand Down Expand Up @@ -174,6 +182,8 @@ def update(self):
ones.
"""
trials = self.experiment.fetch_trials(with_evc_tree=True)
self.num_trials = len(trials)
self.num_broken = len([trial for trial in trials if trial.status == "broken"])

self._update_algorithm(
[trial for trial in trials if trial.status == "completed"]
Expand Down
43 changes: 22 additions & 21 deletions src/orion/testing/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import inspect
import itertools
from collections import defaultdict

import numpy
import pytest
Expand Down Expand Up @@ -46,16 +47,24 @@ def spy_attr(mocker, algo, attribute):
return mocker.spy(obj, attr_to_mock)


methods_with_phase = set()
methods_with_phase = defaultdict(set)


phase_docstring = """\
This test is parametrizable with phases.
See ``orion.testing.algo.BaseAlgoTests.set_phases``.\
"""


def phase(method):
"""Decorator to mark methods that must be parametrized with phases."""
methods_with_phase.add(method.__name__)
method.__doc__ += (
"\n\nThis test is parametrizable with phases. See "
"``orion.testing.algo.BaseAlgoTests.set_phases``."
)
class_name = ".".join(method.__qualname__.split(".")[:-1])
methods_with_phase[class_name].add(method.__name__)

if method.__doc__ is None:
method.__doc__ = phase_docstring
else:
method.__doc__ += "\n\n" + phase_docstring

return method

Expand Down Expand Up @@ -137,7 +146,11 @@ def set_phases(cls, phases):
ids = [phase[0] for phase in phases]
attrs = [phase[1:] for phase in phases]

for method_name in sorted(methods_with_phase):
cls_methods_with_phase = (
methods_with_phase["BaseAlgoTests"] | methods_with_phase[cls.__name__]
)

for method_name in sorted(cls_methods_with_phase):
parametrize_this(cls, method_name, attrs, ids)

def create_algo(self, config=None, space=None, **kwargs):
Expand Down Expand Up @@ -317,7 +330,7 @@ def assert_dim_type_supported(self, mocker, num, attr, test_space):

spy = self.spy_phase(mocker, num, algo, attr)

points = algo.suggest()
points = algo.suggest(1)
assert points[0] in space
spy.call_count == 1
self.observe_points(points, algo, 1)
Expand Down Expand Up @@ -400,18 +413,6 @@ def test_state_dict(self, mocker, num, attr):

self.assert_callbacks(spy, num, algo)

@phase
def test_suggest(self, mocker, num, attr):
"""Verify that suggest returns correct number of points.
This method will likely require to be overriden based on the behavior of the algorith
for ``suggest(num=None)``.
"""
algo = self.create_algo()
spy = self.spy_phase(mocker, num, algo, attr)
points = algo.suggest()
assert len(points) == 1

@phase
def test_suggest_n(self, mocker, num, attr):
"""Verify that suggest returns correct number of trials if ``num`` is specified in ``suggest``."""
Expand Down Expand Up @@ -607,7 +608,7 @@ def test_optimize_branin(self):
break

if not points:
points = algo.suggest()
points = algo.suggest(MAX_TRIALS - len(objectives))

point = points.pop(0)
results = task(*point)
Expand Down
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self._score_point = None
self._judge_point = None
self._measurements = None
self.default_num = 1
self.pool_size = 1
self.possible_values = [value]
super(DumbAlgo, self).__init__(
space,
Expand Down Expand Up @@ -139,10 +139,9 @@ def set_state(self, state_dict):
self._num = state_dict["num"]
self.done = state_dict["done"]

def suggest(self, num=None):
def suggest(self, num):
"""Suggest based on `value`."""
if num is None:
num = self.default_num
num = min(num, self.pool_size)
self._num += num

rval = []
Expand Down
6 changes: 4 additions & 2 deletions tests/functional/algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,15 @@ def test_with_evc(algorithm):
name="exp",
space=space_with_fidelity,
algorithms=algorithm_configs["random"],
max_trials=10,
)
base_exp.workon(rosenbrock, max_trials=10)

exp = create_experiment(
name="exp",
space=space_with_fidelity,
algorithms=algorithm,
max_trials=30,
branching={"branch_from": "exp"},
)

Expand All @@ -229,13 +231,13 @@ def test_with_evc(algorithm):
assert len(trials) >= 20

trials_with_evc = exp.fetch_trials(with_evc_tree=True)
assert len(trials_with_evc) >= 31
assert len(trials_with_evc) >= 30
assert len(trials_with_evc) - len(trials) == 10

completed_trials = [
trial for trial in trials_with_evc if trial.status == "completed"
]
assert len(completed_trials) > 30
assert len(completed_trials) >= 30

results = [trial.objective.value for trial in completed_trials]
best_trial = next(
Expand Down
22 changes: 20 additions & 2 deletions tests/functional/client/test_cli_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def test_interrupt(storage, monkeypatch, capsys):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"1",
"--worker-trials",
"2",
"1",
"python",
"black_box.py",
"interrupt_trial",
Expand Down Expand Up @@ -65,6 +67,8 @@ def empty_env(self, trial, results_file=None):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand All @@ -89,8 +93,10 @@ def empty_env(self, trial, results_file=None):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trial",
"1",
"--worker-trials",
"2",
"1",
"python",
"black_box.py",
"interrupt_trial",
Expand Down Expand Up @@ -131,6 +137,8 @@ def test_report_no_name(storage, monkeypatch, fct):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand Down Expand Up @@ -165,6 +173,8 @@ def test_report_with_name(storage, monkeypatch, fct):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand Down Expand Up @@ -202,6 +212,8 @@ def test_report_with_bad_objective(storage, monkeypatch, fct):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand All @@ -227,6 +239,8 @@ def test_report_with_bad_trial_no_objective(storage, monkeypatch):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand Down Expand Up @@ -258,6 +272,8 @@ def test_report_with_bad_trial_with_data(storage, monkeypatch):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand Down Expand Up @@ -295,6 +311,8 @@ def test_no_report(storage, monkeypatch, capsys):
"hunt",
"--config",
"./orion_config.yaml",
"--exp-max-trials",
"2",
"--worker-trials",
"2",
"python",
Expand Down
Loading

0 comments on commit f6aea11

Please sign in to comment.