diff --git a/tools/launch.py b/tools/launch.py index de42ea2a7dd3..0908950636e3 100755 --- a/tools/launch.py +++ b/tools/launch.py @@ -36,7 +36,14 @@ def dmlc_opts(opts): '--cluster', opts.launcher, '--host-file', opts.hostfile, '--sync-dst-dir', opts.sync_dst_dir] - args += opts.command; + + # convert to dictionary + dopts = vars(opts) + for key in ['env_server', 'env_worker', 'env']: + for v in dopts[key]: + args.append('--' + key.replace("_","-")) + args.append(v) + args += opts.command try: from dmlc_tracker import opts except ImportError: @@ -64,6 +71,21 @@ def main(): parser.add_argument('--launcher', type=str, default='ssh', choices = ['local', 'ssh', 'mpi', 'sge', 'yarn'], help = 'the launcher to use') + parser.add_argument('--env-server', action='append', default=[], + help = 'Given a pair of environment_variable:value, sets this value of \ + environment variable for the server processes. This overrides values of \ + those environment variable on the machine where this script is run from. \ + Example OMP_NUM_THREADS:3') + parser.add_argument('--env-worker', action='append', default=[], + help = 'Given a pair of environment_variable:value, sets this value of \ + environment variable for the worker processes. This overrides values of \ + those environment variable on the machine where this script is run from. \ + Example OMP_NUM_THREADS:3') + parser.add_argument('--env', action='append', default=[], + help = 'given a environment variable, passes their \ + values from current system to all workers and servers. \ + Not necessary when launcher is local as in that case \ + all environment variables which are set are copied.') parser.add_argument('command', nargs='+', help = 'command for launching the program') args, unknown = parser.parse_known_args()