diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index 41dc3b6cdd26..b88282aefae5 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -43,6 +43,7 @@ def __init__(self, if connect: redis_password = head_node_args.get("redis_password") output_info = ray.init( + ignore_reinit_error=True, redis_address=self.redis_address, redis_password=redis_password) logger.info(output_info) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index d2c79e6d871a..a7de6d5f96ba 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -71,19 +71,19 @@ def _init(self): self._log_syncer = get_syncer(self.logdir, self.uri) def on_result(self, result): - for logger in self._loggers: - logger.on_result(result) + for _logger in self._loggers: + _logger.on_result(result) self._log_syncer.set_worker_ip(result.get(NODE_IP)) self._log_syncer.sync_if_needed() def close(self): - for logger in self._loggers: - logger.close() + for _logger in self._loggers: + _logger.close() self._log_syncer.sync_now(force=True) def flush(self): - for logger in self._loggers: - logger.flush() + for _logger in self._loggers: + _logger.flush() self._log_syncer.sync_now(force=True) self._log_syncer.wait() @@ -104,7 +104,7 @@ def _init(self): sort_keys=True, cls=_SafeFallbackEncoder) local_file = os.path.join(self.logdir, "result.json") - self.local_out = open(local_file, "w") + self.local_out = open(local_file, "a") def on_result(self, result): json.dump(result, self, cls=_SafeFallbackEncoder) @@ -114,6 +114,9 @@ def write(self, b): self.local_out.write(b) self.local_out.flush() + def flush(self): + self.local_out.flush() + def close(self): self.local_out.close() @@ -133,6 +136,7 @@ def to_tf_values(result, path): class _TFLogger(Logger): def _init(self): + # TODO(rliaw): Implement a proper resume functionality for this. self._file_writer = tf.summary.FileWriter(self.logdir) def on_result(self, result): @@ -140,7 +144,8 @@ def on_result(self, result): for k in [ "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION ]: - del tmp[k] # not useful to tf log these + if k in tmp: + del tmp[k] # not useful to tf log these values = to_tf_values(tmp, ["ray", "tune"]) train_stats = tf.Summary(value=values) t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] @@ -163,15 +168,21 @@ class _VisKitLogger(Logger): def _init(self): """CSV outputted with Headers as first set of results.""" # Note that we assume params.json was already created by JsonLogger - self._file = open(os.path.join(self.logdir, "progress.csv"), "w") + progress_file = os.path.join(self.logdir, "progress.csv") + self._continuing = os.path.exists(progress_file) + self._file = open(progress_file, "a") self._csv_out = None def on_result(self, result): if self._csv_out is None: self._csv_out = csv.DictWriter(self._file, result.keys()) - self._csv_out.writeheader() + if not self._continuing: + self._csv_out.writeheader() self._csv_out.writerow(result.copy()) + def flush(self): + self._file.flush() + def close(self): self._file.close() diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 6b107b17c82f..3c3a5c509d6b 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -19,8 +19,8 @@ class RayTrialExecutor(TrialExecutor): """An implemention of TrialExecutor based on Ray.""" - def __init__(self, queue_trials=False): - super(RayTrialExecutor, self).__init__(queue_trials) + def __init__(self, queue_trials=False, track_checkpoints=False): + super(RayTrialExecutor, self).__init__(queue_trials, track_checkpoints) self._running = {} # Since trial resume after paused should not run # trial.train.remote(), thus no more new remote object id generated. @@ -60,7 +60,7 @@ def _train(self, trial): def _start_trial(self, trial, checkpoint=None): prior_status = trial.status - trial.status = Trial.RUNNING + self.set_status(trial, Trial.RUNNING) trial.runner = self._setup_runner(trial) if not self.restore(trial, checkpoint): return @@ -88,9 +88,9 @@ def _stop_trial(self, trial, error=False, error_msg=None, """ if error: - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) else: - trial.status = Trial.TERMINATED + self.set_status(trial, Trial.TERMINATED) try: trial.write_error_log(error_msg) @@ -103,7 +103,7 @@ def _stop_trial(self, trial, error=False, error_msg=None, stop_tasks, num_returns=2, timeout=250) except Exception: logger.exception("Error stopping runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) finally: trial.runner = None @@ -302,7 +302,7 @@ def restore(self, trial, checkpoint=None): return True if trial.runner is None: logger.error("Unable to restore - no runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) return False try: value = checkpoint.value @@ -316,5 +316,5 @@ def restore(self, trial, checkpoint=None): return True except Exception: logger.exception("Error restoring runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) return False diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index 59f12181b8ff..8019929d1593 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -2,7 +2,10 @@ from __future__ import division from __future__ import print_function +import inspect import json +import time +import os import pytest try: import pytest_timeout @@ -10,14 +13,39 @@ pytest_timeout = None import ray +from ray import tune from ray.rllib import _register_all from ray.test.cluster_utils import Cluster +from ray.test.test_utils import run_string_as_driver_nonblocking from ray.tune.error import TuneError from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator +def register_fail_trainable(): + class _Fail(tune.Trainable): + """Fails on the 4th iteration.""" + + def _setup(self, config): + self.state = {"hi": 0} + + def _train(self): + self.state["hi"] += 1 + time.sleep(0.5) + if self.state["hi"] >= 4: + assert False + return {} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + tune.register_trainable("test", _Fail) + + def _start_new_cluster(): cluster = Cluster( initialize_head=True, @@ -66,7 +94,6 @@ def start_connected_emptyhead_cluster(): def test_counting_resources(start_connected_cluster): """Tests that Tune accounting is consistent with actual cluster.""" - cluster = start_connected_cluster nodes = [] assert ray.global_state.cluster_resources()["CPU"] == 1 @@ -240,3 +267,156 @@ def test_trial_requeue(start_connected_emptyhead_cluster): with pytest.raises(TuneError): runner.step() + + +def test_cluster_down_simple(start_connected_cluster, tmpdir): + """Tests that TrialRunner save/restore works on cluster shutdown.""" + cluster = start_connected_cluster + cluster.add_node(resources=dict(CPU=1)) + dirpath = str(tmpdir) + runner = TrialRunner( + BasicVariantGenerator(), checkpoint_freq=2, checkpoint_dir=dirpath) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "checkpoint_freq": 1, + "max_failures": 1 + } + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() # start + runner.step() # start2 + runner.step() # step + assert all(t.status == Trial.RUNNING for t in runner.get_trials()) + runner.save() + + cluster.shutdown() + ray.shutdown() + + cluster = _start_new_cluster() + runner = TrialRunner(BasicVariantGenerator()) + runner.restore(dirpath) + runner.step() # start + runner.step() # start2 + + for i in range(3): + runner.step() + + with pytest.raises(TuneError): + runner.step() + + assert all(t.status == Trial.TERMINATED for t in runner.get_trials()) + cluster.shutdown() + + +def test_cluster_down_full(start_connected_cluster, tmpdir): + """Tests that run_experiment restoring works on cluster shutdown.""" + cluster = start_connected_cluster + dirpath = str(tmpdir) + + exp1_args = dict( + run="__fake", + stop=dict(training_iteration=3), + checkpoint_freq=1) + exp2_args = dict(run="__fake", stop=dict(training_iteration=3)) + exp3_args = dict( + run="__fake", + stop=dict(training_iteration=3), + config=dict(mock_error=True)) + exp4_args = dict( + run="__fake", + stop=dict(training_iteration=3), + config=dict(mock_error=True), + checkpoint_freq=1) + + tune.run_experiments( + dict(exp1=exp1_args, exp2=exp2_args, exp3=exp3_args, exp4=exp4_args), + checkpoint_dir=dirpath, + checkpoint_freq=2, + raise_on_failed_trial=False) + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + + # Check that last_result.iteration = 1 + runner = TrialRunner(BasicVariantGenerator()) + runner.restore(dirpath) + trials = runner.get_trials() + trials = tune.run_experiments(restore_from_path=dirpath) + assert len(trials) == 2 + assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials) + cluster.shutdown() + + +def test_cluster_interrupt(start_connected_cluster, tmpdir): + """Tests run_experiment on cluster shutdown even with atypical trial. + + The trial fails on the 4th step, and the checkpointing happens on + the 3rd step, so restoring should actually launch the trial again. + """ + cluster = start_connected_cluster + dirpath = str(tmpdir) + script = """ +import time +import ray +from ray import tune + +ray.init(redis_address="{redis_address}") + +{register_trainable_fn} +{run_register_trainable_fn}() + +kwargs = dict( + run="test", + stop=dict(training_iteration=5), + checkpoint_freq=1, + max_failures=1) + +# This will save to disk on step 0 and step 3 +tune.run_experiments( + dict(experiment1=kwargs), + checkpoint_dir="{checkpoint_dir}", + checkpoint_freq=3, + raise_on_failed_trial=False) +""".format( + redis_address=cluster.redis_address, + checkpoint_dir=dirpath, + register_trainable_fn=inspect.getsource(register_fail_trainable), + run_register_trainable_fn=register_fail_trainable.__name__) + run_string_as_driver_nonblocking(script) + + # Wait until the right checkpoint is saved. + # The trainable returns every 0.5 seconds, so this should not miss + # the checkpoint. + for i in range(30): + if os.path.exists(os.path.join(dirpath, "experiment.state")): + # Inspect the internal trialrunner + runner = TrialRunner(BasicVariantGenerator()) + runner.restore(dirpath) + trials = runner.get_trials() + last_res = trials[0].last_result + if last_res is not None and last_res["training_iteration"] == 3: + break + time.sleep(0.2) + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + register_fail_trainable() + + # Inspect the internal trialrunner just in case + runner = TrialRunner(BasicVariantGenerator()) + runner.restore(dirpath) + trials = runner.get_trials() + assert trials[0].last_result["training_iteration"] == 3 + assert trials[0].status == Trial.PENDING + + # Restore properly from checkpoint + trials = tune.run_experiments( + restore_from_path=dirpath, raise_on_failed_trial=False) + assert all(t.status == Trial.ERROR for t in trials) + cluster.shutdown() diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 8e4aa2cea148..d60ca7ee2b03 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -3,13 +3,14 @@ from __future__ import print_function import os +import shutil import sys +import tempfile import time import unittest import ray from ray.rllib import _register_all - from ray.tune import Trainable, TuneError from ray.tune import register_env, register_trainable, run_experiments from ray.tune.ray_trial_executor import RayTrialExecutor @@ -586,6 +587,25 @@ def train(config, reporter): self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) + def testSimultaneousExperimentRestore(self): + tmpdir = tempfile.mkdtemp() + def train(config, reporter): + for i in range(100): + reporter(timesteps_total=i) + + register_trainable("f1", train) + exp1 = Experiment(**{ + "name": "foo", + "run": "f1", + "config": { + "script_min_iter_time_s": 0 + } + }) + self.assertRaises( + AssertionError, lambda: run_experiments( + exp1, restore_from_path=tmpdir)) + shutil.rmtree(tmpdir) + def testExperimentList(self): def train(config, reporter): for i in range(100): @@ -1555,6 +1575,97 @@ def _suggest(self, trial_id): self.assertTrue(searcher.is_finished()) self.assertRaises(TuneError, runner.step) + def testSaveRestore(self): + """Creates trials of different status to test runner.save/restore.""" + ray.init(num_cpus=3) + tmpdir = tempfile.mkdtemp() + default_resources = Resources(cpu=1, gpu=0) + + runner = TrialRunner( + BasicVariantGenerator(), checkpoint_dir=tmpdir, checkpoint_freq=1) + trials = [ + Trial( + "__fake", + trial_id="trial_terminate", + stopping_criterion={"training_iteration": 1}, + checkpoint_freq=1, + resources=default_resources) + ] + runner.add_trial(trials[0]) + runner.step() # start + runner.step() + self.assertEquals(trials[0].status, Trial.TERMINATED) + + trials += [ + Trial( + "__fake", + trial_id="trial_fail", + stopping_criterion={"training_iteration": 3}, + checkpoint_freq=1, + config={"mock_error": True}, + resources=default_resources) + ] + runner.add_trial(trials[1]) + runner.step() + runner.step() + runner.step() + self.assertEquals(trials[1].status, Trial.ERROR) + + trials += [ + Trial( + "__fake", + trial_id="trial_succ", + stopping_criterion={"training_iteration": 2}, + checkpoint_freq=1, + resources=default_resources) + ] + runner.add_trial(trials[2]) + runner.step() + self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3) + self.assertEquals(trials[2].status, Trial.RUNNING) + + runner2 = TrialRunner(BasicVariantGenerator()) + runner2.restore(tmpdir) + for tid in ["trial_terminate", "trial_fail"]: + original_trial = runner.get_trial(tid) + restored_trial = runner2.get_trial(tid) + self.assertEqual(original_trial.status, restored_trial.status) + + restored_trial = runner2.get_trial("trial_succ") + self.assertEqual(Trial.PENDING, restored_trial.status) + + runner2.step() + runner2.step() + runner2.step() + self.assertRaises(TuneError, runner2.step) + shutil.rmtree(tmpdir) + + def testNoSave(self): + """Check that non-checkpointing trials are not saved.""" + ray.init(num_cpus=3) + tmpdir = tempfile.mkdtemp() + default_resources = Resources(cpu=1, gpu=0) + + runner = TrialRunner( + BasicVariantGenerator(), checkpoint_dir=tmpdir, checkpoint_freq=1) + trials = [ + Trial( + "__fake", + trial_id="trial_terminate", + stopping_criterion={"training_iteration": 2}, + resources=default_resources) + ] + runner.add_trial(trials[0]) + runner.step() # start + runner.step() + + runner2 = TrialRunner(BasicVariantGenerator()) + runner2.restore(tmpdir) + self.assertEquals(len(runner2.get_trials()), 0) + runner2.step() + self.assertRaises(TuneError, runner2.step) + shutil.rmtree(tmpdir) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f60fd25f2dba..74449aa3feda 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -327,3 +327,27 @@ def __str__(self): if self.experiment_tag: identifier += "_" + self.experiment_tag return identifier + + def __getstate__(self): + if not self._checkpoint.storage == Checkpoint.DISK: + raise ValueError("Most recent checkpoint cannot be in-memory.") + state = self.__dict__.copy() + + if state["status"] == Trial.RUNNING: + state["status"] = Trial.PENDING + # Remove the unpicklable entries. + if state["result_logger"]: + state["result_logger"].flush() + state["_logger_started"] = True + else: + state["_logger_started"] = False + + state["result_logger"] = None + state["runner"] = None + return state + + def __setstate__(self, state): + logger_started = state.pop("_logger_started") + self.__dict__.update(state) + if logger_started: + self.init_logger() diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 063129780b47..7c0de0caed30 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -4,6 +4,7 @@ from __future__ import print_function import logging +import pickle from ray.tune.trial import Trial, Checkpoint @@ -15,7 +16,7 @@ class TrialExecutor(object): and starting/stopping trials. """ - def __init__(self, queue_trials=False): + def __init__(self, queue_trials=False, track_checkpoints=False): """Initializes a new TrialExecutor. Args: @@ -25,6 +26,44 @@ def __init__(self, queue_trials=False): automatic scale-up. """ self._queue_trials = queue_trials + self._track_checkpoints = track_checkpoints + self._checkpoints = {} + + def set_status(self, trial, status): + """Sets status and checkpoints metadata if needed. + + Only checkpoints metadata if trial status is a terminal condition. + PENDING, PAUSED, and RUNNING switches have checkpoints taken care of + in the TrialRunner. + + Args: + trial (Trial): Trial to checkpoint. + status (Trial.status): Status to set trial to. + """ + trial.status = status + if status in [Trial.TERMINATED, Trial.ERROR]: + self.try_checkpoint_metadata(trial) + + def try_checkpoint_metadata(self, trial): + """Checkpoints metadata if current session and trial allow. + + Args: + trial (Trial): Trial to checkpoint. + """ + if self._track_checkpoints and trial.checkpoint_freq > 0: + if trial._checkpoint.storage == Checkpoint.MEMORY: + logger.debug("Not saving data for trial w/ memory checkpoint.") + return + try: + logger.debug("Saving trial metadata.") + metadata = pickle.dumps(trial) + self._checkpoints[trial.trial_id] = metadata + except ValueError: + logger.exception("Error checkpointing trial metadata.") + + def get_checkpoints(self): + """Returns a copy of mapping of the trial ID to pickled metadata.""" + return self._checkpoints.copy() def has_resources(self, resources): """Returns whether this runner has at least the specified resources.""" @@ -71,15 +110,15 @@ def pause_trial(self, trial): try: self.save(trial, Checkpoint.MEMORY) self.stop_trial(trial, stop_logger=False) - trial.status = Trial.PAUSED + self.set_status(trial, Trial.PAUSED) except Exception: logger.exception("Error pausing runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) def unpause_trial(self, trial): """Sets PAUSED trial to pending to allow scheduler to start.""" assert trial.status == Trial.PAUSED, trial.status - trial.status = Trial.PENDING + self.set_status(trial, Trial.PENDING) def resume_trial(self, trial): """Resumes PAUSED trials. This is a blocking call.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 84457ff8d9e9..80a89901ce9f 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -5,6 +5,7 @@ import collections import logging import os +import pickle import re import time import traceback @@ -53,6 +54,8 @@ def __init__(self, search_alg, scheduler=None, launch_web_server=False, + checkpoint_dir=None, + checkpoint_freq=0, server_port=TuneServer.DEFAULT_PORT, verbose=True, queue_trials=False, @@ -64,6 +67,9 @@ def __init__(self, Trial objects. scheduler (TrialScheduler): Defaults to FIFOScheduler. launch_web_server (bool): Flag for starting TuneServer + checkpoint_dir (str): Path where global checkpoints are stored. + checkpoint_freq (int): How many steps between global + checkpoints. A value of 0 (default) disables checkpointing. server_port (int): Port number for launching TuneServer verbose (bool): Flag for verbosity. If False, trial results will not be output. @@ -77,19 +83,68 @@ def __init__(self, self._scheduler_alg = scheduler or FIFOScheduler() self._trials = [] self.trial_executor = trial_executor or \ - RayTrialExecutor(queue_trials=queue_trials) + RayTrialExecutor(queue_trials=queue_trials, + track_checkpoints=checkpoint_freq > 0) # For debugging, it may be useful to halt trials after some time has # elapsed. TODO(ekl) consider exposing this in the API. self._global_time_limit = float( os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf'))) self._total_time = 0 + self._iteration = 0 self._server = None if launch_web_server: self._server = TuneServer(self, server_port) self._stop_queue = [] self._verbose = verbose self._queue_trials = queue_trials + self._checkpoint_dir = checkpoint_dir + self._checkpoint_freq = checkpoint_freq + self._trial_checkpoints = {} + + def save(self): + """Saves all trial checkpoints to `self._checkpoint_dir.`""" + checkpoint_dir = self._checkpoint_dir + if not os.path.exists(checkpoint_dir): + logger.debug("Checkpoint directory newly created.") + os.makedirs(checkpoint_dir) + logger.warning("Search Algorithm and Scheduler not checkpointed.") + # search_alg_checkpoint = self._search_alg.save(checkpoint_dir) + # scheduler_alg_checkpoint = self._scheduler_alg.save(checkpoint_dir) + runner_state = { + "checkpoints": list( + self.trial_executor.get_checkpoints().values()), + "total_time": self._total_time, + "stop_queue": self._stop_queue + } + with open(os.path.join(checkpoint_dir, "experiment.state"), "wb") as f: + pickle.dump(runner_state, f) + + return checkpoint_dir + + def restore(self, checkpoint_dir): + """Restores all checkpointed trials from previous run. + + Requires user to manually re-register their objects. Also stops + all ongoing trials. + + Args: + checkpoint_dir (str): Path to checkpoint (previously specified). + """ + logger.debug("Stopping all trials.") + for trial in self._trials: + self.stop_trial(trial) + + with open(os.path.join(checkpoint_dir, "experiment.state"), "rb") as f: + runner_state = pickle.load(f) + + logger.info("Replacing all trials with checkpoint state.") + for ckpt in runner_state["checkpoints"]: + trial = pickle.loads(ckpt) + self.add_trial(trial) + + self._total_time = runner_state["total_time"] + self._stop_queue = runner_state["stop_queue"] def is_finished(self): """Returns whether all trials have finished running.""" @@ -136,6 +191,13 @@ def step(self): "There are paused trials, but no more pending " "trials with sufficient resources.") + if self._checkpoint_freq: + if (self._iteration % self._checkpoint_freq == 0 + or self.is_finished()): + self.save() + + self._iteration += 1 + if self._server: self._process_requests() @@ -165,6 +227,7 @@ def add_trial(self, trial): """ trial.set_verbose(self._verbose) self._scheduler_alg.on_trial_add(self, trial) + self._checkpoint_if_needed(trial) self._trials.append(trial) def debug_string(self, max_debug=MAX_DEBUG_TRIALS): @@ -310,6 +373,7 @@ def _checkpoint_if_needed(self, trial): # Save trial runtime if possible if hasattr(trial, "runner") and trial.runner: self.trial_executor.save(trial, storage=Checkpoint.DISK) + self.trial_executor.try_checkpoint_metadata(trial) def _try_recover(self, trial, error_msg): """Tries to recover trial. @@ -344,11 +408,11 @@ def _try_recover(self, trial, error_msg): def _requeue_trial(self, trial): """Notification to TrialScheduler and requeue trial. - This does not notify the SearchAlgorithm because - the function evaluation is still in progress. + This does not notify the SearchAlgorithm because the function + evaluation is still in progress. """ self._scheduler_alg.on_trial_error(self, trial) - trial.status = Trial.PENDING + self.trial_executor.set_status(trial, Trial.PENDING) self._scheduler_alg.on_trial_add(self, trial) def _update_trial_queue(self, blocking=False, timeout=600): diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 335660ecb836..fa773572c110 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -3,6 +3,7 @@ from __future__ import print_function import logging +import os import time from ray.tune.error import TuneError @@ -35,6 +36,9 @@ def _make_scheduler(args): def run_experiments(experiments=None, search_alg=None, scheduler=None, + restore_from_path=None, + checkpoint_dir=None, + checkpoint_freq=0, with_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, @@ -51,6 +55,12 @@ def run_experiments(experiments=None, scheduler (TrialScheduler): Scheduler for executing the experiment. Choose among FIFO (default), MedianStopping, AsyncHyperBand, and HyperBand. + restore_from_path (str): Restores experiment execution state to + given checkpoint path. + checkpoint_dir (str): Path at which experiment checkpoints are stored. + Defaults to DEFAULT_RESULTS_DIR. + checkpoint_freq (int): How many trial results between + checkpoints. A value of 0 (default) disables checkpointing. with_server (bool): Starts a background Tune server. Needed for using the Client API. server_port (int): Port number for launching TuneServer. @@ -95,14 +105,24 @@ def run_experiments(experiments=None, runner = TrialRunner( search_alg, scheduler=scheduler, + checkpoint_dir=checkpoint_dir or DEFAULT_RESULTS_DIR, + checkpoint_freq=checkpoint_freq, launch_web_server=with_server, server_port=server_port, verbose=verbose, queue_trials=queue_trials, trial_executor=trial_executor) - logger.info(runner.debug_string(max_debug=99999)) + if restore_from_path: + if not os.path.exists(restore_from_path): + raise ValueError("Provided path invalid: %s" % restore_from_path) + assert experiments is None, ( + "Simultaneous starting experiments and restoring not supported.") + runner.restore(restore_from_path) + else: + search_alg.add_configurations(experiments) + logger.info(runner.debug_string(max_debug=99999)) last_debug = 0 while not runner.is_finished(): runner.step()