Skip to content

Commit

Permalink
Fix ASHA termination condition
Browse files Browse the repository at this point in the history
Why:

The termination condition was wrong and would stop when the penultimate
rung was filled, making it impossible to ever run the final trial with
max resources.
  • Loading branch information
bouthilx committed Aug 28, 2019
1 parent a0ddc05 commit fd4f94a
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/orion/algo/asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
Params: {params}
"""

SPACE_ERROR = """
ASHA cannot be used if space does contain a fidelity dimension.
For more information on the configuration and usage of ASHA, see
https://orion.readthedocs.io/en/develop/user/algorithms.html#asha
"""


class ASHA(BaseAlgorithm):
"""Asynchronous Successive Halving Algorithm
Expand Down Expand Up @@ -72,11 +78,16 @@ def __init__(self, space, seed=None, grace_period=None, max_resources=None,
reduction_factor=None, num_rungs=None, num_brackets=1):
super(ASHA, self).__init__(
space, seed=seed, max_resources=max_resources, grace_period=grace_period,
reduction_factor=reduction_factor, num_brackets=num_brackets)
reduction_factor=reduction_factor, num_rungs=num_rungs, num_brackets=num_brackets)

self.trial_info = {} # Stores Trial -> Bracket

fidelity_dim = space.values()[self.fidelity_index]
try:
fidelity_index = self.fidelity_index
except IndexError:
raise RuntimeError(SPACE_ERROR)

fidelity_dim = space.values()[fidelity_index]

if grace_period is not None:
logger.warning(
Expand Down Expand Up @@ -106,14 +117,15 @@ def __init__(self, space, seed=None, grace_period=None, max_resources=None,
raise AttributeError("Reduction factor for ASHA needs to be at least 2.")

if num_rungs is None:
num_rungs = numpy.log(max_resources / min_resources) / numpy.log(reduction_factor) + 1
num_rungs = int(numpy.log(max_resources / min_resources) /
numpy.log(reduction_factor) + 1)

self.num_rungs = num_rungs

budgets = numpy.logspace(
numpy.log(min_resources) / numpy.log(reduction_factor),
numpy.log(max_resources) / numpy.log(reduction_factor),
num_rungs, base=reduction_factor)
num_rungs, base=reduction_factor).astype(int)

# Tracks state for new trial add
self.brackets = [
Expand Down Expand Up @@ -162,7 +174,7 @@ def suggest(self, num=1):
logger.debug('Promoting')
return [candidate]

if all(bracket.is_done for bracket in self.brackets):
if all(bracket.is_filled for bracket in self.brackets):
logger.debug('All brackets are filled.')
return None

Expand All @@ -178,7 +190,7 @@ def suggest(self, num=1):

sizes = numpy.array([len(b.rungs) for b in self.brackets])
probs = numpy.e**(sizes - sizes.max())
probs = numpy.array([prob * int(not bracket.is_done)
probs = numpy.array([prob * int(not bracket.is_filled)
for prob, bracket in zip(probs, self.brackets)])
normalized = probs / probs.sum()
idx = self.rng.choice(len(self.brackets), p=normalized)
Expand Down Expand Up @@ -298,14 +310,19 @@ def get_candidate(self, rung_id):

@property
def is_done(self):
"""Return True, if the last rung is filled."""
return len(self.rungs[-1][1])

@property
def is_filled(self):
"""Return True, if the penultimate rung is filled."""
return self.is_filled(len(self.rungs) - 2) or len(self.rungs[-1][1])
return self.has_rung_filled(len(self.rungs) - 2)

def is_filled(self, rung_id):
def has_rung_filled(self, rung_id):
"""Return True, if the rung[rung_id] is filled."""
n_rungs = len(self.rungs)
n_trials = len(self.rungs[rung_id][1])
return n_trials >= (n_rungs - rung_id - 1) ** self.reduction_factor
return n_trials >= self.reduction_factor ** (n_rungs - rung_id - 1)

def update_rungs(self):
"""Promote the first candidate that is found and return it
Expand All @@ -321,7 +338,7 @@ def update_rungs(self):
Lookup for promotion in rung l + 1 contains trials of any status.
"""
if self.is_done and self.rungs[-1][1]:
if self.is_done:
return None

for rung_id in range(len(self.rungs) - 2, -1, -1):
Expand Down
21 changes: 21 additions & 0 deletions tests/functional/algos/asha_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: demo_algo

pool_size: 1
max_trials: 100

algorithms:
asha:
seed: 1
num_rungs: 4
num_brackets: 1
grace_period: null
max_resources: null
reduction_factor: null

producer:
strategy: StubParallelStrategy

database:
type: 'mongodb'
name: 'orion_test'
host: 'mongodb://user:pass@localhost'
46 changes: 46 additions & 0 deletions tests/functional/algos/black_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Simple one dimensional example with noise level for a possible user's script."""
import argparse
import random

from orion.client import report_results


def function(x, noise):
"""Evaluate partial information of a quadratic."""
z = (x - 34.56789) * random.gauss(0, noise)
return 4 * z**2 + 23.4, 8 * z


def execute():
"""Execute a simple pipeline as an example."""
# 1. Receive inputs as you want
parser = argparse.ArgumentParser()
parser.add_argument('-x', type=float, required=True)
parser.add_argument('--fidelity', type=int, default=10)
inputs = parser.parse_args()

assert 0 <= inputs.fidelity <= 10

noise = (1 - inputs.fidelity / 10) + 0.0001

# 2. Perform computations
y, dy = function(inputs.x, noise)

# 3. Gather and report results
results = list()
results.append(dict(
name='example_objective',
type='objective',
value=y))
results.append(dict(
name='example_gradient',
type='gradient',
value=[dy]))

report_results(results)


if __name__ == "__main__":
execute()
13 changes: 13 additions & 0 deletions tests/functional/algos/random_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: demo_algo

pool_size: 1
max_trials: 100

algorithms:
random:
seed: 1

database:
type: 'mongodb'
name: 'orion_test'
host: 'mongodb://user:pass@localhost'
116 changes: 116 additions & 0 deletions tests/functional/algos/test_algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Perform a functional test for algos included with orion."""
import os

import pytest
import yaml

import orion.core.cli
from orion.storage.base import get_storage


config_files = ['random_config.yaml']
fidelity_config_files = ['random_config.yaml', 'asha_config.yaml']
fidelity_only_config_files = list(set(fidelity_config_files) - set(config_files))


@pytest.mark.usefixtures("clean_db")
@pytest.mark.usefixtures("null_db_instances")
@pytest.mark.parametrize('config_file', fidelity_only_config_files)
def test_missing_fidelity(monkeypatch, config_file):
"""Test a simple usage scenario."""
monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__)))

with pytest.raises(RuntimeError) as exc:
orion.core.cli.main(["hunt", "--config", config_file,
"./black_box.py", "-x~uniform(-50, 50)"])
assert "https://orion.readthedocs.io/en/develop/user/algorithms.html" in str(exc.value)


@pytest.mark.usefixtures("clean_db")
@pytest.mark.usefixtures("null_db_instances")
@pytest.mark.parametrize('config_file', config_files)
def test_simple(monkeypatch, config_file):
"""Test a simple usage scenario."""
monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__)))
orion.core.cli.main(["hunt", "--config", config_file,
"./black_box.py", "-x~uniform(-50, 50)"])

with open(config_file, 'rb') as f:
config = yaml.safe_load(f)

storage = get_storage()
exp = list(storage.fetch_experiments({'name': config['name']}))
assert len(exp) == 1
exp = exp[0]
assert '_id' in exp
exp_id = exp['_id']
assert exp['name'] == config['name']
assert exp['pool_size'] == 1
assert exp['max_trials'] == 100
assert exp['algorithms'] == config['algorithms']
assert 'user' in exp['metadata']
assert 'datetime' in exp['metadata']
assert 'orion_version' in exp['metadata']
assert 'user_script' in exp['metadata']
assert os.path.isabs(exp['metadata']['user_script'])
assert exp['metadata']['user_args'] == ['-x~uniform(-50, 50)']

trials = storage.fetch_trials(uid=exp_id)
assert len(trials) <= config['max_trials']
assert trials[-1].status == 'completed'

best_trial = next(iter(sorted(trials, key=lambda trial: trial.objective.value)))
assert best_trial.objective.name == 'example_objective'
assert abs(best_trial.objective.value - 23.4) < 1e-5
assert len(best_trial.params) == 1
param = best_trial.params[0]
assert param.name == '/x'
assert param.type == 'real'


@pytest.mark.usefixtures("clean_db")
@pytest.mark.usefixtures("null_db_instances")
@pytest.mark.parametrize('config_file', fidelity_config_files)
def test_with_fidelity(database, monkeypatch, config_file):
"""Test a scenario with fidelity."""
monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__)))
orion.core.cli.main(["hunt", "--config", config_file,
"./black_box.py", "-x~uniform(-50, 50)", "--fidelity~fidelity(1,10,4)"])

with open(config_file, 'rb') as f:
config = yaml.safe_load(f)

storage = get_storage()
exp = list(storage.fetch_experiments({'name': config['name']}))
assert len(exp) == 1
exp = exp[0]
assert '_id' in exp
exp_id = exp['_id']
assert exp['name'] == config['name']
assert exp['pool_size'] == 1
assert exp['max_trials'] == 100
assert exp['algorithms'] == config['algorithms']
assert 'user' in exp['metadata']
assert 'datetime' in exp['metadata']
assert 'orion_version' in exp['metadata']
assert 'user_script' in exp['metadata']
assert os.path.isabs(exp['metadata']['user_script'])
assert exp['metadata']['user_args'] == ['-x~uniform(-50, 50)', "--fidelity~fidelity(1,10,4)"]

trials = storage.fetch_trials(uid=exp_id)
assert len(trials) <= config['max_trials']
assert trials[-1].status == 'completed'

best_trial = next(iter(sorted(trials, key=lambda trial: trial.objective.value)))
assert best_trial.objective.name == 'example_objective'
assert abs(best_trial.objective.value - 23.4) < 1e-5
assert len(best_trial.params) == 2
fidelity = best_trial.params[0]
assert fidelity.name == '/fidelity'
assert fidelity.type == 'fidelity'
assert fidelity.value == 10
param = best_trial.params[1]
assert param.name == '/x'
assert param.type == 'real'
1 change: 1 addition & 0 deletions tests/unittests/algo/test_asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def test_suggest_opt_out(self, asha, bracket, rung_0, rung_1, rung_2):
"""Test that ASHA opts out when last rung is full."""
asha.brackets = [bracket]
bracket.asha = asha
bracket.rungs[1] = rung_1
bracket.rungs[2] = rung_2

points = asha.suggest()
Expand Down

0 comments on commit fd4f94a

Please sign in to comment.