diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index e8ce405d9457..dd1e56f86083 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -259,7 +259,7 @@ of a trial, you can additionally set the checkpoint_at_end to True. An example i Recovering From Failures (Experimental) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tune automatically persists the progress of your experiments, so if an experiment crashes or is otherwise cancelled, it can be resumed with ``resume=True``. The default setting of ``resume=False`` creates a new experiment, and ``resume="prompt"`` will cause Tune to prompt you for whether you want to resume. You can always force a new experiment to be created by changing the experiment name. +Tune automatically persists the progress of your experiments, so if an experiment crashes or is otherwise cancelled, it can be resumed by passing one of True, False, "LOCAL", "REMOTE", or "PROMPT" to ``tune.run(resume=...)``. The default setting of ``resume=False`` creates a new experiment. ``resume="LOCAL"`` and ``resume=True`` restore the experiment from ``local_dir/[experiment_name]``. ``resume="REMOTE"`` syncs the upload dir down to the local dir and then restores the experiment from ``local_dir/experiment_name``. ``resume="PROMPT"`` will cause Tune to prompt you for whether you want to resume. You can always force a new experiment to be created by changing the experiment name. Note that trials will be restored to their last checkpoint. If trial checkpointing is not enabled, unfinished trials will be restarted from scratch. @@ -399,31 +399,39 @@ An example can be found in `logging_example.py `__. +By default, local syncing requires rsync to be installed. You can customize the sync command with the ``sync_to_driver`` argument in ``tune.run`` by providing either a function or a string. -If a string is provided, then it must include replacement fields ``{local_dir}`` and -``{remote_dir}``, like ``"aws s3 sync {local_dir} {remote_dir}"``. - -Alternatively, a function can be provided with the following signature (and must -be wrapped with ``tune.function``): +If a string is provided, then it must include replacement fields ``{source}`` and ``{target}``, like ``rsync -savz -e "ssh -i ssh_key.pem" {source} {target}``. Alternatively, a function can be provided with the following signature (and must be wrapped with ``tune.function``): .. code-block:: python - def custom_sync_func(local_dir, remote_dir): - sync_cmd = "aws s3 sync {local_dir} {remote_dir}".format( - local_dir=local_dir, - remote_dir=remote_dir) + def custom_sync_func(source, target): + sync_cmd = "rsync {source} {target}".format( + source=source, + target=target) sync_process = subprocess.Popen(sync_cmd, shell=True) sync_process.wait() tune.run( MyTrainableClass, name="experiment_name", - sync_function=tune.function(custom_sync_func) + sync_to_driver=tune.function(custom_sync_func), ) +When syncing results back to the driver, the source would be a path similar to ``ubuntu@192.0.0.1:/home/ubuntu/ray_results/trial1``, and the target would be a local path. +This custom sync command would be also be used in node failures, where the source argument would be the path to the trial directory and the target would be a remote path. The `sync_to_driver` would be invoked to push a checkpoint to new node for a queued trial to resume. + +If an upload directory is provided, Tune will automatically sync results to the given directory, natively supporting standard S3/gsutil commands. +You can customize this to specify arbitrary storages with the ``sync_to_cloud`` argument. This argument is similar to ``sync_to_cloud`` in that it supports strings with the same replacement fields and arbitrary functions. See `syncer.py `__ for implementation details. + +.. code-block:: python + + tune.run( + MyTrainableClass, + name="experiment_name", + sync_to_cloud=tune.function(custom_sync_func), + ) Tune Client API --------------- diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 6a51416994d6..449c0310505c 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -10,6 +10,7 @@ import ray from ray.tests.cluster_utils import Cluster from ray.tune.config_parser import make_parser +from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.trial import resources_to_json from ray.tune.tune import _make_scheduler, run_experiments @@ -71,6 +72,17 @@ def create_parser(parser_creator=None): default="default", type=str, help="Name of the subdirectory under `local_dir` to put results in.") + parser.add_argument( + "--local-dir", + default=DEFAULT_RESULTS_DIR, + type=str, + help="Local dir to save training results to. Defaults to '{}'.".format( + DEFAULT_RESULTS_DIR)) + parser.add_argument( + "--upload-dir", + default="", + type=str, + help="Optional URI to sync training results to (e.g. s3://bucket).") parser.add_argument( "--resume", action="store_true", diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 139ef6f82bc3..e1386ca6da81 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -10,7 +10,6 @@ from six import string_types from ray.tune import TuneError -from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.trial import Trial, json_to_resources from ray.tune.logger import _SafeFallbackEncoder @@ -65,17 +64,6 @@ def make_parser(parser_creator=None, **kwargs): default=1, type=int, help="Number of times to repeat each trial.") - parser.add_argument( - "--local-dir", - default=DEFAULT_RESULTS_DIR, - type=str, - help="Local dir to save training results to. Defaults to '{}'.".format( - DEFAULT_RESULTS_DIR)) - parser.add_argument( - "--upload-dir", - default="", - type=str, - help="Optional URI to sync training results to (e.g. s3://bucket).") parser.add_argument( "--checkpoint-freq", default=0, @@ -183,7 +171,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): trainable_name=spec["run"], # json.load leads to str -> unicode in py2.7 config=spec.get("config", {}), - local_dir=os.path.join(args.local_dir, output_path), + local_dir=os.path.join(spec["local_dir"], output_path), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, @@ -193,10 +181,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), - upload_dir=args.upload_dir, trial_name_creator=spec.get("trial_name_creator"), loggers=spec.get("loggers"), # str(None) doesn't create None - sync_function=spec.get("sync_function"), + sync_to_driver_fn=spec.get("sync_to_driver"), max_failures=args.max_failures, **trial_kwargs) diff --git a/python/ray/tune/examples/logging_example.py b/python/ray/tune/examples/logging_example.py index 5c42d5687e90..7f2de54cd10b 100755 --- a/python/ray/tune/examples/logging_example.py +++ b/python/ray/tune/examples/logging_example.py @@ -11,9 +11,8 @@ import numpy as np -import ray from ray import tune -from ray.tune import Trainable, run, Experiment +from ray.tune import Trainable, run class TestLogger(tune.logger.Logger): @@ -60,11 +59,11 @@ def _restore(self, checkpoint_path): parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() - ray.init() - exp = Experiment( + + trials = run( + MyTrainableClass, name="hyperband_test", - run=MyTrainableClass, - num_samples=1, + num_samples=5, trial_name_creator=tune.function(trial_str_creator), loggers=[TestLogger], stop={"training_iteration": 1 if args.smoke_test else 99999}, @@ -73,5 +72,3 @@ def _restore(self, checkpoint_path): lambda spec: 10 + int(90 * random.random())), "height": tune.sample_from(lambda spec: int(100 * random.random())) }) - - trials = run(exp) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 95cb12043f8f..8d64c6aa71b6 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -52,7 +52,6 @@ class Experiment(object): >>> }, >>> num_samples=10, >>> local_dir="~/ray_results", - >>> upload_dir="s3://your_bucket/path", >>> checkpoint_freq=10, >>> max_failures=2) """ @@ -68,7 +67,7 @@ def __init__(self, upload_dir=None, trial_name_creator=None, loggers=None, - sync_function=None, + sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, keep_checkpoints_num=None, @@ -78,18 +77,16 @@ def __init__(self, restore=None, repeat=None, trial_resources=None, - custom_loggers=None): - if sync_function: - assert upload_dir, "Need `upload_dir` if sync_function given." - + custom_loggers=None, + sync_function=None): if repeat: _raise_deprecation_note("repeat", "num_samples", soft=False) if trial_resources: _raise_deprecation_note( "trial_resources", "resources_per_trial", soft=False) - if custom_loggers: - _raise_deprecation_note("custom_loggers", "loggers", soft=False) - + if sync_function: + _raise_deprecation_note( + "sync_function", "sync_to_driver", soft=False) run_identifier = Experiment._register_if_needed(run) spec = { "run": run_identifier, @@ -98,10 +95,10 @@ def __init__(self, "resources_per_trial": resources_per_trial, "num_samples": num_samples, "local_dir": os.path.expanduser(local_dir or DEFAULT_RESULTS_DIR), - "upload_dir": upload_dir or "", # argparse converts None to "null" + "upload_dir": upload_dir, "trial_name_creator": trial_name_creator, "loggers": loggers, - "sync_function": sync_function, + "sync_to_driver": sync_to_driver, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, "keep_checkpoints_num": keep_checkpoints_num, @@ -182,7 +179,13 @@ def local_dir(self): @property def checkpoint_dir(self): - return os.path.join(self.spec["local_dir"], self.name) + if self.local_dir: + return os.path.join(self.local_dir, self.name) + + @property + def remote_checkpoint_dir(self): + if self.spec["upload_dir"]: + return os.path.join(self.spec["upload_dir"], self.name) def convert_to_experiment_list(experiments): diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index 0bb079b48770..6888ae6a8335 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -4,11 +4,6 @@ import distutils.spawn import logging -import os -import subprocess -import tempfile -import time -import types try: # py3 from shlex import quote @@ -17,231 +12,69 @@ import ray from ray.tune.cluster_info import get_ssh_key, get_ssh_user -from ray.tune.error import TuneError -from ray.tune.result import DEFAULT_RESULTS_DIR -from ray.tune.sample import function as tune_function logger = logging.getLogger(__name__) _log_sync_warned = False -# Map from (logdir, remote_dir) -> syncer -_syncers = {} -S3_PREFIX = "s3://" -GCS_PREFIX = "gs://" -ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX) +def log_sync_template(): + """Syncs the local_dir between driver and worker if possible. + Requires ray cluster to be started with the autoscaler. Also requires + rsync to be installed. -def get_syncer(local_dir, remote_dir=None, sync_function=None): - if remote_dir: - if not sync_function and not any( - remote_dir.startswith(prefix) - for prefix in ALLOWED_REMOTE_PREFIXES): - raise TuneError("Upload uri must start with one of: {}" - "".format(ALLOWED_REMOTE_PREFIXES)) - - if (remote_dir.startswith(S3_PREFIX) - and not distutils.spawn.find_executable("aws")): - raise TuneError( - "Upload uri starting with '{}' requires awscli tool" - " to be installed".format(S3_PREFIX)) - elif (remote_dir.startswith(GCS_PREFIX) - and not distutils.spawn.find_executable("gsutil")): - raise TuneError( - "Upload uri starting with '{}' requires gsutil tool" - " to be installed".format(GCS_PREFIX)) - - if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"): - rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR) - remote_dir = os.path.join(remote_dir, rel_path) - - key = (local_dir, remote_dir) - if key not in _syncers: - _syncers[key] = _LogSyncer(local_dir, remote_dir, sync_function) - - return _syncers[key] - - -def wait_for_log_sync(): - for syncer in _syncers.values(): - syncer.wait() - - -def validate_sync_function(sync_function): - if sync_function is None: + """ + if not distutils.spawn.find_executable("rsync"): + logger.error("Log sync requires rsync to be installed.") + return + global _log_sync_warned + ssh_key = get_ssh_key() + if ssh_key is None: + if not _log_sync_warned: + logger.error("Log sync requires cluster to be setup with " + "`ray up`.") + _log_sync_warned = True return - elif isinstance(sync_function, str): - assert "{remote_dir}" in sync_function, ( - "Sync template missing '{remote_dir}'.") - assert "{local_dir}" in sync_function, ( - "Sync template missing '{local_dir}'.") - elif not (isinstance(sync_function, types.FunctionType) - or isinstance(sync_function, tune_function)): - raise ValueError("Sync function {} must be string or function".format( - sync_function)) - - -class _LogSyncer(object): - """Log syncer for tune. - This syncs files from workers to the local node, and optionally also from - the local node to a remote directory (e.g. S3). + return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """ + """-o StrictHostKeyChecking=no" {{source}} {{target}}""" + ).format(ssh_key=quote(ssh_key)) - Arguments: - logdir (str): Directory to sync from. - upload_uri (str): Directory to sync to. - sync_function (func|str): Function for syncing the local_dir to - upload_dir. If string, then it must be a string template - for syncer to run and needs to include replacement fields - '{local_dir}' and '{remote_dir}'. - """ - def __init__(self, local_dir, remote_dir=None, sync_function=None): - self.local_dir = local_dir - self.remote_dir = remote_dir - self.logfile = tempfile.NamedTemporaryFile( - prefix="log_sync", dir=self.local_dir, suffix=".log", delete=False) +class NodeSyncMixin(): + """Mixin for syncing files to/from a remote dir to a local dir.""" - # Resolve sync_function into template or function - self.sync_func = None - self.sync_cmd_tmpl = None - if isinstance(sync_function, types.FunctionType) or isinstance( - sync_function, tune_function): - self.sync_func = sync_function - elif isinstance(sync_function, str): - self.sync_cmd_tmpl = sync_function - self.last_sync_time = 0 - self.sync_process = None + def __init__(self): + assert hasattr(self, "_remote_dir"), "Mixin not mixed with Syncer." self.local_ip = ray.services.get_node_ip_address() self.worker_ip = None - logger.debug("Created LogSyncer for {} -> {}".format( - local_dir, remote_dir)) - - def close(self): - self.logfile.close() def set_worker_ip(self, worker_ip): """Set the worker ip to sync logs from.""" self.worker_ip = worker_ip - def sync_if_needed(self): - if time.time() - self.last_sync_time > 300: - self.sync_now() - - def sync_to_worker_if_possible(self): - """Syncs the local logdir on driver to worker if possible. - - Requires ray cluster to be started with the autoscaler. Also requires - rsync to be installed. - """ + def _check_valid_worker_ip(self): + if not self.worker_ip: + logger.info("Worker ip unknown, skipping log sync for {}".format( + self._local_dir)) + return False if self.worker_ip == self.local_ip: - return - ssh_key = get_ssh_key() + logger.debug( + "Worker ip is local ip, skipping log sync for {}".format( + self._local_dir)) + return False + return True + + @property + def _remote_path(self): ssh_user = get_ssh_user() global _log_sync_warned - if ssh_key is None or ssh_user is None: + if not self._check_valid_worker_ip(): + return + if ssh_user is None: if not _log_sync_warned: logger.error("Log sync requires cluster to be setup with " "`ray up`.") _log_sync_warned = True return - if not distutils.spawn.find_executable("rsync"): - logger.error("Log sync requires rsync to be installed.") - return - source = "{}/".format(self.local_dir) - target = "{}@{}:{}/".format(ssh_user, self.worker_ip, self.local_dir) - final_cmd = (("""rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" {} {}""").format( - quote(ssh_key), quote(source), quote(target))) - logger.info("Syncing results to %s", str(self.worker_ip)) - sync_process = subprocess.Popen( - final_cmd, shell=True, stdout=self.logfile) - sync_process.wait() - - def sync_now(self, force=False): - self.last_sync_time = time.time() - if not self.worker_ip: - logger.debug("Worker ip unknown, skipping log sync for {}".format( - self.local_dir)) - return - - if self.worker_ip == self.local_ip: - worker_to_local_sync_cmd = None # don't need to rsync - else: - ssh_key = get_ssh_key() - ssh_user = get_ssh_user() - global _log_sync_warned - if ssh_key is None or ssh_user is None: - if not _log_sync_warned: - logger.error("Log sync requires cluster to be setup with " - "`ray up`.") - _log_sync_warned = True - return - if not distutils.spawn.find_executable("rsync"): - logger.error("Log sync requires rsync to be installed.") - return - source = "{}@{}:{}/".format(ssh_user, self.worker_ip, - self.local_dir) - target = "{}/".format(self.local_dir) - worker_to_local_sync_cmd = (( - """rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" {} {}""").format( - quote(ssh_key), quote(source), quote(target))) - - if self.remote_dir: - if self.sync_func: - local_to_remote_sync_cmd = None - try: - self.sync_func(self.local_dir, self.remote_dir) - except Exception: - logger.exception("Sync function failed.") - else: - local_to_remote_sync_cmd = self.get_remote_sync_cmd() - else: - local_to_remote_sync_cmd = None - - if self.sync_process: - self.sync_process.poll() - if self.sync_process.returncode is None: - if force: - self.sync_process.kill() - else: - logger.warning("Last sync is still in progress, skipping.") - return - - if worker_to_local_sync_cmd or local_to_remote_sync_cmd: - final_cmd = "" - if worker_to_local_sync_cmd: - final_cmd += worker_to_local_sync_cmd - if local_to_remote_sync_cmd: - if final_cmd: - final_cmd += " && " - final_cmd += local_to_remote_sync_cmd - logger.debug("Running log sync: {}".format(final_cmd)) - self.sync_process = subprocess.Popen( - final_cmd, shell=True, stdout=self.logfile) - - def wait(self): - if self.sync_process: - self.sync_process.wait() - - def get_remote_sync_cmd(self): - if self.sync_cmd_tmpl: - local_to_remote_sync_cmd = (self.sync_cmd_tmpl.format( - local_dir=quote(self.local_dir), - remote_dir=quote(self.remote_dir))) - elif self.remote_dir.startswith(S3_PREFIX): - local_to_remote_sync_cmd = ( - "aws s3 sync {local_dir} {remote_dir}".format( - local_dir=quote(self.local_dir), - remote_dir=quote(self.remote_dir))) - elif self.remote_dir.startswith(GCS_PREFIX): - local_to_remote_sync_cmd = ( - "gsutil rsync -r {local_dir} {remote_dir}".format( - local_dir=quote(self.local_dir), - remote_dir=quote(self.remote_dir))) - else: - logger.warning("Remote sync unsupported, skipping.") - local_to_remote_sync_cmd = None - - return local_to_remote_sync_cmd + return "{}@{}:{}/".format(ssh_user, self.worker_ip, self._remote_dir) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 895f4819e0c0..9e2a96aeef3e 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -13,7 +13,7 @@ import numpy as np import ray.cloudpickle as cloudpickle -from ray.tune.log_sync import get_syncer +from ray.tune.syncer import get_log_syncer from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \ TIMESTEPS_TOTAL @@ -33,13 +33,11 @@ class Logger(object): Arguments: config: Configuration passed to all logger creators. logdir: Directory for all logger creators to log to. - upload_uri (str): Optional URI where the logdir is sync'ed to. """ - def __init__(self, config, logdir, upload_uri=None): + def __init__(self, config, logdir): self.config = config self.logdir = logdir - self.uri = upload_uri self._init() def _init(self): @@ -196,24 +194,16 @@ def close(self): class UnifiedLogger(Logger): """Unified result logger for TensorBoard, rllab/viskit, plain json. - This class also periodically syncs output to the given upload uri. - Arguments: config: Configuration passed to all logger creators. logdir: Directory for all logger creators to log to. - upload_uri (str): Optional URI where the logdir is sync'ed to. loggers (list): List of logger creators. Defaults to CSV, Tensorboard, and JSON loggers. sync_function (func|str): Optional function for syncer to run. See ray/python/ray/tune/log_sync.py """ - def __init__(self, - config, - logdir, - upload_uri=None, - loggers=None, - sync_function=None): + def __init__(self, config, logdir, loggers=None, sync_function=None): if loggers is None: self._logger_cls_list = DEFAULT_LOGGERS else: @@ -221,24 +211,26 @@ def __init__(self, self._sync_function = sync_function self._log_syncer = None - Logger.__init__(self, config, logdir, upload_uri) + super(UnifiedLogger, self).__init__(config, logdir) def _init(self): self._loggers = [] for cls in self._logger_cls_list: try: - self._loggers.append(cls(self.config, self.logdir, self.uri)) + self._loggers.append(cls(self.config, self.logdir)) except Exception: logger.warning("Could not instantiate {} - skipping.".format( str(cls))) - self._log_syncer = get_syncer( - self.logdir, self.uri, sync_function=self._sync_function) + self._log_syncer = get_log_syncer( + self.logdir, + remote_dir=self.logdir, + sync_function=self._sync_function) def on_result(self, result): for _logger in self._loggers: _logger.on_result(result) self._log_syncer.set_worker_ip(result.get(NODE_IP)) - self._log_syncer.sync_if_needed() + self._log_syncer.sync_down_if_needed() def update_config(self, config): for _logger in self._loggers: @@ -247,13 +239,12 @@ def update_config(self, config): def close(self): for _logger in self._loggers: _logger.close() - self._log_syncer.sync_now(force=False) - self._log_syncer.close() + self._log_syncer.sync_down() def flush(self): for _logger in self._loggers: _logger.flush() - self._log_syncer.sync_now(force=False) + self._log_syncer.sync_down() def sync_results_to_new_location(self, worker_ip): """Sends the current log directory to the remote node. @@ -262,8 +253,13 @@ def sync_results_to_new_location(self, worker_ip): with the Ray autoscaler. """ if worker_ip != self._log_syncer.worker_ip: + logger.info("Syncing (blocking) results to {}".format(worker_ip)) + self._log_syncer.reset() self._log_syncer.set_worker_ip(worker_ip) - self._log_syncer.sync_to_worker_if_possible() + self._log_syncer.sync_up() + # TODO: change this because this is blocking. But failures + # are rare, so maybe this is OK? + self._log_syncer.wait() class _SafeFallbackEncoder(json.JSONEncoder): diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py new file mode 100644 index 000000000000..e83c966a61a8 --- /dev/null +++ b/python/ray/tune/syncer.py @@ -0,0 +1,266 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils +import logging +import os +import subprocess +import tempfile +import time +import types + +try: # py3 + from shlex import quote +except ImportError: # py2 + from pipes import quote + +from ray.tune.sample import function as tune_function +from ray.tune.error import TuneError +from ray.tune.log_sync import log_sync_template, NodeSyncMixin + +logger = logging.getLogger(__name__) + +S3_PREFIX = "s3://" +GS_PREFIX = "gs://" +ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GS_PREFIX) +SYNC_PERIOD = 300 + +_syncers = {} + + +def validate_sync_string(sync_string): + if "{source}" not in sync_string: + raise ValueError("Sync template missing '{source}'.") + if "{target}" not in sync_string: + raise ValueError("Sync template missing '{target}'.") + + +def wait_for_sync(): + for syncer in _syncers.values(): + syncer.wait() + + +class BaseSyncer(object): + def __init__(self, local_dir, remote_dir, sync_function=None): + """Syncs between two directories with the sync_function. + + Arguments: + local_dir (str): Directory to sync. Uniquely identifies the syncer. + remote_dir (str): Remote directory to sync with. + sync_function (func): Function for syncing the local_dir to + remote_dir. Defaults to a Noop. + """ + self._local_dir = (os.path.join(local_dir, "") + if local_dir else local_dir) + self._remote_dir = remote_dir + self.last_sync_up_time = float("-inf") + self.last_sync_down_time = float("-inf") + self._sync_function = sync_function or (lambda source, target: None) + + def sync_function(self, source, target): + """Executes sync between source and target. + + Can be overwritten by subclasses for custom sync procedures. + + Args: + source: Path to source file(s). + target: Path to target file(s). + """ + if self._sync_function: + return self._sync_function(source, target) + + def sync(self, source, target): + if not (source and target): + logger.debug( + "Source or target is empty, skipping log sync for {}".format( + self._local_dir)) + return + + try: + self.sync_function(source, target) + return True + except Exception: + logger.exception("Sync function failed.") + + def sync_up_if_needed(self): + if time.time() - self.last_sync_up_time > SYNC_PERIOD: + self.sync_up() + + def sync_down_if_needed(self): + if time.time() - self.last_sync_down_time > SYNC_PERIOD: + self.sync_down() + + def sync_down(self, *args, **kwargs): + self.sync(self._remote_path, self._local_dir, *args, **kwargs) + self.last_sync_down_time = time.time() + + def sync_up(self, *args, **kwargs): + self.sync(self._local_dir, self._remote_path, *args, **kwargs) + self.last_sync_up_time = time.time() + + def reset(self): + self.last_sync_up_time = float("-inf") + self.last_sync_down_time = float("-inf") + + def wait(self): + pass + + @property + def _remote_path(self): + """Protected method for accessing remote_dir. + + Can be overridden in subclass for custom path. + """ + return self._remote_dir + + +class CommandSyncer(BaseSyncer): + def __init__(self, local_dir, remote_dir, sync_template): + """Syncs between two directories with the given command. + + Arguments: + local_dir (str): Directory to sync. + remote_dir (str): Remote directory to sync with. + sync_template (str): A string template + for syncer to run and needs to include replacement fields + '{source}' and '{target}'. Returned when using + `CommandSyncer.sync_template`, which can be overridden + by subclass. + """ + super(CommandSyncer, self).__init__(local_dir, remote_dir) + if not isinstance(sync_template, str): + raise ValueError("{} is not a string.".format(sync_template)) + validate_sync_string(sync_template) + self._sync_template = sync_template + self.logfile = tempfile.NamedTemporaryFile( + prefix="log_sync", + dir=self._local_dir, + suffix=".log", + delete=False) + + self.sync_process = None + + def sync_function(self, source, target): + self.last_sync_time = time.time() + if self.sync_process: + self.sync_process.poll() + if self.sync_process.returncode is None: + logger.warning("Last sync is still in progress, skipping.") + return + final_cmd = self._sync_template.format( + source=quote(source), target=quote(target)) + logger.debug("Running sync: {}".format(final_cmd)) + self.sync_process = subprocess.Popen( + final_cmd, shell=True, stdout=self.logfile) + return True + + def reset(self): + if self.sync_process: + logger.warning("Sync process still running but resetting anyways.") + self.sync_process = None + super(CommandSyncer, self).reset() + + def wait(self): + if self.sync_process: + self.sync_process.wait() + + +def _get_sync_cls(sync_function): + if not sync_function: + return + if isinstance(sync_function, types.FunctionType) or isinstance( + sync_function, tune_function): + return BaseSyncer + elif isinstance(sync_function, str): + return CommandSyncer + else: + raise ValueError("Sync function {} must be string or function".format( + sync_function)) + + +def get_syncer(local_dir, remote_dir=None, sync_function=None): + """Returns a Syncer depending on given args. + + This syncer is in charge of syncing the local_dir with upload_dir. + + Args: + local_dir: Source directory for syncing. + remote_dir: Target directory for syncing. If None, + returns BaseSyncer with a noop. + sync_function (func | str): Function for syncing the local_dir to + remote_dir. If string, then it must be a string template for + syncer to run. If not provided, it defaults + to standard S3 or gsutil sync commands. + """ + key = (local_dir, remote_dir) + + if key in _syncers: + return _syncers[key] + + if not remote_dir: + _syncers[key] = BaseSyncer(local_dir, remote_dir) + return _syncers[key] + + sync_cls = _get_sync_cls(sync_function) + + if sync_cls: + _syncers[key] = sync_cls(local_dir, remote_dir, sync_function) + return _syncers[key] + + if remote_dir.startswith(S3_PREFIX): + if not distutils.spawn.find_executable("aws"): + raise TuneError( + "Upload uri starting with '{}' requires awscli tool" + " to be installed".format(S3_PREFIX)) + _syncers[key] = CommandSyncer(local_dir, remote_dir, + "aws s3 sync {source} {target}") + elif remote_dir.startswith(GS_PREFIX): + if not distutils.spawn.find_executable("gsutil"): + raise TuneError( + "Upload uri starting with '{}' requires gsutil tool" + " to be installed".format(GS_PREFIX)) + _syncers[key] = CommandSyncer(local_dir, remote_dir, + "gsutil rsync -r {source} {target}") + else: + raise TuneError("Upload uri must start with one of: {}" + "".format(ALLOWED_REMOTE_PREFIXES)) + + return _syncers[key] + + +def get_log_syncer(local_dir, remote_dir=None, sync_function=None): + """Returns a Syncer depending on given args. + + This syncer is in charge of syncing the local_dir with remote local_dir. + + Args: + local_dir: Source directory for syncing. + remote_dir: Target directory for syncing. If None, + returns BaseSyncer with noop. + sync_function (func | str): Function for syncing the local_dir to + remote_dir. If string, then it must be a string template for + syncer to run. If not provided, it defaults rsync. + """ + key = (local_dir, remote_dir) + + if key in _syncers: + return _syncers[key] + + sync_cls = None + if sync_function: + sync_cls = _get_sync_cls(sync_function) + else: + sync_cls = CommandSyncer + sync_function = log_sync_template() + + if not remote_dir or sync_function is None: + sync_cls = BaseSyncer + + class MixedSyncer(NodeSyncMixin, sync_cls): + def __init__(self, *args, **kwargs): + sync_cls.__init__(self, *args, **kwargs) + NodeSyncMixin.__init__(self) + + _syncers[key] = MixedSyncer(local_dir, remote_dir, sync_function) + return _syncers[key] diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index e00e5da371c5..8c8274bcb2e8 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -272,8 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): cluster.wait_for_nodes() dirpath = str(tmpdir) - runner = TrialRunner( - BasicVariantGenerator(), metadata_checkpoint_dir=dirpath) + runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath) kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -295,7 +294,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): ray.shutdown() cluster = _start_new_cluster() - runner = TrialRunner.restore(dirpath) + runner = TrialRunner(resume="LOCAL", local_checkpoint_dir=dirpath) runner.step() # start runner.step() # start2 @@ -377,18 +376,19 @@ def test_cluster_rllib_restore(start_connected_cluster, tmpdir): # Wait until the right checkpoint is saved. # The trainable returns every 0.5 seconds, so this should not miss # the checkpoint. - metadata_checkpoint_dir = os.path.join(dirpath, "experiment") + local_checkpoint_dir = os.path.join(dirpath, "experiment") for i in range(100): - if TrialRunner.checkpoint_exists(metadata_checkpoint_dir): + if TrialRunner.checkpoint_exists(local_checkpoint_dir): # Inspect the internal trialrunner - runner = TrialRunner.restore(metadata_checkpoint_dir) + runner = TrialRunner( + resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir) trials = runner.get_trials() last_res = trials[0].last_result if last_res and last_res.get("training_iteration"): break time.sleep(0.3) - if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir): + if not TrialRunner.checkpoint_exists(local_checkpoint_dir): raise RuntimeError("Checkpoint file didn't appear.") ray.shutdown() @@ -469,18 +469,19 @@ def _restore(self, state): # Wait until the right checkpoint is saved. # The trainable returns every 0.5 seconds, so this should not miss # the checkpoint. - metadata_checkpoint_dir = os.path.join(dirpath, "experiment") + local_checkpoint_dir = os.path.join(dirpath, "experiment") for i in range(50): - if TrialRunner.checkpoint_exists(metadata_checkpoint_dir): + if TrialRunner.checkpoint_exists(local_checkpoint_dir): # Inspect the internal trialrunner - runner = TrialRunner.restore(metadata_checkpoint_dir) + runner = TrialRunner( + resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir) trials = runner.get_trials() last_res = trials[0].last_result if last_res and last_res.get("training_iteration") == 3: break time.sleep(0.2) - if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir): + if not TrialRunner.checkpoint_exists(local_checkpoint_dir): raise RuntimeError("Checkpoint file didn't appear.") ray.shutdown() @@ -489,7 +490,8 @@ def _restore(self, state): Experiment._register_if_needed(_Mock) # Inspect the internal trialrunner - runner = TrialRunner.restore(metadata_checkpoint_dir) + runner = TrialRunner( + resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir) trials = runner.get_trials() assert trials[0].last_result["training_iteration"] == 3 assert trials[0].status == Trial.PENDING diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 7b613a6fdea2..2697e7838a04 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -114,8 +114,8 @@ def testRunnerData(self): runner_data = self.ea.runner_data() self.assertTrue(isinstance(runner_data, dict)) - self.assertTrue("_metadata_checkpoint_dir" in runner_data) - self.assertEqual(runner_data["_metadata_checkpoint_dir"], + self.assertTrue("_local_checkpoint_dir" in runner_data) + self.assertEqual(runner_data["_local_checkpoint_dir"], os.path.expanduser(self.test_path)) def testBestLogdir(self): diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 64b8e9761488..5c4e2616aea0 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -3,6 +3,7 @@ from __future__ import print_function import copy +import glob import os import shutil import sys @@ -315,21 +316,6 @@ def train(config, reporter): } }) - def testUploadDirNone(self): - def train(config, reporter): - reporter(timesteps_total=1) - - [trial] = run_experiments({ - "foo": { - "run": train, - "upload_dir": None, - "config": { - "a": "b" - }, - } - }) - self.assertFalse(trial.upload_dir) - def testLogdirStartingWithTilde(self): local_dir = "~/ray_results/local_dir" @@ -930,50 +916,190 @@ def testCustomTrialString(self): str(trial), "{}_{}_321".format(trial.trainable_name, trial.trial_id)) - def testSyncFunction(self): - def fail_sync_local(): - [trial] = run_experiments({ - "foo": { - "run": "__fake", + +class TestSyncFunctionality(unittest.TestCase): + def setUp(self): + ray.init() + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + @patch("ray.tune.syncer.S3_PREFIX", "test") + def testNoUploadDir(self): + """No Upload Dir is given.""" + with self.assertRaises(AssertionError): + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ "stop": { "training_iteration": 1 }, - "upload_dir": "test", - "sync_function": "ls {remote_dir}" - } - }) + "sync_to_cloud": "echo {source} {target}" + }) - self.assertRaises(AssertionError, fail_sync_local) - - def fail_sync_remote(): - [trial] = run_experiments({ - "foo": { - "run": "__fake", + @patch("ray.tune.syncer.S3_PREFIX", "test") + def testCloudProperString(self): + with self.assertRaises(ValueError): + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ "stop": { "training_iteration": 1 }, "upload_dir": "test", - "sync_function": "ls {local_dir}" - } - }) + "sync_to_cloud": "ls {target}" + }) - self.assertRaises(AssertionError, fail_sync_remote) + with self.assertRaises(ValueError): + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "upload_dir": "test", + "sync_to_cloud": "ls {source}" + }) - def sync_func(local, remote): - with open(os.path.join(local, "test.log"), "w") as f: - f.write(remote) + tmpdir = tempfile.mkdtemp() + logfile = os.path.join(tmpdir, "test.log") - [trial] = run_experiments({ - "foo": { - "run": "__fake", + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ "stop": { "training_iteration": 1 }, "upload_dir": "test", - "sync_function": tune.function(sync_func) - } - }) - self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) + "sync_to_cloud": "echo {source} {target} > " + logfile + }) + with open(logfile) as f: + lines = f.read() + self.assertTrue("test" in lines) + shutil.rmtree(tmpdir) + + def testClusterProperString(self): + """Tests that invalid commands throw..""" + with self.assertRaises(TuneError): + # This raises TuneError because logger is init in safe zone. + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "sync_to_driver": "ls {target}" + }) + + with self.assertRaises(TuneError): + # This raises TuneError because logger is init in safe zone. + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "sync_to_driver": "ls {source}" + }) + + with patch("ray.tune.syncer.CommandSyncer.sync_function" + ) as mock_fn, patch( + "ray.services.get_node_ip_address") as mock_sync: + mock_sync.return_value = "0.0.0.0" + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "sync_to_driver": "echo {source} {target}" + }) + self.assertGreater(mock_fn.call_count, 0) + + def testCloudFunctions(self): + tmpdir = tempfile.mkdtemp() + tmpdir2 = tempfile.mkdtemp() + os.mkdir(os.path.join(tmpdir2, "foo")) + + def sync_func(local, remote): + for filename in glob.glob(os.path.join(local, "*.json")): + shutil.copy(filename, remote) + + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + local_dir=tmpdir, + stop={"training_iteration": 1}, + upload_dir=tmpdir2, + sync_to_cloud=tune.function(sync_func)) + test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json")) + self.assertTrue(test_file_path) + shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir2) + + def testClusterSyncFunction(self): + def sync_func_driver(source, target): + assert ":" in source, "Source not a remote path." + assert ":" not in target, "Target is supposed to be local." + with open(os.path.join(target, "test.log2"), "w") as f: + print("writing to", f.name) + f.write(source) + + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + stop={"training_iteration": 1}, + sync_to_driver=tune.function(sync_func_driver)) + test_file_path = os.path.join(trial.logdir, "test.log2") + self.assertFalse(os.path.exists(test_file_path)) + + with patch("ray.services.get_node_ip_address") as mock_sync: + mock_sync.return_value = "0.0.0.0" + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + stop={"training_iteration": 1}, + sync_to_driver=tune.function(sync_func_driver)) + test_file_path = os.path.join(trial.logdir, "test.log2") + self.assertTrue(os.path.exists(test_file_path)) + os.remove(test_file_path) + + def testNoSync(self): + def sync_func(source, target): + pass + + with patch("ray.tune.syncer.CommandSyncer.sync_function") as mock_sync: + [trial] = tune.run( + "__fake", + name="foo", + max_failures=0, + **{ + "stop": { + "training_iteration": 1 + }, + "upload_dir": "test", + "sync_to_driver": tune.function(sync_func), + "sync_to_cloud": tune.function(sync_func) + }) + self.assertEqual(mock_sync.call_count, 0) class VariantGeneratorTest(unittest.TestCase): @@ -1960,7 +2086,7 @@ def testTrialSaveRestore(self): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir) trials = [ Trial( "__fake", @@ -1999,7 +2125,7 @@ def testTrialSaveRestore(self): self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3) self.assertEquals(trials[2].status, Trial.RUNNING) - runner2 = TrialRunner.restore(tmpdir) + runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) for tid in ["trial_terminate", "trial_fail"]: original_trial = runner.get_trial(tid) restored_trial = runner2.get_trial(tid) @@ -2019,7 +2145,7 @@ def testTrialNoSave(self): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir) runner.add_trial( Trial( @@ -2051,7 +2177,7 @@ def testTrialNoSave(self): runner.step() runner.step() - runner2 = TrialRunner.restore(tmpdir) + runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) new_trials = runner2.get_trials() self.assertEquals(len(new_trials), 3) self.assertTrue( @@ -2074,13 +2200,13 @@ def testCheckpointWithFunction(self): }, checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir) runner.add_trial(trial) for i in range(5): runner.step() # force checkpoint runner.checkpoint() - runner2 = TrialRunner.restore(tmpdir) + runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) new_trial = runner2.get_trials()[0] self.assertTrue("callbacks" in new_trial.config) self.assertTrue("on_episode_start" in new_trial.config["callbacks"]) @@ -2095,7 +2221,7 @@ def count_checkpoints(cdir): ray.init() trial = Trial("__fake", checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir) runner.add_trial(trial) for i in range(5): runner.step() @@ -2103,7 +2229,7 @@ def count_checkpoints(cdir): runner.checkpoint() self.assertEquals(count_checkpoints(tmpdir), 1) - runner2 = TrialRunner.restore(tmpdir) + runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) for i in range(5): runner2.step() self.assertEquals(count_checkpoints(tmpdir), 2) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 1a44575c716e..938b5bfaac90 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -19,7 +19,6 @@ import ray from ray.tune import TuneError -from ray.tune.log_sync import validate_sync_function from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not @@ -276,10 +275,9 @@ def __init__(self, checkpoint_score_attr="", export_formats=None, restore_path=None, - upload_dir=None, trial_name_creator=None, loggers=None, - sync_function=None, + sync_to_driver_fn=None, max_failures=0): """Initialize a new trial. @@ -308,10 +306,8 @@ def __init__(self, resources = default_resources self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} - self.upload_dir = upload_dir self.loggers = loggers - self.sync_function = sync_function - validate_sync_function(sync_function) + self.sync_to_driver_fn = sync_to_driver_fn self.verbose = True self.max_failures = max_failures @@ -352,7 +348,7 @@ def __init__(self, self._nonjson_fields = [ "_checkpoint", "loggers", - "sync_function", + "sync_to_driver_fn", "results", "best_result", "param_config", @@ -394,9 +390,8 @@ def init_logger(self): self.result_logger = UnifiedLogger( self.config, self.logdir, - upload_uri=self.upload_dir, loggers=self.loggers, - sync_function=self.sync_function) + sync_function=self.sync_to_driver_fn) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index dfd809732857..5b20afe3d7f8 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import click import collections from datetime import datetime import json @@ -15,6 +16,7 @@ from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE +from ray.tune.syncer import get_syncer from ray.tune.trial import Trial, Checkpoint from ray.tune.sample import function from ray.tune.schedulers import FIFOScheduler, TrialScheduler @@ -97,12 +99,16 @@ class TrialRunner(object): """ CKPT_FILE_TMPL = "experiment_state-{}.json" + VALID_RESUME_TYPES = [True, "LOCAL", "REMOTE", "PROMPT"] def __init__(self, search_alg=None, scheduler=None, launch_web_server=False, - metadata_checkpoint_dir=None, + local_checkpoint_dir=None, + remote_checkpoint_dir=None, + sync_to_cloud=None, + resume=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, trial_executor=None): @@ -113,15 +119,16 @@ def __init__(self, Trial objects. scheduler (TrialScheduler): Defaults to FIFOScheduler. launch_web_server (bool): Flag for starting TuneServer - metadata_checkpoint_dir (str): Path where + local_checkpoint_dir (str): Path where global checkpoints are stored and restored from. - server_port (int): Port number for launching TuneServer + remote_checkpoint_dir (str): Remote path where + global checkpoints are stored and restored from. Used + if `resume` == REMOTE. + resume (str|False): see `tune.py:run`. + sync_to_cloud (func|str): see `tune.py:run`. + server_port (int): Port number for launching TuneServer. verbose (bool): Flag for verbosity. If False, trial results will not be output. - reuse_actors (bool): Whether to reuse actors between different - trials when possible. This can drastically speed up experiments - that start and stop actors often (e.g., PBT in - time-multiplexing mode). trial_executor (TrialExecutor): Defaults to RayTrialExecutor. """ self._search_alg = search_alg or BasicVariantGenerator() @@ -143,12 +150,73 @@ def __init__(self, self._trials = [] self._stop_queue = [] - self._metadata_checkpoint_dir = metadata_checkpoint_dir + self._local_checkpoint_dir = local_checkpoint_dir + + if self._local_checkpoint_dir and not os.path.exists( + self._local_checkpoint_dir): + os.makedirs(self._local_checkpoint_dir) + + self._remote_checkpoint_dir = remote_checkpoint_dir + self._syncer = get_syncer(local_checkpoint_dir, remote_checkpoint_dir, + sync_to_cloud) + + self._resumed = False + + if self._validate_resume(resume_type=resume): + try: + self.resume() + logger.info("Resuming trial.") + self._resumed = True + except Exception: + logger.exception( + "Runner restore failed. Restarting experiment.") + else: + logger.info("Starting a new experiment.") self._start_time = time.time() self._session_str = datetime.fromtimestamp( self._start_time).strftime("%Y-%m-%d_%H-%M-%S") + def _validate_resume(self, resume_type): + """Checks whether to resume experiment. + + Args: + resume_type: One of True, "REMOTE", "LOCAL", "PROMPT". + """ + if not resume_type: + return False + assert resume_type in self.VALID_RESUME_TYPES, ( + "resume_type {} is not one of {}".format(resume_type, + self.VALID_RESUME_TYPES)) + # Not clear if we need this assertion, since we should always have a + # local checkpoint dir. + assert self._local_checkpoint_dir or self._remote_checkpoint_dir + if resume_type in [True, "LOCAL", "PROMPT"]: + if not self.checkpoint_exists(self._local_checkpoint_dir): + raise ValueError("Called resume when no checkpoint exists " + "in local directory.") + elif resume_type == "PROMPT": + if click.confirm("Resume from local directory?"): + return True + + if resume_type in ["REMOTE", "PROMPT"]: + if resume_type == "PROMPT" and not click.confirm( + "Try downloading from remote directory?"): + return False + if not self._remote_checkpoint_dir: + raise ValueError( + "Called resume from remote without remote directory.") + + # Try syncing down the upload directory. + logger.info("Downloading from {}".format( + self._remote_checkpoint_dir)) + self._syncer.sync_down_if_needed() + + if not self.checkpoint_exists(self._local_checkpoint_dir): + raise ValueError("Called resume when no checkpoint exists " + "in remote or local directory.") + return True + @classmethod def checkpoint_exists(cls, directory): if not os.path.exists(directory): @@ -157,17 +225,21 @@ def checkpoint_exists(cls, directory): (fname.startswith("experiment_state") and fname.endswith(".json")) for fname in os.listdir(directory)) + def add_experiment(self, experiment): + if not self._resumed: + self._search_alg.add_configurations([experiment]) + else: + logger.info("TrialRunner resumed, ignoring new add_experiment.") + def checkpoint(self): - """Saves execution state to `self._metadata_checkpoint_dir`. + """Saves execution state to `self._local_checkpoint_dir`. Overwrites the current session checkpoint, which starts when self is instantiated. """ - if not self._metadata_checkpoint_dir: + if not self._local_checkpoint_dir: return - metadata_checkpoint_dir = self._metadata_checkpoint_dir - if not os.path.exists(metadata_checkpoint_dir): - os.makedirs(metadata_checkpoint_dir) + runner_state = { "checkpoints": list( self.trial_executor.get_checkpoints().values()), @@ -177,55 +249,37 @@ def checkpoint(self): "timestamp": time.time() } } - tmp_file_name = os.path.join(metadata_checkpoint_dir, + tmp_file_name = os.path.join(self._local_checkpoint_dir, ".tmp_checkpoint") with open(tmp_file_name, "w") as f: json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder) os.rename( tmp_file_name, - os.path.join(metadata_checkpoint_dir, + os.path.join(self._local_checkpoint_dir, TrialRunner.CKPT_FILE_TMPL.format(self._session_str))) - return metadata_checkpoint_dir + self._syncer.sync_up_if_needed() + return self._local_checkpoint_dir - @classmethod - def restore(cls, - metadata_checkpoint_dir, - search_alg=None, - scheduler=None, - trial_executor=None): - """Restores all checkpointed trials from previous run. + def resume(self): + """Resumes all checkpointed trials from previous run. Requires user to manually re-register their objects. Also stops all ongoing trials. - - Args: - metadata_checkpoint_dir (str): Path to metadata checkpoints. - search_alg (SearchAlgorithm): Search Algorithm. Defaults to - BasicVariantGenerator. - scheduler (TrialScheduler): Scheduler for executing - the experiment. - trial_executor (TrialExecutor): Manage the execution of trials. - - Returns: - runner (TrialRunner): A TrialRunner to resume experiments from. """ - newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir) + newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir) with open(newest_ckpt_path, "r") as f: runner_state = json.load(f, cls=_TuneFunctionDecoder) logger.warning("".join([ "Attempting to resume experiment from {}. ".format( - metadata_checkpoint_dir), "This feature is experimental, " + self._local_checkpoint_dir), "This feature is experimental, " "and may not work with all search algorithms. ", "This will ignore any new changes to the specification." ])) - runner = TrialRunner( - search_alg, scheduler=scheduler, trial_executor=trial_executor) - - runner.__setstate__(runner_state["runner_data"]) + self.__setstate__(runner_state["runner_data"]) trials = [] for trial_cp in runner_state["checkpoints"]: @@ -234,8 +288,7 @@ def restore(cls, trials += [new_trial] for trial in sorted( trials, key=lambda t: t.last_update_time, reverse=True): - runner.add_trial(trial) - return runner + self.add_trial(trial) def is_finished(self): """Returns whether all trials have finished running.""" @@ -626,6 +679,7 @@ def __getstate__(self): "_search_alg", "_scheduler_alg", "trial_executor", + "_syncer", ]: del state[k] state["launch_web_server"] = bool(self._server) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index db302f6bd5e6..3750210940c1 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import print_function -import click import logging import time @@ -12,7 +11,7 @@ from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.ray_trial_executor import RayTrialExecutor -from ray.tune.log_sync import wait_for_log_sync +from ray.tune.syncer import wait_for_sync from ray.tune.trial_runner import TrialRunner from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, FIFOScheduler, MedianStoppingRule) @@ -36,36 +35,6 @@ def _make_scheduler(args): args.scheduler, _SCHEDULERS.keys())) -def _find_checkpoint_dir(exp): - # TODO(rliaw): Make sure the checkpoint_dir is resolved earlier. - # Right now it is resolved somewhere far down the trial generation process - return exp.checkpoint_dir - - -def _prompt_restore(checkpoint_dir, resume): - restore = False - if TrialRunner.checkpoint_exists(checkpoint_dir): - if resume == "prompt": - msg = ("Found incomplete experiment at {}. " - "Would you like to resume it?".format(checkpoint_dir)) - restore = click.confirm(msg, default=False) - if restore: - logger.info("Tip: to always resume, " - "pass resume=True to run()") - else: - logger.info("Tip: to always start a new experiment, " - "pass resume=False to run()") - elif resume: - restore = True - else: - logger.info("Tip: to resume incomplete experiments, " - "pass resume='prompt' or resume=True to run()") - else: - logger.info( - "Did not find checkpoint file in {}.".format(checkpoint_dir)) - return restore - - def run(run_or_experiment, name=None, stop=None, @@ -76,7 +45,8 @@ def run(run_or_experiment, upload_dir=None, trial_name_creator=None, loggers=None, - sync_function=None, + sync_to_cloud=None, + sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, export_formats=None, @@ -93,7 +63,8 @@ def run(run_or_experiment, trial_executor=None, raise_on_failed_trial=True, return_trials=True, - ray_auto_init=True): + ray_auto_init=True, + sync_function=None): """Executes training. Args: @@ -129,10 +100,15 @@ def run(run_or_experiment, loggers (list): List of logger creators to be used with each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS. See `ray/tune/logger.py`. - sync_function (func|str): Function for syncing the local_dir to - upload_dir. If string, then it must be a string template for - syncer to run. If not provided, the sync command defaults - to standard S3 or gsutil sync comamnds. + sync_to_cloud (func|str): Function for syncing the local_dir to and + from upload_dir. If string, then it must be a string template + that includes `{source}` and `{target}` for the syncer to run. + If not provided, the sync command defaults to standard + S3 or gsutil sync comamnds. + sync_to_driver (func|str): Function for syncing trial logdir from + remote node to local. If string, then it must be a string template + that includes `{source}` and `{target}` for the syncer to run. + If not provided, defaults to using rsync. checkpoint_freq (int): How many training iterations between checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the @@ -155,9 +131,12 @@ def run(run_or_experiment, server_port (int): Port number for launching TuneServer. verbose (int): 0, 1, or 2. Verbosity mode. 0 = silent, 1 = only status updates, 2 = status and trial results. - resume (bool|"prompt"): If checkpoint exists, the experiment will - resume from there. If resume is "prompt", Tune will prompt if - checkpoint detected. + resume (str|bool): One of "LOCAL", "REMOTE", "PROMPT", or bool. + LOCAL/True restores the checkpoint from the local_checkpoint_dir. + REMOTE restores the checkpoint from remote_checkpoint_dir. + PROMPT provides CLI feedback. False forces a new + experiment. If resume is set but checkpoint does not exist, + ValueError will be thrown. queue_trials (bool): Whether to queue trials when the cluster does not currently have enough resources to launch one. This should be set to True when running on an autoscaling cluster to enable @@ -172,6 +151,8 @@ def run(run_or_experiment, ray_auto_init (bool): Automatically starts a local Ray cluster if using a RayTrialExecutor (which is the default) and if Ray is not initialized. Defaults to True. + sync_function: Deprecated. See `sync_to_cloud` and + `sync_to_driver`. Returns: List of Trial objects. @@ -199,53 +180,45 @@ def run(run_or_experiment, ray_auto_init=ray_auto_init) experiment = run_or_experiment if not isinstance(run_or_experiment, Experiment): + run_identifier = Experiment._register_if_needed(run_or_experiment) experiment = Experiment( name=name, - run=run_or_experiment, + run=run_identifier, stop=stop, config=config, resources_per_trial=resources_per_trial, num_samples=num_samples, local_dir=local_dir, upload_dir=upload_dir, + sync_to_driver=sync_to_driver, trial_name_creator=trial_name_creator, loggers=loggers, - sync_function=sync_function, checkpoint_freq=checkpoint_freq, checkpoint_at_end=checkpoint_at_end, export_formats=export_formats, max_failures=max_failures, - restore=restore) + restore=restore, + sync_function=sync_function) else: logger.debug("Ignoring some parameters passed into tune.run.") - checkpoint_dir = _find_checkpoint_dir(experiment) - should_restore = _prompt_restore(checkpoint_dir, resume) - - runner = None - if should_restore: - try: - runner = TrialRunner.restore(checkpoint_dir, search_alg, scheduler, - trial_executor) - except Exception: - logger.exception("Runner restore failed. Restarting experiment.") - else: - logger.info("Starting a new experiment.") - - if not runner: - scheduler = scheduler or FIFOScheduler() - search_alg = search_alg or BasicVariantGenerator() + if sync_to_cloud: + assert experiment.remote_checkpoint_dir, ( + "Need `upload_dir` if `sync_to_cloud` given.") - search_alg.add_configurations([experiment]) + runner = TrialRunner( + search_alg=search_alg or BasicVariantGenerator(), + scheduler=scheduler or FIFOScheduler(), + local_checkpoint_dir=experiment.checkpoint_dir, + remote_checkpoint_dir=experiment.remote_checkpoint_dir, + sync_to_cloud=sync_to_cloud, + resume=resume, + launch_web_server=with_server, + server_port=server_port, + verbose=bool(verbose > 1), + trial_executor=trial_executor) - runner = TrialRunner( - search_alg=search_alg, - scheduler=scheduler, - metadata_checkpoint_dir=checkpoint_dir, - launch_web_server=with_server, - server_port=server_port, - verbose=bool(verbose > 1), - trial_executor=trial_executor) + runner.add_experiment(experiment) if verbose: print(runner.debug_string(max_debug=99999)) @@ -261,7 +234,7 @@ def run(run_or_experiment, if verbose: print(runner.debug_string(max_debug=99999)) - wait_for_log_sync() + wait_for_sync() errored_trials = [] for trial in runner.get_trials():