diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 9a89261be7de..faaef8c6a153 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -423,6 +423,8 @@ def rsync(config_file, source, target, override_cluster_name, down): override_cluster_name: set the name of the cluster down: whether we're syncing remote -> local """ + assert bool(source) == bool(target), ( + "Must either provide both or neither source and target.") config = yaml.load(open(config_file).read()) if override_cluster_name is not None: @@ -448,7 +450,12 @@ def rsync(config_file, source, target, override_cluster_name, down): rsync = updater.rsync_down else: rsync = updater.rsync_up - rsync(source, target, check_error=False) + + if source and target: + rsync(source, target, check_error=False) + else: + updater.sync_file_mounts(rsync) + finally: provider.cleanup() diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 9fff0c767467..c86750fe399d 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -183,25 +183,9 @@ def wait_for_ssh(self, deadline): return False - def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) - - deadline = time.time() + NODE_START_WAIT_S - self.set_ssh_ip_if_required() - - # Wait for SSH access - with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): - ssh_ok = self.wait_for_ssh(deadline) - assert ssh_ok, "Unable to SSH to node" - + def sync_file_mounts(self, sync_cmd): # Rsync file mounts - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "syncing-files"}) for remote_path, local_path in self.file_mounts.items(): - logger.info("NodeUpdater: " - "{}: Syncing {} to {}...".format( - self.node_id, local_path, remote_path)) assert os.path.exists(local_path), local_path if os.path.isdir(local_path): if not local_path.endswith("/"): @@ -217,7 +201,23 @@ def do_update(self): "mkdir -p {}".format(os.path.dirname(remote_path)), redirect=redirect, ) - self.rsync_up(local_path, remote_path, redirect=redirect) + sync_cmd(local_path, remote_path, redirect=redirect) + + def do_update(self): + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) + + deadline = time.time() + NODE_START_WAIT_S + self.set_ssh_ip_if_required() + + # Wait for SSH access + with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): + ssh_ok = self.wait_for_ssh(deadline) + assert ssh_ok, "Unable to SSH to node" + + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "syncing-files"}) + self.sync_file_mounts(self.rsync_up) # Run init commands self.provider.set_node_tags(self.node_id, @@ -236,6 +236,9 @@ def do_update(self): self.ssh_cmd(cmd, redirect=redirect) def rsync_up(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} to {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ @@ -247,6 +250,9 @@ def rsync_up(self, source, target, redirect=None, check_error=True): stderr=redirect or sys.stderr) def rsync_down(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} from {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 1951af208573..e993478c1011 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -529,8 +529,8 @@ def attach(cluster_config_file, start, tmux, cluster_name, new): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", @@ -543,8 +543,8 @@ def rsync_down(cluster_config_file, source, target, cluster_name): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index ac26d0353a98..7163dcfd6a01 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -49,6 +49,11 @@ action="store_true", default=False, help="disables CUDA training") +parser.add_argument( + "--redis-address", + default=None, + type=str, + help="The Redis address of the cluster.") parser.add_argument( "--seed", type=int, @@ -173,7 +178,7 @@ def _restore(self, checkpoint_path): from ray import tune from ray.tune.schedulers import HyperBandScheduler - ray.init() + ray.init(redis_address=args.redis_address) sched = HyperBandScheduler( time_attr="training_iteration", reward_attr="neg_mean_loss") tune.run(