diff --git a/src/orion/algo/asha.py b/src/orion/algo/asha.py index f07764d42..e92b6512c 100644 --- a/src/orion/algo/asha.py +++ b/src/orion/algo/asha.py @@ -26,6 +26,12 @@ Params: {params} """ +SPACE_ERROR = """ +ASHA cannot be used if space does contain a fidelity dimension. +For more information on the configuration and usage of ASHA, see +https://orion.readthedocs.io/en/develop/user/algorithms.html#asha +""" + class ASHA(BaseAlgorithm): """Asynchronous Successive Halving Algorithm @@ -72,11 +78,16 @@ def __init__(self, space, seed=None, grace_period=None, max_resources=None, reduction_factor=None, num_rungs=None, num_brackets=1): super(ASHA, self).__init__( space, seed=seed, max_resources=max_resources, grace_period=grace_period, - reduction_factor=reduction_factor, num_brackets=num_brackets) + reduction_factor=reduction_factor, num_rungs=num_rungs, num_brackets=num_brackets) self.trial_info = {} # Stores Trial -> Bracket - fidelity_dim = space.values()[self.fidelity_index] + try: + fidelity_index = self.fidelity_index + except IndexError: + raise RuntimeError(SPACE_ERROR) + + fidelity_dim = space.values()[fidelity_index] if grace_period is not None: logger.warning( @@ -106,14 +117,15 @@ def __init__(self, space, seed=None, grace_period=None, max_resources=None, raise AttributeError("Reduction factor for ASHA needs to be at least 2.") if num_rungs is None: - num_rungs = numpy.log(max_resources / min_resources) / numpy.log(reduction_factor) + 1 + num_rungs = int(numpy.log(max_resources / min_resources) / + numpy.log(reduction_factor) + 1) self.num_rungs = num_rungs budgets = numpy.logspace( numpy.log(min_resources) / numpy.log(reduction_factor), numpy.log(max_resources) / numpy.log(reduction_factor), - num_rungs, base=reduction_factor) + num_rungs, base=reduction_factor).astype(int) # Tracks state for new trial add self.brackets = [ @@ -162,7 +174,7 @@ def suggest(self, num=1): logger.debug('Promoting') return [candidate] - if all(bracket.is_done for bracket in self.brackets): + if all(bracket.is_filled for bracket in self.brackets): logger.debug('All brackets are filled.') return None @@ -178,7 +190,7 @@ def suggest(self, num=1): sizes = numpy.array([len(b.rungs) for b in self.brackets]) probs = numpy.e**(sizes - sizes.max()) - probs = numpy.array([prob * int(not bracket.is_done) + probs = numpy.array([prob * int(not bracket.is_filled) for prob, bracket in zip(probs, self.brackets)]) normalized = probs / probs.sum() idx = self.rng.choice(len(self.brackets), p=normalized) @@ -298,14 +310,19 @@ def get_candidate(self, rung_id): @property def is_done(self): + """Return True, if the last rung is filled.""" + return len(self.rungs[-1][1]) + + @property + def is_filled(self): """Return True, if the penultimate rung is filled.""" - return self.is_filled(len(self.rungs) - 2) or len(self.rungs[-1][1]) + return self.has_rung_filled(len(self.rungs) - 2) - def is_filled(self, rung_id): + def has_rung_filled(self, rung_id): """Return True, if the rung[rung_id] is filled.""" n_rungs = len(self.rungs) n_trials = len(self.rungs[rung_id][1]) - return n_trials >= (n_rungs - rung_id - 1) ** self.reduction_factor + return n_trials >= self.reduction_factor ** (n_rungs - rung_id - 1) def update_rungs(self): """Promote the first candidate that is found and return it @@ -321,7 +338,7 @@ def update_rungs(self): Lookup for promotion in rung l + 1 contains trials of any status. """ - if self.is_done and self.rungs[-1][1]: + if self.is_done: return None for rung_id in range(len(self.rungs) - 2, -1, -1): diff --git a/tests/functional/algos/asha_config.yaml b/tests/functional/algos/asha_config.yaml new file mode 100644 index 000000000..978d7b1fa --- /dev/null +++ b/tests/functional/algos/asha_config.yaml @@ -0,0 +1,21 @@ +name: demo_algo + +pool_size: 1 +max_trials: 100 + +algorithms: + asha: + seed: 1 + num_rungs: 4 + num_brackets: 1 + grace_period: null + max_resources: null + reduction_factor: null + +producer: + strategy: StubParallelStrategy + +database: + type: 'mongodb' + name: 'orion_test' + host: 'mongodb://user:pass@localhost' diff --git a/tests/functional/algos/black_box.py b/tests/functional/algos/black_box.py new file mode 100755 index 000000000..631781eaa --- /dev/null +++ b/tests/functional/algos/black_box.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Simple one dimensional example with noise level for a possible user's script.""" +import argparse +import random + +from orion.client import report_results + + +def function(x, noise): + """Evaluate partial information of a quadratic.""" + z = (x - 34.56789) * random.gauss(0, noise) + return 4 * z**2 + 23.4, 8 * z + + +def execute(): + """Execute a simple pipeline as an example.""" + # 1. Receive inputs as you want + parser = argparse.ArgumentParser() + parser.add_argument('-x', type=float, required=True) + parser.add_argument('--fidelity', type=int, default=10) + inputs = parser.parse_args() + + assert 0 <= inputs.fidelity <= 10 + + noise = (1 - inputs.fidelity / 10) + 0.0001 + + # 2. Perform computations + y, dy = function(inputs.x, noise) + + # 3. Gather and report results + results = list() + results.append(dict( + name='example_objective', + type='objective', + value=y)) + results.append(dict( + name='example_gradient', + type='gradient', + value=[dy])) + + report_results(results) + + +if __name__ == "__main__": + execute() diff --git a/tests/functional/algos/random_config.yaml b/tests/functional/algos/random_config.yaml new file mode 100644 index 000000000..2f3763527 --- /dev/null +++ b/tests/functional/algos/random_config.yaml @@ -0,0 +1,13 @@ +name: demo_algo + +pool_size: 1 +max_trials: 100 + +algorithms: + random: + seed: 1 + +database: + type: 'mongodb' + name: 'orion_test' + host: 'mongodb://user:pass@localhost' diff --git a/tests/functional/algos/test_algos.py b/tests/functional/algos/test_algos.py new file mode 100644 index 000000000..4375bd776 --- /dev/null +++ b/tests/functional/algos/test_algos.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Perform a functional test for algos included with orion.""" +import os + +import pytest +import yaml + +import orion.core.cli +from orion.storage.base import get_storage + + +config_files = ['random_config.yaml'] +fidelity_config_files = ['random_config.yaml', 'asha_config.yaml'] +fidelity_only_config_files = list(set(fidelity_config_files) - set(config_files)) + + +@pytest.mark.usefixtures("clean_db") +@pytest.mark.usefixtures("null_db_instances") +@pytest.mark.parametrize('config_file', fidelity_only_config_files) +def test_missing_fidelity(monkeypatch, config_file): + """Test a simple usage scenario.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + + with pytest.raises(RuntimeError) as exc: + orion.core.cli.main(["hunt", "--config", config_file, + "./black_box.py", "-x~uniform(-50, 50)"]) + assert "https://orion.readthedocs.io/en/develop/user/algorithms.html" in str(exc.value) + + +@pytest.mark.usefixtures("clean_db") +@pytest.mark.usefixtures("null_db_instances") +@pytest.mark.parametrize('config_file', config_files) +def test_simple(monkeypatch, config_file): + """Test a simple usage scenario.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + orion.core.cli.main(["hunt", "--config", config_file, + "./black_box.py", "-x~uniform(-50, 50)"]) + + with open(config_file, 'rb') as f: + config = yaml.safe_load(f) + + storage = get_storage() + exp = list(storage.fetch_experiments({'name': config['name']})) + assert len(exp) == 1 + exp = exp[0] + assert '_id' in exp + exp_id = exp['_id'] + assert exp['name'] == config['name'] + assert exp['pool_size'] == 1 + assert exp['max_trials'] == 100 + assert exp['algorithms'] == config['algorithms'] + assert 'user' in exp['metadata'] + assert 'datetime' in exp['metadata'] + assert 'orion_version' in exp['metadata'] + assert 'user_script' in exp['metadata'] + assert os.path.isabs(exp['metadata']['user_script']) + assert exp['metadata']['user_args'] == ['-x~uniform(-50, 50)'] + + trials = storage.fetch_trials(uid=exp_id) + assert len(trials) <= config['max_trials'] + assert trials[-1].status == 'completed' + + best_trial = next(iter(sorted(trials, key=lambda trial: trial.objective.value))) + assert best_trial.objective.name == 'example_objective' + assert abs(best_trial.objective.value - 23.4) < 1e-5 + assert len(best_trial.params) == 1 + param = best_trial.params[0] + assert param.name == '/x' + assert param.type == 'real' + + +@pytest.mark.usefixtures("clean_db") +@pytest.mark.usefixtures("null_db_instances") +@pytest.mark.parametrize('config_file', fidelity_config_files) +def test_with_fidelity(database, monkeypatch, config_file): + """Test a scenario with fidelity.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + orion.core.cli.main(["hunt", "--config", config_file, + "./black_box.py", "-x~uniform(-50, 50)", "--fidelity~fidelity(1,10,4)"]) + + with open(config_file, 'rb') as f: + config = yaml.safe_load(f) + + storage = get_storage() + exp = list(storage.fetch_experiments({'name': config['name']})) + assert len(exp) == 1 + exp = exp[0] + assert '_id' in exp + exp_id = exp['_id'] + assert exp['name'] == config['name'] + assert exp['pool_size'] == 1 + assert exp['max_trials'] == 100 + assert exp['algorithms'] == config['algorithms'] + assert 'user' in exp['metadata'] + assert 'datetime' in exp['metadata'] + assert 'orion_version' in exp['metadata'] + assert 'user_script' in exp['metadata'] + assert os.path.isabs(exp['metadata']['user_script']) + assert exp['metadata']['user_args'] == ['-x~uniform(-50, 50)', "--fidelity~fidelity(1,10,4)"] + + trials = storage.fetch_trials(uid=exp_id) + assert len(trials) <= config['max_trials'] + assert trials[-1].status == 'completed' + + best_trial = next(iter(sorted(trials, key=lambda trial: trial.objective.value))) + assert best_trial.objective.name == 'example_objective' + assert abs(best_trial.objective.value - 23.4) < 1e-5 + assert len(best_trial.params) == 2 + fidelity = best_trial.params[0] + assert fidelity.name == '/fidelity' + assert fidelity.type == 'fidelity' + assert fidelity.value == 10 + param = best_trial.params[1] + assert param.name == '/x' + assert param.type == 'real' diff --git a/tests/unittests/algo/test_asha.py b/tests/unittests/algo/test_asha.py index 563c58ab2..bc63ceb50 100644 --- a/tests/unittests/algo/test_asha.py +++ b/tests/unittests/algo/test_asha.py @@ -384,6 +384,7 @@ def test_suggest_opt_out(self, asha, bracket, rung_0, rung_1, rung_2): """Test that ASHA opts out when last rung is full.""" asha.brackets = [bracket] bracket.asha = asha + bracket.rungs[1] = rung_1 bracket.rungs[2] = rung_2 points = asha.suggest()