-
Notifications
You must be signed in to change notification settings - Fork 7
/
config_parser.py
202 lines (182 loc) · 6.95 KB
/
config_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
# For compatibility under py2 to consider unicode as str
from six import string_types
from ray.tune import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Resources, Trial
from ray.tune.logger import _SafeFallbackEncoder
def json_to_resources(data):
if data is None or data == "null":
return None
if isinstance(data, string_types):
data = json.loads(data)
for k in data:
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
raise TuneError(
"The field `{}` is no longer supported. Use `extra_cpu` "
"or `extra_gpu` instead.".format(k))
if k not in Resources._fields:
raise TuneError(
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))
def resources_to_json(resources):
if resources is None:
return None
return {
"cpu": resources.cpu,
"gpu": resources.gpu,
"extra_cpu": resources.extra_cpu,
"extra_gpu": resources.extra_gpu,
}
def make_parser(parser_creator=None, **kwargs):
"""Returns a base argument parser for the ray.tune tool.
Args:
parser_creator: A constructor for the parser class.
kwargs: Non-positional args to be passed into the
parser class constructor.
"""
if parser_creator:
parser = parser_creator(**kwargs)
else:
parser = argparse.ArgumentParser(**kwargs)
# Note: keep this in sync with rllib/train.py
parser.add_argument(
"--run",
default=None,
type=str,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
parser.add_argument(
"--stop",
default="{}",
type=json.loads,
help="The stopping criteria, specified in JSON. The keys may be any "
"field returned by 'train()' e.g. "
"'{\"time_total_s\": 600, \"training_iteration\": 100000}' to stop "
"after 600 seconds or 100k iterations, whichever is reached first.")
parser.add_argument(
"--config",
default="{}",
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--trial-resources",
default=None,
type=json_to_resources,
help="Override the machine resources to allocate per trial, e.g. "
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
"unless you specify them here. For RLlib, you probably want to "
"leave this alone and use RLlib configs to control parallelism.")
parser.add_argument(
"--num-samples",
default=1,
type=int,
help="Number of times to repeat each trial.")
parser.add_argument(
"--local-dir",
default=DEFAULT_RESULTS_DIR,
type=str,
help="Local dir to save training results to. Defaults to '{}'.".format(
DEFAULT_RESULTS_DIR))
parser.add_argument(
"--upload-dir",
default="",
type=str,
help="Optional URI to sync training results to (e.g. s3://bucket).")
parser.add_argument(
"--checkpoint-freq",
default=0,
type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--checkpoint-at-end",
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--max-failures",
default=3,
type=int,
help="Try to recover a trial from its last checkpoint at least this "
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler",
default="FIFO",
type=str,
help="FIFO (default), MedianStopping, AsyncHyperBand, "
"HyperBand, or HyperOpt.")
parser.add_argument(
"--scheduler-config",
default="{}",
type=json.loads,
help="Config options to pass to the scheduler.")
# Note: this currently only makes sense when running a single trial
parser.add_argument(
"--restore",
default=None,
type=str,
help="If specified, restore from this checkpoint.")
return parser
def to_argv(config):
"""Converts configuration to a command line argument format."""
argv = []
for k, v in config.items():
if "-" in k:
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
if not isinstance(v, bool) or v: # for argparse flags
argv.append("--{}".format(k.replace("_", "-")))
if isinstance(v, string_types):
argv.append(v)
elif isinstance(v, bool):
pass
else:
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
return argv
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
"""Creates a Trial object from parsing the spec.
Arguments:
spec (dict): A resolved experiment specification. Arguments should
The args here should correspond to the command line flags
in ray.tune.config_parser.
output_path (str); A specific output path within the local_dir.
Typically the name of the experiment.
parser (ArgumentParser): An argument parser object from
make_parser.
trial_kwargs: Extra keyword arguments used in instantiating the Trial.
Returns:
A trial object with corresponding parameters to the specification.
"""
try:
args = parser.parse_args(to_argv(spec))
except SystemExit:
raise TuneError("Error parsing args, see above message", spec)
if "trial_resources" in spec:
trial_kwargs["resources"] = json_to_resources(spec["trial_resources"])
return Trial(
# Submitting trial via server in py2.7 creates Unicode, which does not
# convert to string in a straightforward manner.
trainable_name=spec["run"],
# json.load leads to str -> unicode in py2.7
config=spec.get("config", {}),
local_dir=os.path.join(args.local_dir, output_path),
# json.load leads to str -> unicode in py2.7
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
# str(None) doesn't create None
restore_path=spec.get("restore"),
upload_dir=args.upload_dir,
max_failures=args.max_failures,
**trial_kwargs)