-
Notifications
You must be signed in to change notification settings - Fork 7
/
tune.py
128 lines (102 loc) · 4.33 KB
/
tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import time
from ray.tune.error import TuneError
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.log_sync import wait_for_log_sync
from ray.tune.trial_runner import TrialRunner
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
FIFOScheduler, MedianStoppingRule)
from ray.tune.web_server import TuneServer
logger = logging.getLogger(__name__)
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
"HyperBand": HyperBandScheduler,
"AsyncHyperBand": AsyncHyperBandScheduler,
}
def _make_scheduler(args):
if args.scheduler in _SCHEDULERS:
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
else:
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments=None,
search_alg=None,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True,
queue_trials=False,
trial_executor=None,
raise_on_failed_trial=True):
"""Runs and blocks until all trials finish.
Args:
experiments (Experiment | list | dict): Experiments to run. Will be
passed to `search_alg` via `add_configurations`.
search_alg (SearchAlgorithm): Search Algorithm. Defaults to
BasicVariantGenerator.
scheduler (TrialScheduler): Scheduler for executing
the experiment. Choose among FIFO (default), MedianStopping,
AsyncHyperBand, and HyperBand.
with_server (bool): Starts a background Tune server. Needed for
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (bool): How much output should be printed for each trial.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
trial_executor (TrialExecutor): Manage the execution of trials.
raise_on_failed_trial (bool): Raise TuneError if there exists failed
trial (of ERROR state) when the experiments complete.
Examples:
>>> experiment_spec = Experiment("experiment", my_func)
>>> run_experiments(experiments=experiment_spec)
>>> experiment_spec = {"experiment": {"run": my_func}}
>>> run_experiments(experiments=experiment_spec)
>>> run_experiments(
>>> experiments=experiment_spec,
>>> scheduler=MedianStoppingRule(...))
>>> run_experiments(
>>> experiments=experiment_spec,
>>> search_alg=SearchAlgorithm(),
>>> scheduler=MedianStoppingRule(...))
Returns:
List of Trial objects, holding data for each executed trial.
"""
if scheduler is None:
scheduler = FIFOScheduler()
if search_alg is None:
search_alg = BasicVariantGenerator()
search_alg.add_configurations(experiments)
runner = TrialRunner(
search_alg,
scheduler=scheduler,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose,
queue_trials=queue_trials,
trial_executor=trial_executor)
logger.info(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
logger.info(runner.debug_string())
last_debug = time.time()
logger.info(runner.debug_string(max_debug=99999))
wait_for_log_sync()
errored_trials = []
for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
errored_trials += [trial]
if errored_trials:
if raise_on_failed_trial:
raise TuneError("Trials did not complete", errored_trials)
else:
logger.error("Trials did not complete: %s", errored_trials)
return runner.get_trials()