-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathlogger.py
173 lines (137 loc) · 5.5 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
from argparse import Namespace
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
import numpy as np
from torch import Tensor
def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[str, Any]:
"""Ensure parameters are a dict or convert to dict if necessary.
Args:
params: Target to be converted to a dictionary
Returns:
params as a dictionary
"""
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)
if params is None:
params = {}
return params
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
Args:
params: Dictionary containing the hyperparameters
Returns:
dictionary with all callables sanitized
"""
def _sanitize_callable(val: Any) -> Any:
if inspect.isclass(val):
# If it's a class, don't try to instantiate it, just return the name
return val.__name__
if callable(val):
# Callables get a chance to return a name
try:
_val = val()
if callable(_val):
return val.__name__
return _val
# todo: specify the possible exception
except Exception:
return getattr(val, "__name__", None)
return val
return {key: _sanitize_callable(val) for key, val in params.items()}
def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]:
"""Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
Args:
params: Dictionary containing the hyperparameters
delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
Returns:
Flattened dict.
Examples:
>>> _flatten_dict({'a': {'b': 'c'}})
{'a/b': 'c'}
>>> _flatten_dict({'a': {'b': 123}})
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
"""
result: Dict[str, Any] = {}
for k, v in params.items():
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
if is_dataclass(v):
v = asdict(v)
elif isinstance(v, Namespace):
v = vars(v)
if isinstance(v, MutableMapping):
result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
else:
result[new_key] = v
return result
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Returns params with non-primitvies converted to strings for logging.
>>> import torch
>>> params = {"float": 0.3,
... "int": 1,
... "string": "abc",
... "bool": True,
... "list": [1, 2, 3],
... "namespace": Namespace(foo=3),
... "layer": torch.nn.BatchNorm1d}
>>> import pprint
>>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
{'bool': True,
'float': 0.3,
'int': 1,
'layer': "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>",
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
"""
for k in params:
# convert relevant np scalars to python types first (instead of str)
if isinstance(params[k], (np.bool_, np.integer, np.floating)):
params[k] = params[k].item()
elif type(params[k]) not in [bool, int, float, str, Tensor]:
params[k] = str(params[k])
return params
def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]:
"""Convert non-serializable objects in params to string."""
return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()}
def _is_json_serializable(value: Any) -> bool:
"""Test whether a variable can be encoded as json."""
if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path
return True
try:
json.dumps(value)
return True
except (TypeError, OverflowError):
# OverflowError is raised if number is too large to encode
return False
def _add_prefix(
metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str
) -> Mapping[str, Union[Tensor, float]]:
"""Insert prefix before each key in a dict, separated by the separator.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
prefix: Prefix to insert before each key
separator: Separates prefix and original key name
Returns:
Dictionary with prefix and separator inserted before each key
"""
if not prefix:
return metrics
return {f"{prefix}{separator}{k}": v for k, v in metrics.items()}