diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index c4f846401b3c..0cd6572d44fc 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -296,6 +296,31 @@ of a trial, you can additionally set the checkpoint_at_end to True. An example i }, }) +Recovering From Failures (Experimental) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Tune automatically persists the progress of your experiments, so if an experiment crashes or is otherwise cancelled, it can be resumed after prompting. The default setting of `resume=None` will cause Tune to prompt you for whether you want to resume. Prompting can be turned off with ``resume=True``. If ``resume=False``, a new experiment will be created instead. You can always force a new experiment to be created by changing the experiment name. + +Note that trials will be restored to their last checkpoint. If trial checkpointing is not enabled, unfinished trials will be restarted from scratch. + +E.g.: + +.. code-block:: python + + run_experiments({ + "my_experiment_name": { + "run": my_trainable + "checkpoint_freq": 10, + "local_dir": "~/path/to/results" + }, + }, resume=True) + + +Upon a second run, this will restore the entire experiment state from ``~/path/to/results/my_experiment_name``. Importantly, any changes to the experiment specification upon resume will be ignored. + +This feature is still experimental, so any provided Trial Scheduler or Search Algorithm will not be preserved. Only ``FIFOScheduler`` and ``BasicVariantGenerator`` will be supported. + + Handling Large Datasets ----------------------- diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index d9f7cf58e0b4..5e03dfa588e4 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -9,7 +9,8 @@ import ray from ray.test.cluster_utils import Cluster -from ray.tune.config_parser import make_parser, resources_to_json +from ray.tune.config_parser import make_parser +from ray.tune.trial import resources_to_json from ray.tune.tune import _make_scheduler, run_experiments EXAMPLE_USAGE = """ @@ -70,6 +71,10 @@ def create_parser(parser_creator=None): default="default", type=str, help="Name of the subdirectory under `local_dir` to put results in.") + parser.add_argument( + "--resume", + action="store_true", + help="Whether to attempt to resume previous Tune experiments.") parser.add_argument( "--env", default=None, type=str, help="The gym environment to use.") parser.add_argument( @@ -138,7 +143,8 @@ def run(args, parser): run_experiments( experiments, scheduler=_make_scheduler(args), - queue_trials=args.queue_trials) + queue_trials=args.queue_trials, + resume=args.resume) if __name__ == "__main__": diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index aff302efc434..146000daef82 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -51,7 +51,9 @@ def connect(self, head_node_args): assert not self.connected redis_password = head_node_args.get("redis_password") output_info = ray.init( - redis_address=self.redis_address, redis_password=redis_password) + ignore_reinit_error=True, + redis_address=self.redis_address, + redis_password=redis_password) logger.info(output_info) self.connected = True diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 22adfc397ecc..aa0caa437ae6 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -11,40 +11,10 @@ from ray.tune import TuneError from ray.tune.result import DEFAULT_RESULTS_DIR -from ray.tune.trial import Resources, Trial +from ray.tune.trial import Trial, json_to_resources from ray.tune.logger import _SafeFallbackEncoder -def json_to_resources(data): - if data is None or data == "null": - return None - if isinstance(data, string_types): - data = json.loads(data) - for k in data: - if k in ["driver_cpu_limit", "driver_gpu_limit"]: - raise TuneError( - "The field `{}` is no longer supported. Use `extra_cpu` " - "or `extra_gpu` instead.".format(k)) - if k not in Resources._fields: - raise TuneError( - "Unknown resource type {}, must be one of {}".format( - k, Resources._fields)) - return Resources( - data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0), - data.get("extra_gpu", 0)) - - -def resources_to_json(resources): - if resources is None: - return None - return { - "cpu": resources.cpu, - "gpu": resources.gpu, - "extra_cpu": resources.extra_cpu, - "extra_gpu": resources.extra_gpu, - } - - def make_parser(parser_creator=None, **kwargs): """Returns a base argument parser for the ray.tune tool. diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 5859487527db..4471edd2b311 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -4,11 +4,11 @@ import copy import logging +import os import six import types from ray.tune.error import TuneError -from ray.tune.log_sync import validate_sync_function from ray.tune.registry import register_trainable from ray.tune.result import DEFAULT_RESULTS_DIR @@ -122,7 +122,6 @@ def __init__(self, restore=None, repeat=None, trial_resources=None): - validate_sync_function(sync_function) if sync_function: assert upload_dir, "Need `upload_dir` if sync_function given." @@ -134,16 +133,16 @@ def __init__(self, resources_per_trial = trial_resources spec = { - "run": self._register_if_needed(run), + "run": Experiment._register_if_needed(run), "stop": stop or {}, "config": config or {}, "resources_per_trial": resources_per_trial, "num_samples": num_samples, - "local_dir": local_dir or DEFAULT_RESULTS_DIR, + "local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR), "upload_dir": upload_dir or "", # argparse converts None to "null" "trial_name_creator": trial_name_creator, "custom_loggers": custom_loggers, - "sync_function": sync_function or "", # See `upload_dir`. + "sync_function": sync_function, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, "max_failures": max_failures, @@ -180,7 +179,8 @@ def from_json(cls, name, spec): raise TuneError("Improper argument from JSON: {}.".format(spec)) return exp - def _register_if_needed(self, run_object): + @classmethod + def _register_if_needed(cls, run_object): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Does not diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 972705487efb..3341c3c7601a 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -106,19 +106,19 @@ def _init(self): self.logdir, self.uri, sync_function=self._sync_function) def on_result(self, result): - for logger in self._loggers: - logger.on_result(result) + for _logger in self._loggers: + _logger.on_result(result) self._log_syncer.set_worker_ip(result.get(NODE_IP)) self._log_syncer.sync_if_needed() def close(self): - for logger in self._loggers: - logger.close() + for _logger in self._loggers: + _logger.close() self._log_syncer.sync_now(force=True) def flush(self): - for logger in self._loggers: - logger.flush() + for _logger in self._loggers: + _logger.flush() self._log_syncer.sync_now(force=True) self._log_syncer.wait() @@ -142,7 +142,7 @@ def _init(self): with open(config_pkl, "wb") as f: cloudpickle.dump(self.config, f) local_file = os.path.join(self.logdir, "result.json") - self.local_out = open(local_file, "w") + self.local_out = open(local_file, "a") def on_result(self, result): json.dump(result, self, cls=_SafeFallbackEncoder) @@ -152,6 +152,9 @@ def write(self, b): self.local_out.write(b) self.local_out.flush() + def flush(self): + self.local_out.flush() + def close(self): self.local_out.close() @@ -182,7 +185,8 @@ def on_result(self, result): for k in [ "config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION ]: - del tmp[k] # not useful to tf log these + if k in tmp: + del tmp[k] # not useful to tf log these values = to_tf_values(tmp, ["ray", "tune"]) train_stats = tf.Summary(value=values) t = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] @@ -205,15 +209,21 @@ class _VisKitLogger(Logger): def _init(self): """CSV outputted with Headers as first set of results.""" # Note that we assume params.json was already created by JsonLogger - self._file = open(os.path.join(self.logdir, "progress.csv"), "w") + progress_file = os.path.join(self.logdir, "progress.csv") + self._continuing = os.path.exists(progress_file) + self._file = open(progress_file, "a") self._csv_out = None def on_result(self, result): if self._csv_out is None: self._csv_out = csv.DictWriter(self._file, result.keys()) - self._csv_out.writeheader() + if not self._continuing: + self._csv_out.writeheader() self._csv_out.writerow(result.copy()) + def flush(self): + self._file.flush() + def close(self): self._file.close() diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 6b107b17c82f..3706d00f0a52 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -38,6 +38,8 @@ def _setup_runner(self, trial): num_gpus=trial.resources.gpu)(trial._get_trainable_cls()) trial.init_logger() + # We checkpoint metadata here to try mitigating logdir duplication + self.try_checkpoint_metadata(trial) remote_logdir = trial.logdir def logger_creator(config): @@ -60,7 +62,7 @@ def _train(self, trial): def _start_trial(self, trial, checkpoint=None): prior_status = trial.status - trial.status = Trial.RUNNING + self.set_status(trial, Trial.RUNNING) trial.runner = self._setup_runner(trial) if not self.restore(trial, checkpoint): return @@ -87,10 +89,13 @@ def _stop_trial(self, trial, error=False, error_msg=None, stop_logger (bool): Whether to shut down the trial logger. """ + if stop_logger: + trial.close_logger() + if error: - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) else: - trial.status = Trial.TERMINATED + self.set_status(trial, Trial.TERMINATED) try: trial.write_error_log(error_msg) @@ -103,13 +108,10 @@ def _stop_trial(self, trial, error=False, error_msg=None, stop_tasks, num_returns=2, timeout=250) except Exception: logger.exception("Error stopping runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) finally: trial.runner = None - if stop_logger: - trial.close_logger() - def start_trial(self, trial, checkpoint=None): """Starts the trial. @@ -302,7 +304,7 @@ def restore(self, trial, checkpoint=None): return True if trial.runner is None: logger.error("Unable to restore - no runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) return False try: value = checkpoint.value @@ -316,5 +318,5 @@ def restore(self, trial, checkpoint=None): return True except Exception: logger.exception("Error restoring runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) return False diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index 59f12181b8ff..75b1a4c545fa 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -2,7 +2,10 @@ from __future__ import division from __future__ import print_function +import inspect import json +import time +import os import pytest try: import pytest_timeout @@ -10,14 +13,37 @@ pytest_timeout = None import ray +from ray import tune from ray.rllib import _register_all from ray.test.cluster_utils import Cluster +from ray.test.test_utils import run_string_as_driver_nonblocking from ray.tune.error import TuneError +from ray.tune.experiment import Experiment from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator +class _Fail(tune.Trainable): + """Fails on the 4th iteration.""" + + def _setup(self, config): + self.state = {"hi": 0} + + def _train(self): + self.state["hi"] += 1 + time.sleep(0.5) + if self.state["hi"] >= 4: + assert False + return {} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + def _start_new_cluster(): cluster = Cluster( initialize_head=True, @@ -36,6 +62,7 @@ def _start_new_cluster(): @pytest.fixture def start_connected_cluster(): # Start the Ray processes. + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" cluster = _start_new_cluster() yield cluster # The code after the yield will run as teardown code. @@ -47,6 +74,7 @@ def start_connected_cluster(): def start_connected_emptyhead_cluster(): """Starts head with no resources.""" + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" cluster = Cluster( initialize_head=True, connect=True, @@ -66,7 +94,6 @@ def start_connected_emptyhead_cluster(): def test_counting_resources(start_connected_cluster): """Tests that Tune accounting is consistent with actual cluster.""" - cluster = start_connected_cluster nodes = [] assert ray.global_state.cluster_resources()["CPU"] == 1 @@ -240,3 +267,231 @@ def test_trial_requeue(start_connected_emptyhead_cluster): with pytest.raises(TuneError): runner.step() + + +def test_cluster_down_simple(start_connected_cluster, tmpdir): + """Tests that TrialRunner save/restore works on cluster shutdown.""" + cluster = start_connected_cluster + cluster.add_node(resources=dict(CPU=1)) + assert cluster.wait_for_nodes() + + dirpath = str(tmpdir) + runner = TrialRunner( + BasicVariantGenerator(), metadata_checkpoint_dir=dirpath) + kwargs = { + "stopping_criterion": { + "training_iteration": 2 + }, + "checkpoint_freq": 1, + "max_failures": 1 + } + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] + for t in trials: + runner.add_trial(t) + + runner.step() # start + runner.step() # start2 + runner.step() # step + assert all(t.status == Trial.RUNNING for t in runner.get_trials()) + runner.checkpoint() + + cluster.shutdown() + ray.shutdown() + + cluster = _start_new_cluster() + runner = TrialRunner.restore(dirpath) + runner.step() # start + runner.step() # start2 + + for i in range(3): + runner.step() + + with pytest.raises(TuneError): + runner.step() + + assert all(t.status == Trial.TERMINATED for t in runner.get_trials()) + cluster.shutdown() + + +def test_cluster_down_full(start_connected_cluster, tmpdir): + """Tests that run_experiment restoring works on cluster shutdown.""" + cluster = start_connected_cluster + dirpath = str(tmpdir) + + exp1_args = dict( + run="__fake", + stop=dict(training_iteration=3), + local_dir=dirpath, + checkpoint_freq=1) + exp2_args = dict(run="__fake", stop=dict(training_iteration=3)) + exp3_args = dict( + run="__fake", + stop=dict(training_iteration=3), + config=dict(mock_error=True)) + exp4_args = dict( + run="__fake", + stop=dict(training_iteration=3), + config=dict(mock_error=True), + checkpoint_freq=1) + all_experiments = { + "exp1": exp1_args, + "exp2": exp2_args, + "exp3": exp3_args, + "exp4": exp4_args + } + + tune.run_experiments(all_experiments, raise_on_failed_trial=False) + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + + trials = tune.run_experiments( + all_experiments, resume=True, raise_on_failed_trial=False) + assert len(trials) == 4 + assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials) + cluster.shutdown() + + +def test_cluster_rllib_restore(start_connected_cluster, tmpdir): + cluster = start_connected_cluster + dirpath = str(tmpdir) + script = """ +import time +import ray +from ray import tune + +ray.init(redis_address="{redis_address}") + +kwargs = dict( + run="PG", + env="CartPole-v1", + stop=dict(training_iteration=10), + local_dir="{checkpoint_dir}", + checkpoint_freq=1, + max_failures=1) + +tune.run_experiments( + dict(experiment=kwargs), + raise_on_failed_trial=False) +""".format( + redis_address=cluster.redis_address, checkpoint_dir=dirpath) + run_string_as_driver_nonblocking(script) + # Wait until the right checkpoint is saved. + # The trainable returns every 0.5 seconds, so this should not miss + # the checkpoint. + metadata_checkpoint_dir = os.path.join(dirpath, "experiment") + for i in range(50): + if os.path.exists( + os.path.join(metadata_checkpoint_dir, + TrialRunner.CKPT_FILE_NAME)): + # Inspect the internal trialrunner + runner = TrialRunner.restore(metadata_checkpoint_dir) + trials = runner.get_trials() + last_res = trials[0].last_result + if last_res is not None and last_res["training_iteration"]: + break + time.sleep(0.2) + + if not os.path.exists( + os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + raise RuntimeError("Checkpoint file didn't appear.") + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + cluster.wait_for_nodes() + + # Restore properly from checkpoint + trials2 = tune.run_experiments( + { + "experiment": { + "run": "PG", + "checkpoint_freq": 1, + "local_dir": dirpath + } + }, + resume=True) + assert all(t.status == Trial.TERMINATED for t in trials2) + cluster.shutdown() + + +def test_cluster_interrupt(start_connected_cluster, tmpdir): + """Tests run_experiment on cluster shutdown even with atypical trial. + + The trial fails on the 4th step, and the checkpointing happens on + the 3rd step, so restoring should actually launch the trial again. + """ + cluster = start_connected_cluster + dirpath = str(tmpdir) + script = """ +import time +import ray +from ray import tune + +ray.init(redis_address="{redis_address}") + +{fail_class_code} + +kwargs = dict( + run={fail_class}, + stop=dict(training_iteration=5), + local_dir="{checkpoint_dir}", + checkpoint_freq=1, + max_failures=1) + +tune.run_experiments( + dict(experiment=kwargs), + raise_on_failed_trial=False) +""".format( + redis_address=cluster.redis_address, + checkpoint_dir=dirpath, + fail_class_code=inspect.getsource(_Fail), + fail_class=_Fail.__name__) + run_string_as_driver_nonblocking(script) + + # Wait until the right checkpoint is saved. + # The trainable returns every 0.5 seconds, so this should not miss + # the checkpoint. + metadata_checkpoint_dir = os.path.join(dirpath, "experiment") + for i in range(50): + if os.path.exists( + os.path.join(metadata_checkpoint_dir, + TrialRunner.CKPT_FILE_NAME)): + # Inspect the internal trialrunner + runner = TrialRunner.restore(metadata_checkpoint_dir) + trials = runner.get_trials() + last_res = trials[0].last_result + if last_res is not None and last_res["training_iteration"] == 3: + break + time.sleep(0.2) + + if not os.path.exists( + os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + raise RuntimeError("Checkpoint file didn't appear.") + + ray.shutdown() + cluster.shutdown() + cluster = _start_new_cluster() + Experiment._register_if_needed(_Fail) + + # Inspect the internal trialrunner + runner = TrialRunner.restore(metadata_checkpoint_dir) + trials = runner.get_trials() + assert trials[0].last_result["training_iteration"] == 3 + assert trials[0].status == Trial.PENDING + + # Restore properly from checkpoint + trials2 = tune.run_experiments( + { + "experiment": { + "run": _Fail, + "local_dir": dirpath, + "checkpoint_freq": 1 + } + }, + resume=True, + raise_on_failed_trial=False) + assert all(t.status == Trial.ERROR for t in trials2) + assert {t.trial_id for t in trials2} == {t.trial_id for t in trials} + cluster.shutdown() diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 40b6575ce2a0..2faf30c1e278 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -3,7 +3,9 @@ from __future__ import print_function import os +import shutil import sys +import tempfile import time import unittest @@ -36,6 +38,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def setUp(self): + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" ray.init(num_cpus=4, num_gpus=0) def tearDown(self): @@ -541,6 +544,7 @@ def _restore(self, state): class RunExperimentTest(unittest.TestCase): def setUp(self): + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" ray.init() def tearDown(self): @@ -613,29 +617,6 @@ def train(config, reporter): self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) - def testSpecifyAlgorithm(self): - """Tests run_experiments works without specifying experiment.""" - - def train(config, reporter): - for i in range(100): - reporter(timesteps_total=i) - - register_trainable("f1", train) - - alg = BasicVariantGenerator() - alg.add_configurations({ - "foo": { - "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } - } - }) - trials = run_experiments(search_alg=alg) - for trial in trials: - self.assertEqual(trial.status, Trial.TERMINATED) - self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99) - def testAutoregisterTrainable(self): def train(config, reporter): for i in range(100): @@ -777,6 +758,7 @@ def sync_func(local, remote): class VariantGeneratorTest(unittest.TestCase): def setUp(self): + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" ray.init() def tearDown(self): @@ -966,6 +948,9 @@ def on_trial_complete(self, trial_id, error=False, **kwargs): class TrialRunnerTest(unittest.TestCase): + def setUp(self): + os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" + def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects @@ -1650,6 +1635,116 @@ def _suggest(self, trial_id): self.assertTrue(searcher.is_finished()) self.assertRaises(TuneError, runner.step) + def testTrialSaveRestore(self): + """Creates different trials to test runner.checkpoint/restore.""" + ray.init(num_cpus=3) + tmpdir = tempfile.mkdtemp() + + runner = TrialRunner( + BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + trials = [ + Trial( + "__fake", + trial_id="trial_terminate", + stopping_criterion={"training_iteration": 1}, + checkpoint_freq=1) + ] + runner.add_trial(trials[0]) + runner.step() # start + runner.step() + self.assertEquals(trials[0].status, Trial.TERMINATED) + + trials += [ + Trial( + "__fake", + trial_id="trial_fail", + stopping_criterion={"training_iteration": 3}, + checkpoint_freq=1, + config={"mock_error": True}) + ] + runner.add_trial(trials[1]) + runner.step() + runner.step() + runner.step() + self.assertEquals(trials[1].status, Trial.ERROR) + + trials += [ + Trial( + "__fake", + trial_id="trial_succ", + stopping_criterion={"training_iteration": 2}, + checkpoint_freq=1) + ] + runner.add_trial(trials[2]) + runner.step() + self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3) + self.assertEquals(trials[2].status, Trial.RUNNING) + + runner2 = TrialRunner.restore(tmpdir) + for tid in ["trial_terminate", "trial_fail"]: + original_trial = runner.get_trial(tid) + restored_trial = runner2.get_trial(tid) + self.assertEqual(original_trial.status, restored_trial.status) + + restored_trial = runner2.get_trial("trial_succ") + self.assertEqual(Trial.PENDING, restored_trial.status) + + runner2.step() + runner2.step() + runner2.step() + self.assertRaises(TuneError, runner2.step) + shutil.rmtree(tmpdir) + + def testTrialNoSave(self): + """Check that non-checkpointing trials are not saved.""" + ray.init(num_cpus=3) + tmpdir = tempfile.mkdtemp() + + runner = TrialRunner( + BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + + runner.add_trial( + Trial( + "__fake", + trial_id="non_checkpoint", + stopping_criterion={"training_iteration": 2})) + + while not all(t.status == Trial.TERMINATED + for t in runner.get_trials()): + runner.step() + + runner.add_trial( + Trial( + "__fake", + trial_id="checkpoint", + checkpoint_at_end=True, + stopping_criterion={"training_iteration": 2})) + + while not all(t.status == Trial.TERMINATED + for t in runner.get_trials()): + runner.step() + + runner.add_trial( + Trial( + "__fake", + trial_id="pending", + stopping_criterion={"training_iteration": 2})) + + runner.step() + runner.step() + + runner2 = TrialRunner.restore(tmpdir) + new_trials = runner2.get_trials() + self.assertEquals(len(new_trials), 3) + self.assertTrue( + runner2.get_trial("non_checkpoint").status == Trial.TERMINATED) + self.assertTrue( + runner2.get_trial("checkpoint").status == Trial.TERMINATED) + self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) + self.assertTrue(runner2.get_trial("pending").last_result is None) + runner2.step() + shutil.rmtree(tmpdir) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index e8e2f20544b4..9d463850f143 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -578,6 +578,7 @@ def __init__(self, i, config): self.logger_running = False self.restored_checkpoint = None self.resources = Resources(1, 0) + self.trial_name = None class PopulationBasedTestingSuite(unittest.TestCase): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 5824c5221ff5..0b50b55216e4 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -10,6 +10,7 @@ import logging import os import pickle +from six import string_types import shutil import tempfile import time @@ -216,10 +217,11 @@ def save(self, checkpoint_dir=None): checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, "checkpoint_{}".format(self._iteration)) - os.makedirs(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) checkpoint = self._save(checkpoint_dir) saved_as_dict = False - if isinstance(checkpoint, str): + if isinstance(checkpoint, string_types): if (not checkpoint.startswith(checkpoint_dir) or checkpoint == checkpoint_dir): raise ValueError( @@ -237,7 +239,9 @@ def save(self, checkpoint_dir=None): with open(checkpoint_path, "wb") as f: pickle.dump(checkpoint, f) else: - raise ValueError("Return value from `_save` must be dict or str.") + raise ValueError( + "`_save` must return a dict or string type: {}".format( + str(type(checkpoint)))) pickle.dump({ "experiment_id": self._experiment_id, "iteration": self._iteration, diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 3a766a400f32..66406231a108 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -3,15 +3,22 @@ from __future__ import print_function from collections import namedtuple +import ray.cloudpickle as cloudpickle +import copy from datetime import datetime import logging +import json import time import tempfile import os + +# For compatibility under py2 to consider unicode as str +from six import string_types from numbers import Number import ray from ray.tune import TuneError +from ray.tune.log_sync import validate_sync_function from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not @@ -19,7 +26,7 @@ import ray.tune.registry from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID, TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL) -from ray.utils import random_string, binary_to_hex +from ray.utils import random_string, binary_to_hex, hex_to_binary DEBUG_PRINT_INTERVAL = 5 MAX_LEN_IDENTIFIER = 130 @@ -66,6 +73,36 @@ def gpu_total(self): return self.gpu + self.extra_gpu +def json_to_resources(data): + if data is None or data == "null": + return None + if isinstance(data, string_types): + data = json.loads(data) + for k in data: + if k in ["driver_cpu_limit", "driver_gpu_limit"]: + raise TuneError( + "The field `{}` is no longer supported. Use `extra_cpu` " + "or `extra_gpu` instead.".format(k)) + if k not in Resources._fields: + raise TuneError( + "Unknown resource type {}, must be one of {}".format( + k, Resources._fields)) + return Resources( + data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0), + data.get("extra_gpu", 0)) + + +def resources_to_json(resources): + if resources is None: + return None + return { + "cpu": resources.cpu, + "gpu": resources.gpu, + "extra_cpu": resources.extra_cpu, + "extra_gpu": resources.extra_gpu, + } + + def has_trainable(trainable_name): return ray.tune.registry._global_registry.contains( ray.tune.registry.TRAINABLE_CLASS, trainable_name) @@ -133,12 +170,8 @@ def __init__(self, The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ - if not has_trainable(trainable_name): - # Make sure rllib agents are registered - from ray import rllib # noqa: F401 - if not has_trainable(trainable_name): - raise TuneError("Unknown trainable: " + trainable_name) + Trial._registration_check(trainable_name) # Trial config self.trainable_name = trainable_name self.config = config or {} @@ -149,14 +182,15 @@ def __init__(self, or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir - self.trial_name_creator = trial_name_creator self.custom_loggers = custom_loggers self.sync_function = sync_function + validate_sync_function(sync_function) self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = None + self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self._checkpoint = Checkpoint( @@ -170,6 +204,18 @@ def __init__(self, self.error_file = None self.num_failures = 0 + self.trial_name = None + if trial_name_creator: + self.trial_name = trial_name_creator(self) + + @classmethod + def _registration_check(cls, trainable_name): + if not has_trainable(trainable_name): + # Make sure rllib agents are registered + from ray import rllib # noqa: F401 + if not has_trainable(trainable_name): + raise TuneError("Unknown trainable: " + trainable_name) + @classmethod def generate_id(cls): return binary_to_hex(random_string())[:8] @@ -180,10 +226,14 @@ def init_logger(self): if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) - self.logdir = tempfile.mkdtemp( - prefix="{}_{}".format( - str(self)[:MAX_LEN_IDENTIFIER], date_str()), - dir=self.local_dir) + if not self.logdir: + self.logdir = tempfile.mkdtemp( + prefix="{}_{}".format( + str(self)[:MAX_LEN_IDENTIFIER], date_str()), + dir=self.local_dir) + elif not os.path.exists(self.logdir): + os.makedirs(self.logdir) + self.result_logger = UnifiedLogger( self.config, self.logdir, @@ -307,6 +357,7 @@ def update_last_result(self, result, terminate=False): pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result + self.last_update_time = time.time() self.result_logger.on_result(self.last_result) def _get_trainable_cls(self): @@ -327,8 +378,8 @@ def __str__(self): Can be overriden with a custom string creator. """ - if self.trial_name_creator: - return self.trial_name_creator(self) + if self.trial_name: + return self.trial_name if "env" in self.config: env = self.config["env"] @@ -340,3 +391,48 @@ def __str__(self): if self.experiment_tag: identifier += "_" + self.experiment_tag return identifier.replace("/", "_") + + def __getstate__(self): + """Memento generator for Trial. + + Sets RUNNING trials to PENDING, and flushes the result logger. + Note this can only occur if the trial holds a DISK checkpoint. + """ + assert self._checkpoint.storage == Checkpoint.DISK, ( + "Checkpoint must not be in-memory.") + state = self.__dict__.copy() + state["resources"] = resources_to_json(self.resources) + + pickle_data = { + "_checkpoint": self._checkpoint, + "config": self.config, + "custom_loggers": self.custom_loggers, + "sync_function": self.sync_function + } + + for key, value in pickle_data.items(): + state[key] = binary_to_hex(cloudpickle.dumps(value)) + + state["runner"] = None + state["result_logger"] = None + if self.status == Trial.RUNNING: + state["status"] = Trial.PENDING + if self.result_logger: + self.result_logger.flush() + state["__logger_started__"] = True + else: + state["__logger_started__"] = False + return copy.deepcopy(state) + + def __setstate__(self, state): + logger_started = state.pop("__logger_started__") + state["resources"] = json_to_resources(state["resources"]) + for key in [ + "_checkpoint", "config", "custom_loggers", "sync_function" + ]: + state[key] = cloudpickle.loads(hex_to_binary(state[key])) + + self.__dict__.update(state) + Trial._registration_check(self.trainable_name) + if logger_started: + self.init_logger() diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 063129780b47..22d6d85eb78d 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -25,6 +25,41 @@ def __init__(self, queue_trials=False): automatic scale-up. """ self._queue_trials = queue_trials + self._cached_trial_state = {} + + def set_status(self, trial, status): + """Sets status and checkpoints metadata if needed. + + Only checkpoints metadata if trial status is a terminal condition. + PENDING, PAUSED, and RUNNING switches have checkpoints taken care of + in the TrialRunner. + + Args: + trial (Trial): Trial to checkpoint. + status (Trial.status): Status to set trial to. + """ + trial.status = status + if status in [Trial.TERMINATED, Trial.ERROR]: + self.try_checkpoint_metadata(trial) + + def try_checkpoint_metadata(self, trial): + """Checkpoints metadata. + + Args: + trial (Trial): Trial to checkpoint. + """ + if trial._checkpoint.storage == Checkpoint.MEMORY: + logger.debug("Not saving data for trial w/ memory checkpoint.") + return + try: + logger.debug("Saving trial metadata.") + self._cached_trial_state[trial.trial_id] = trial.__getstate__() + except Exception: + logger.exception("Error checkpointing trial metadata.") + + def get_checkpoints(self): + """Returns a copy of mapping of the trial ID to pickled metadata.""" + return self._cached_trial_state.copy() def has_resources(self, resources): """Returns whether this runner has at least the specified resources.""" @@ -71,15 +106,15 @@ def pause_trial(self, trial): try: self.save(trial, Checkpoint.MEMORY) self.stop_trial(trial, stop_logger=False) - trial.status = Trial.PAUSED + self.set_status(trial, Trial.PAUSED) except Exception: logger.exception("Error pausing runner.") - trial.status = Trial.ERROR + self.set_status(trial, Trial.ERROR) def unpause_trial(self, trial): """Sets PAUSED trial to pending to allow scheduler to start.""" assert trial.status == Trial.PAUSED, trial.status - trial.status = Trial.PENDING + self.set_status(trial, Trial.PENDING) def resume_trial(self, trial): """Resumes PAUSED trials. This is a blocking call.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 84457ff8d9e9..eddfbc488d8c 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -3,6 +3,7 @@ from __future__ import print_function import collections +import json import logging import os import re @@ -49,10 +50,13 @@ class TrialRunner(object): misleading benchmark results. """ + CKPT_FILE_NAME = "experiment_state.json" + def __init__(self, search_alg, scheduler=None, launch_web_server=False, + metadata_checkpoint_dir=None, server_port=TuneServer.DEFAULT_PORT, verbose=True, queue_trials=False, @@ -64,6 +68,8 @@ def __init__(self, Trial objects. scheduler (TrialScheduler): Defaults to FIFOScheduler. launch_web_server (bool): Flag for starting TuneServer + metadata_checkpoint_dir (str): Path where + global checkpoints are stored and restored from. server_port (int): Port number for launching TuneServer verbose (bool): Flag for verbosity. If False, trial results will not be output. @@ -75,7 +81,6 @@ def __init__(self, """ self._search_alg = search_alg self._scheduler_alg = scheduler or FIFOScheduler() - self._trials = [] self.trial_executor = trial_executor or \ RayTrialExecutor(queue_trials=queue_trials) @@ -84,12 +89,92 @@ def __init__(self, self._global_time_limit = float( os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float('inf'))) self._total_time = 0 + self._iteration = 0 + self._verbose = verbose + self._queue_trials = queue_trials + self._server = None + self._server_port = server_port if launch_web_server: - self._server = TuneServer(self, server_port) + self._server = TuneServer(self, self._server_port) + + self._trials = [] self._stop_queue = [] - self._verbose = verbose - self._queue_trials = queue_trials + self._metadata_checkpoint_dir = metadata_checkpoint_dir + + def checkpoint(self): + """Saves execution state to `self._metadata_checkpoint_dir`.""" + if not self._metadata_checkpoint_dir: + return + metadata_checkpoint_dir = self._metadata_checkpoint_dir + if not os.path.exists(metadata_checkpoint_dir): + os.makedirs(metadata_checkpoint_dir) + runner_state = { + "checkpoints": list( + self.trial_executor.get_checkpoints().values()), + "runner_data": self.__getstate__() + } + tmp_file_name = os.path.join(metadata_checkpoint_dir, + ".tmp_checkpoint") + with open(tmp_file_name, "w") as f: + json.dump(runner_state, f, indent=2) + + os.rename( + tmp_file_name, + os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)) + return metadata_checkpoint_dir + + @classmethod + def restore(cls, + metadata_checkpoint_dir, + search_alg=None, + scheduler=None, + trial_executor=None): + """Restores all checkpointed trials from previous run. + + Requires user to manually re-register their objects. Also stops + all ongoing trials. + + Args: + metadata_checkpoint_dir (str): Path to metadata checkpoints. + search_alg (SearchAlgorithm): Search Algorithm. Defaults to + BasicVariantGenerator. + scheduler (TrialScheduler): Scheduler for executing + the experiment. + trial_executor (TrialExecutor): Manage the execution of trials. + + Returns: + runner (TrialRunner): A TrialRunner to resume experiments from. + """ + with open( + os.path.join(metadata_checkpoint_dir, + TrialRunner.CKPT_FILE_NAME), "r") as f: + runner_state = json.load(f) + + logger.warning("".join([ + "Attempting to resume experiment from {}. ".format( + metadata_checkpoint_dir), "This feature is experimental, " + "and may not work with all search algorithms. ", + "This will ignore any new changes to the specification." + ])) + + from ray.tune.suggest import BasicVariantGenerator + runner = TrialRunner( + search_alg or BasicVariantGenerator(), + scheduler=scheduler, + trial_executor=trial_executor) + + runner.__setstate__(runner_state["runner_data"]) + + trials = [] + for trial_cp in runner_state["checkpoints"]: + new_trial = Trial(trial_cp["trainable_name"]) + new_trial.__setstate__(trial_cp) + trials += [new_trial] + for trial in sorted( + trials, key=lambda t: t.last_update_time, reverse=True): + runner.add_trial(trial) + return runner def is_finished(self): """Returns whether all trials have finished running.""" @@ -136,6 +221,12 @@ def step(self): "There are paused trials, but no more pending " "trials with sufficient resources.") + try: + self.checkpoint() + except Exception: + logger.exception("Trial Runner checkpointing failed.") + self._iteration += 1 + if self._server: self._process_requests() @@ -165,6 +256,7 @@ def add_trial(self, trial): """ trial.set_verbose(self._verbose) self._scheduler_alg.on_trial_add(self, trial) + self.trial_executor.try_checkpoint_metadata(trial) self._trials.append(trial) def debug_string(self, max_debug=MAX_DEBUG_TRIALS): @@ -279,14 +371,14 @@ def _process_events(self): result, terminate=(decision == TrialScheduler.STOP)) if decision == TrialScheduler.CONTINUE: - self._checkpoint_if_needed(trial) + self._checkpoint_trial_if_needed(trial) self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: # Checkpoint before ending the trial # if checkpoint_at_end experiment option is set to True - self._checkpoint_if_needed(trial) + self._checkpoint_trial_if_needed(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format( @@ -304,12 +396,13 @@ def _process_events(self): self.trial_executor.stop_trial( trial, error=True, error_msg=error_msg) - def _checkpoint_if_needed(self, trial): + def _checkpoint_trial_if_needed(self, trial): """Checkpoints trial based off trial.last_result.""" if trial.should_checkpoint(): # Save trial runtime if possible if hasattr(trial, "runner") and trial.runner: self.trial_executor.save(trial, storage=Checkpoint.DISK) + self.trial_executor.try_checkpoint_metadata(trial) def _try_recover(self, trial, error_msg): """Tries to recover trial. @@ -344,11 +437,11 @@ def _try_recover(self, trial, error_msg): def _requeue_trial(self, trial): """Notification to TrialScheduler and requeue trial. - This does not notify the SearchAlgorithm because - the function evaluation is still in progress. + This does not notify the SearchAlgorithm because the function + evaluation is still in progress. """ self._scheduler_alg.on_trial_error(self, trial) - trial.status = Trial.PENDING + self.trial_executor.set_status(trial, Trial.PENDING) self._scheduler_alg.on_trial_add(self, trial) def _update_trial_queue(self, blocking=False, timeout=600): @@ -417,3 +510,24 @@ def stop_trial(self, trial): error = True self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg) + + def __getstate__(self): + """Gets state for trial. + + Note that this is not used as a pickling override as + does not have all fields. + """ + state = self.__dict__.copy() + for k in [ + "_trials", "_stop_queue", "_server", "_search_alg", + "_scheduler_alg", "trial_executor" + ]: + del state[k] + state["launch_web_server"] = bool(self._server) + return state + + def __setstate__(self, state): + launch_web_server = state.pop("launch_web_server") + self.__dict__.update(state) + if launch_web_server: + self._server = TuneServer(self, self._server_port) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 335660ecb836..e65e10c7b402 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -2,10 +2,13 @@ from __future__ import division from __future__ import print_function +import click import logging +import os import time from ray.tune.error import TuneError +from ray.tune.experiment import convert_to_experiment_list from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.log_sync import wait_for_log_sync @@ -32,12 +35,30 @@ def _make_scheduler(args): args.scheduler, _SCHEDULERS.keys())) -def run_experiments(experiments=None, +def _find_checkpoint_dir(exp_list): + assert exp_list, "Experiments must be specified via `run_experiments`" + exp = exp_list[0] + # TODO(rliaw): Make sure this is resolved earlier. + return os.path.join(exp.spec["local_dir"], exp.name) + + +def try_restore_runner(checkpoint_dir, search_alg, scheduler, trial_executor): + new_runner = None + try: + new_runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler, + trial_executor) + except Exception: + logger.exception("Runner restore failed. Restarting experiment.") + return new_runner + + +def run_experiments(experiments, search_alg=None, scheduler=None, with_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, + resume=None, queue_trials=False, trial_executor=None, raise_on_failed_trial=True): @@ -55,6 +76,9 @@ def run_experiments(experiments=None, using the Client API. server_port (int): Port number for launching TuneServer. verbose (bool): How much output should be printed for each trial. + resume (bool|None): If checkpoint exists, the experiment will + resume from there. If resume is None, Tune will prompt if + checkpoint detected. queue_trials (bool): Whether to queue trials when the cluster does not currently have enough resources to launch one. This should be set to True when running on an autoscaling cluster to enable @@ -83,26 +107,64 @@ def run_experiments(experiments=None, List of Trial objects, holding data for each executed trial. """ + # This is important to do this here + # because it schematize the experiments + # and it conducts the implicit registration. + experiments = convert_to_experiment_list(experiments) + checkpoint_dir = _find_checkpoint_dir(experiments) + + runner = None + restore = False + + # TUNE_RESUME_PROMPT_OFF is for testing purposes and defaults + # `resume=False.` + if os.environ.get("TUNE_RESUME_PROMPT_OFF"): + resume = resume or False + + if os.path.exists( + os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + if resume: + restore = True + elif resume is None: + msg = ("Found incomplete experiment at {}. " + "Would you like to resume it?".format(checkpoint_dir)) + restore = click.confirm(msg, default=True) + if restore: + logger.info("Tip: to always resume, " + "pass resume=True to run_experiments()") + else: + logger.info("Tip: to always start a new experiment, " + "pass resume=False to run_experiments()") + else: + logger.info( + "Did not find checkpoint file in {}.".format(checkpoint_dir)) + + if restore: + runner = try_restore_runner(checkpoint_dir, search_alg, scheduler, + trial_executor) + else: + logger.info("Starting a new experiment.") - if scheduler is None: - scheduler = FIFOScheduler() + if not runner: + if scheduler is None: + scheduler = FIFOScheduler() - if search_alg is None: - search_alg = BasicVariantGenerator() + if search_alg is None: + search_alg = BasicVariantGenerator() - search_alg.add_configurations(experiments) + search_alg.add_configurations(experiments) - runner = TrialRunner( - search_alg, - scheduler=scheduler, - launch_web_server=with_server, - server_port=server_port, - verbose=verbose, - queue_trials=queue_trials, - trial_executor=trial_executor) + runner = TrialRunner( + search_alg, + scheduler=scheduler, + metadata_checkpoint_dir=checkpoint_dir, + launch_web_server=with_server, + server_port=server_port, + verbose=verbose, + queue_trials=queue_trials, + trial_executor=trial_executor) logger.info(runner.debug_string(max_debug=99999)) - last_debug = 0 while not runner.is_finished(): runner.step() diff --git a/test/multi_node_test.py b/test/multi_node_test.py index b25ea8295b31..e323751b2545 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -394,6 +394,7 @@ def train_func(config, reporter): # add a reporter arg time.sleep(0.1) reporter(timesteps_total=i, mean_accuracy=i+97) # report metrics +os.environ["TUNE_RESUME_PROMPT_OFF"] = "True" ray.init(redis_address="{}") ray.tune.register_trainable("train_func", train_func)