-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
argparse.py
302 lines (248 loc) · 10.8 KB
/
argparse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
Args:
cls: Lightning class
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid Trainer arguments.
Example:
>>> from pytorch_lightning import Trainer
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)
return cls(**trainer_kwargs)
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)}
modified_args = {}
for k, v in vars(args).items():
if k in types_default and v is None:
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
arg_types, arg_default = types_default[k]
if bool in arg_types and isinstance(arg_default, bool):
# Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable something that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here.
v = True
modified_args[k] = v
return Namespace(**modified_args)
def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Example:
>>> from pytorch_lightning import Trainer
>>> parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_GPUS"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> parse_env_variables(Trainer)
Namespace(gpus=42)
>>> del os.environ["PL_TRAINER_GPUS"]
"""
cls_arg_defaults = get_init_arguments_and_types(cls)
env_args = {}
for arg_name, _, _ in cls_arg_defaults:
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ''):
# todo: specify the possible exception
with suppress(Exception):
# converting to native types like int/float/bool
val = eval(val)
env_args[arg_name] = val
return Namespace(**env_args)
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Examples:
>>> from pytorch_lightning import Trainer
>>> args = get_init_arguments_and_types(Trainer)
"""
cls_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in cls_default_params:
arg_type = cls_default_params[arg].annotation
arg_default = cls_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type, )
name_type_default.append((arg, arg_types, arg_default))
return name_type_default
def _get_abbrev_qualified_cls_name(cls):
assert isinstance(cls, type), repr(cls)
if cls.__module__.startswith("pytorch_lightning."):
# Abbreviate.
return f"pl.{cls.__name__}"
else:
# Fully qualified.
return f"{cls.__module__}.{cls.__qualname__}"
def add_argparse_args(
cls,
parent_parser: ArgumentParser,
*,
use_argument_group=True,
) -> ArgumentParser:
r"""Extends existing argparse by default attributes for ``cls``.
Args:
cls: Lightning class
parent_parser:
The custom cli arguments parser, which will be extended by
the class's default arguments.
use_argument_group:
By default, this is True, and uses ``add_argument_group`` to add
a new group.
If False, this will use old behavior.
Returns:
If use_argument_group is True, returns ``parent_parser`` to keep old
workflows. If False, will return the new ArgumentParser object.
Only arguments of the allowed types (str, float, int, bool) will
extend the ``parent_parser``.
Examples:
# Option 1: Default usage.
>>> import argparse
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
# Option 2: Disable use_argument_group (old behavior).
>>> import argparse
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser, use_argument_group=False)
>>> args = parser.parse_args([])
"""
if isinstance(parent_parser, _ArgumentGroup):
raise RuntimeError("Please only pass an ArgumentParser instance.")
if use_argument_group:
group_name = _get_abbrev_qualified_cls_name(cls)
parser = parent_parser.add_argument_group(group_name)
else:
parser = ArgumentParser(
parents=[parent_parser],
add_help=False,
)
ignore_arg_names = ['self', 'args', 'kwargs']
if hasattr(cls, "get_deprecated_arg_names"):
ignore_arg_names += cls.get_deprecated_arg_names()
allowed_types = (str, int, float, bool)
# Get symbols from cls or init function.
for symbol in (cls, cls.__init__):
args_and_types = get_init_arguments_and_types(symbol)
args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names]
if len(args_and_types) > 0:
break
args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "")
for arg, arg_types, arg_default in args_and_types:
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = str_to_bool
elif int in arg_types:
use_type = str_to_bool_or_int
elif str in arg_types:
use_type = str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = _int_or_float_type
# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float
parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help=args_help.get(arg),
**arg_kwargs,
)
if use_argument_group:
return parent_parser
else:
return parser
def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
arg_block_indent = None
current_arg = None
parsed = {}
for line in docstring.split("\n"):
stripped = line.lstrip()
if not stripped:
continue
line_indent = len(line) - len(stripped)
if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')):
arg_block_indent = line_indent + 4
elif arg_block_indent is None:
continue
elif line_indent < arg_block_indent:
break
elif line_indent == arg_block_indent:
current_arg, arg_description = stripped.split(':', maxsplit=1)
parsed[current_arg] = arg_description.lstrip()
elif line_indent > arg_block_indent:
parsed[current_arg] += f' {stripped}'
return parsed
def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass
def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)