Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make some things configurable #265

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions docs/src/user/searchspace.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,24 @@ Configuration file
You can use configuration files to define search space with placeholder
``'orion~dist(*args, **kwargs)'`` in yaml and json files or
``name~dist(*args, **kwargs)`` in any other text-based file.
For now Oríon can only recognize the
configuration file if it is passed with the argument ``--config`` to the user script. This should
not be confused with the argument ``--config`` of ``orion hunt``, which is the configuration of
Oríon. We are here referring the configuration of the user script, represented with
``my_script_config.txt`` in the following example.
By default Oríon will only consider the file passed through the argument ``--config`` as a
configuration file for the user script. However, it is possible to change the default argument
inside the configuration file of Oríon through the `user_script_config` argument, like this:

.. code-block:: yaml

user_script_config: configuration


.. code-block:: console

orion hunt --config my_orion_config.yaml ./my_script --configuration my_script_config.txt

As you can see, the configuration file for the user script is now passed through `--configuration`.

This should not be confused with the argument ``--config`` of ``orion hunt``,
which is the configuration of Oríon. We are here referring the configuration of the user script,
represented with ``my_script_config.txt`` in the following example.

.. code-block:: console

Expand Down
17 changes: 17 additions & 0 deletions src/orion/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def define_config():
"""Create and define the fields of the configuration object."""
config = Configuration()
define_database_config(config)
define_worker_config(config)

config.add_option(
'user_script_config', option_type=str, default='config')

return config


Expand All @@ -76,6 +81,18 @@ def define_database_config(config):
config.database = database_config


def define_worker_config(config):
"""Create and define the fields of the worker configuration."""
worker_config = Configuration()

worker_config.add_option(
'heartbeat', option_type=int, default=120)
worker_config.add_option(
'max_broken', option_type=int, default=3)

config.worker = worker_config


def build_config():
"""Define the config and fill it based on global configuration files."""
config = define_config()
Expand Down
2 changes: 1 addition & 1 deletion src/orion/core/io/orion_cmdline_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def format(self, config_path=None, trial=None, experiment=None):
configuration = self._build_configuration(trial)

if config_path is not None:
configuration['config'] = config_path
configuration[self.config_prefix] = config_path

templated = self.parser.format(configuration)

Expand Down
3 changes: 2 additions & 1 deletion src/orion/core/io/space_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from scipy.stats import distributions as sp_dists

from orion.algo.space import (Categorical, Fidelity, Integer, Real, Space)
from orion.core import config as orion_config
from orion.core.io.orion_cmdline_parser import OrionCmdlineParser

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -267,7 +268,7 @@ def build_from(self, config):
The problem's search space definition.

"""
self.parser = OrionCmdlineParser()
self.parser = OrionCmdlineParser(orion_config.user_script_config)
self.parser.parse(config)

return self.build(self.parser.priors)
Expand Down
3 changes: 2 additions & 1 deletion src/orion/core/worker/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import sys

import orion.core
from orion.core.cli.evc import fetch_branching_configuration
from orion.core.evc.adapters import Adapter, BaseAdapter
from orion.core.evc.conflicts import detect_conflicts
Expand Down Expand Up @@ -372,7 +373,7 @@ def is_broken(self):

"""
num_broken_trials = self._storage.count_broken_trials(self)
return num_broken_trials >= 3 # TODO: make this configurable ?
return num_broken_trials >= orion.core.config.worker.max_broken

@property
def space(self):
Expand Down
5 changes: 3 additions & 2 deletions src/orion/storage/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datetime
import logging

import orion.core
from orion.core.io.convert import JSONConverter
from orion.core.io.database import Database
from orion.core.worker.trial import Trial
Expand Down Expand Up @@ -184,8 +185,8 @@ def _update_trial(self, trial: Trial, where=None, **kwargs) -> Trial:

def fetch_lost_trials(self, experiment):
"""See :func:`~orion.storage.BaseStorageProtocol.fetch_lost_trials`"""
# TODO: Configure this
threshold = datetime.datetime.utcnow() - datetime.timedelta(seconds=60 * 2)
heartbeat = orion.core.config.worker.heartbeat
threshold = datetime.datetime.utcnow() - datetime.timedelta(seconds=heartbeat)
lte_comparison = {'$lte': threshold}
query = {
'experiment': experiment._id,
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/demo/black_box_w_config_other.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Simple one dimensional example for a possible user's script."""
import argparse

import yaml

from orion.client import report_results


def function(x):
"""Evaluate partial information of a quadratic."""
z = x - 34.56789
return 4 * z**2 + 23.4, 8 * z


def execute():
"""Execute a simple pipeline as an example."""
# 1. Receive inputs as you want
parser = argparse.ArgumentParser()
parser.add_argument('--configuration', required=True)
inputs = parser.parse_args()

with open(inputs.configuration, 'r') as f:
config = yaml.load(f)

# 2. Perform computations

y, dy = function(config['x'])

# 3. Gather and report results
results = list()
results.append(dict(
name='example_objective',
type='objective',
value=y))
results.append(dict(
name='example_gradient',
type='gradient',
value=[dy]))

report_results(results)


if __name__ == "__main__":
execute()
19 changes: 19 additions & 0 deletions tests/functional/demo/orion_config_other.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: voila_voici

pool_size: 1
max_trials: 100

algorithms:
gradient_descent:
learning_rate: 0.1
# dx_tolerance: 1e-7

user_script_config: configuration

producer:
strategy: NoParallelStrategy

database:
type: 'mongodb'
name: 'orion_test'
host: 'mongodb://user:pass@localhost'
48 changes: 48 additions & 0 deletions tests/functional/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,51 @@ def test_demo_with_shutdown_quickly(monkeypatch):
"./black_box.py", "-x~uniform(-50, 50)"])

assert process.wait(timeout=10) == 0


@pytest.mark.usefixtures("clean_db")
@pytest.mark.usefixtures("null_db_instances")
def test_demo_with_nondefault_config_keyword(database, monkeypatch):
"""Check that the user script configuration file is correctly used with a new keyword."""
monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__)))
orion.core.config.user_script_config = 'configuration'
orion.core.cli.main(["hunt", "--config", "./orion_config_other.yaml",
"./black_box_w_config_other.py", "--configuration", "script_config.yaml"])

exp = list(database.experiments.find({'name': 'voila_voici'}))
assert len(exp) == 1
exp = exp[0]
assert '_id' in exp
exp_id = exp['_id']
assert exp['name'] == 'voila_voici'
assert exp['pool_size'] == 1
assert exp['max_trials'] == 100
assert exp['algorithms'] == {'gradient_descent': {'learning_rate': 0.1,
'dx_tolerance': 1e-7}}
assert 'user' in exp['metadata']
assert 'datetime' in exp['metadata']
assert 'orion_version' in exp['metadata']
assert 'user_script' in exp['metadata']
assert os.path.isabs(exp['metadata']['user_script'])
assert exp['metadata']['user_args'] == ['--configuration', 'script_config.yaml']

trials = list(database.trials.find({'experiment': exp_id}))
assert len(trials) <= 15
assert trials[-1]['status'] == 'completed'
trials = list(sorted(trials, key=lambda trial: trial['submit_time']))
for result in trials[-1]['results']:
assert result['type'] != 'constraint'
if result['type'] == 'objective':
assert abs(result['value'] - 23.4) < 1e-6
assert result['name'] == 'example_objective'
elif result['type'] == 'gradient':
res = numpy.asarray(result['value'])
assert 0.1 * numpy.sqrt(res.dot(res)) < 1e-7
assert result['name'] == 'example_gradient'
params = trials[-1]['params']
assert len(params) == 1
assert params[0]['name'] == '/x'
assert params[0]['type'] == 'real'
assert (params[0]['value'] - 34.56789) < 1e-5

orion.core.config.user_script_config = 'config'
14 changes: 14 additions & 0 deletions tests/unittests/core/io/test_orion_cmdline_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,17 @@ def test_format_with_properties(parser, cmd_with_properties, hacked_exp):

assert trial.hash_name in cmd_line
assert 'supernaedo2-dendi' in cmd_line


def test_configurable_config_arg(parser_diff_prefix, yaml_sample_path):
"""Parse from a yaml config only."""
parser_diff_prefix.parse(["--config2", yaml_sample_path])
config = parser_diff_prefix.priors

assert len(config.keys()) == 6
assert '/layers/1/width' in config
assert '/layers/1/type' in config
assert '/layers/2/type' in config
assert '/training/lr0' in config
assert '/training/mbs' in config
assert '/something-same' in config
43 changes: 43 additions & 0 deletions tests/unittests/core/worker/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

from orion.algo.base import BaseAlgorithm
import orion.core
from orion.core.io.database import DuplicateKeyError
from orion.core.utils.tests import OrionState
import orion.core.worker.experiment
Expand Down Expand Up @@ -671,6 +672,30 @@ def fetch_lost_trials(self, query):
m.setattr(hacked_exp.__class__, 'fetch_trials', fetch_lost_trials)
hacked_exp.fix_lost_trials()

def test_fix_lost_trials_configurable_hb(self, hacked_exp, random_dt):
"""Test that heartbeat is correctly being configured."""
exp_query = {'experiment': hacked_exp.id}
trial = hacked_exp.fetch_trials(exp_query)[0]
old_heartbeat_value = orion.core.config.worker.heartbeat
heartbeat = random_dt - datetime.timedelta(seconds=180)

get_storage().set_trial_status(trial,
status='reserved',
heartbeat=heartbeat)

trials = get_storage().fetch_trial_by_status(hacked_exp, 'reserved')

assert trial.id in [t.id for t in trials]

orion.core.config.worker.heartbeat = 210
hacked_exp.fix_lost_trials()

trials = get_storage().fetch_trial_by_status(hacked_exp, 'reserved')

assert trial.id in [t.id for t in trials]

orion.core.config.worker.heartbeat = old_heartbeat_value


def test_update_completed_trial(hacked_exp, database, random_dt):
"""Successfully push a completed trial into database."""
Expand Down Expand Up @@ -802,6 +827,24 @@ def test_broken_property(hacked_exp):
assert hacked_exp.is_broken


def test_configurable_broken_property(hacked_exp):
"""Check if max_broken changes after configuration."""
assert not hacked_exp.is_broken
trials = hacked_exp.fetch_trials({})[:3]
old_broken_value = orion.core.config.worker.max_broken

for trial in trials:
get_storage().set_trial_status(trial, status='broken')

assert hacked_exp.is_broken

orion.core.config.worker.max_broken = 4

assert not hacked_exp.is_broken

orion.core.config.worker.max_broken = old_broken_value


def test_experiment_stats(hacked_exp, exp_config, random_dt):
"""Check that property stats is returning a proper summary of experiment's results."""
stats = hacked_exp.stats
Expand Down