From da37cf77aabb36cade2d355f5aaf4b03bf096b94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Corneau-Tremblay?= Date: Fri, 23 Aug 2019 13:56:06 -0400 Subject: [PATCH] Add functional test for user script and fix bug The `OrionCmdlineParser` formatting was not using the `config_prefix` attribute to recreate the configuration file path. --- src/orion/core/io/orion_cmdline_parser.py | 2 +- .../demo/black_box_w_config_other.py | 46 +++++++++++++++++++ tests/functional/demo/orion_config_other.yaml | 20 ++++++++ tests/functional/demo/test_demo.py | 46 +++++++++++++++++++ .../unittests/core/worker/test_experiment.py | 6 +-- 5 files changed, 116 insertions(+), 4 deletions(-) create mode 100755 tests/functional/demo/black_box_w_config_other.py create mode 100644 tests/functional/demo/orion_config_other.yaml diff --git a/src/orion/core/io/orion_cmdline_parser.py b/src/orion/core/io/orion_cmdline_parser.py index d43446f2e1..cc379edae4 100644 --- a/src/orion/core/io/orion_cmdline_parser.py +++ b/src/orion/core/io/orion_cmdline_parser.py @@ -367,7 +367,7 @@ def format(self, config_path=None, trial=None, experiment=None): configuration = self._build_configuration(trial) if config_path is not None: - configuration['config'] = config_path + configuration[self.config_prefix] = config_path templated = self.parser.format(configuration) diff --git a/tests/functional/demo/black_box_w_config_other.py b/tests/functional/demo/black_box_w_config_other.py new file mode 100755 index 0000000000..9de0082137 --- /dev/null +++ b/tests/functional/demo/black_box_w_config_other.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Simple one dimensional example for a possible user's script.""" +import argparse + +import yaml + +from orion.client import report_results + + +def function(x): + """Evaluate partial information of a quadratic.""" + z = x - 34.56789 + return 4 * z**2 + 23.4, 8 * z + + +def execute(): + """Execute a simple pipeline as an example.""" + # 1. Receive inputs as you want + parser = argparse.ArgumentParser() + parser.add_argument('--configuration', required=True) + inputs = parser.parse_args() + + with open(inputs.configuration, 'r') as f: + config = yaml.load(f) + + # 2. Perform computations + + y, dy = function(config['x']) + + # 3. Gather and report results + results = list() + results.append(dict( + name='example_objective', + type='objective', + value=y)) + results.append(dict( + name='example_gradient', + type='gradient', + value=[dy])) + + report_results(results) + + +if __name__ == "__main__": + execute() diff --git a/tests/functional/demo/orion_config_other.yaml b/tests/functional/demo/orion_config_other.yaml new file mode 100644 index 0000000000..443de631b7 --- /dev/null +++ b/tests/functional/demo/orion_config_other.yaml @@ -0,0 +1,20 @@ +name: voila_voici + +pool_size: 1 +max_trials: 100 + +algorithms: + gradient_descent: + learning_rate: 0.1 + # dx_tolerance: 1e-7 + +user_script_config: configuration +working_dir: ./ + +producer: + strategy: NoParallelStrategy + +database: + type: 'mongodb' + name: 'orion_test' + host: 'mongodb://user:pass@localhost' diff --git a/tests/functional/demo/test_demo.py b/tests/functional/demo/test_demo.py index 522ad2ea16..af13ffdbd0 100644 --- a/tests/functional/demo/test_demo.py +++ b/tests/functional/demo/test_demo.py @@ -478,3 +478,49 @@ def test_demo_with_shutdown_quickly(monkeypatch): "./black_box.py", "-x~uniform(-50, 50)"]) assert process.wait(timeout=10) == 0 + + +@pytest.mark.usefixtures("clean_db") +@pytest.mark.usefixtures("null_db_instances") +def test_demo_with_nondefault_config_keyword(database, monkeypatch): + """Check that the user script configuration file is correctly used with a new keyword.""" + monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) + orion.core.config.user_script_config = 'configuration' + orion.core.cli.main(["hunt", "--config", "./orion_config_other.yaml", + "./black_box_w_config_other.py", "--configuration", "script_config.yaml"]) + + exp = list(database.experiments.find({'name': 'voila_voici'})) + assert len(exp) == 1 + exp = exp[0] + assert '_id' in exp + exp_id = exp['_id'] + assert exp['name'] == 'voila_voici' + assert exp['pool_size'] == 1 + assert exp['max_trials'] == 100 + assert exp['algorithms'] == {'gradient_descent': {'learning_rate': 0.1, + 'dx_tolerance': 1e-7}} + assert 'user' in exp['metadata'] + assert 'datetime' in exp['metadata'] + assert 'orion_version' in exp['metadata'] + assert 'user_script' in exp['metadata'] + assert os.path.isabs(exp['metadata']['user_script']) + assert exp['metadata']['user_args'] == ['--configuration', 'script_config.yaml'] + + trials = list(database.trials.find({'experiment': exp_id})) + assert len(trials) <= 15 + assert trials[-1]['status'] == 'completed' + trials = list(sorted(trials, key=lambda trial: trial['submit_time'])) + for result in trials[-1]['results']: + assert result['type'] != 'constraint' + if result['type'] == 'objective': + assert abs(result['value'] - 23.4) < 1e-6 + assert result['name'] == 'example_objective' + elif result['type'] == 'gradient': + res = numpy.asarray(result['value']) + assert 0.1 * numpy.sqrt(res.dot(res)) < 1e-7 + assert result['name'] == 'example_gradient' + params = trials[-1]['params'] + assert len(params) == 1 + assert params[0]['name'] == '/x' + assert params[0]['type'] == 'real' + assert (params[0]['value'] - 34.56789) < 1e-5 diff --git a/tests/unittests/core/worker/test_experiment.py b/tests/unittests/core/worker/test_experiment.py index 3c5064d204..d56f2a6ada 100644 --- a/tests/unittests/core/worker/test_experiment.py +++ b/tests/unittests/core/worker/test_experiment.py @@ -11,7 +11,7 @@ import pytest from orion.algo.base import BaseAlgorithm -from orion.core import config +import orion.core from orion.core.io.database import DuplicateKeyError from orion.core.utils.tests import OrionState import orion.core.worker.experiment @@ -688,7 +688,7 @@ def test_fix_lost_trials_configurable_hb(self, hacked_exp, random_dt): assert len(hacked_exp.fetch_trials(exp_query)) == 1 - config.worker.heartbeat = 150 + orion.core.config.worker.heartbeat = 150 hacked_exp.fix_lost_trials() @@ -835,7 +835,7 @@ def test_configurable_broken_property(hacked_exp): assert hacked_exp.is_broken - config.worker.max_broken = 4 + orion.core.config.worker.max_broken = 4 assert not hacked_exp.is_broken