1
1
"""FBResearch logger and its helper handlers."""
2
2
3
3
import datetime
4
- from typing import Any , Optional
5
-
6
- # from typing import Any, Dict, Optional, Union
4
+ import numbers
5
+ from typing import Any , Callable , List , Optional
7
6
8
7
import torch
9
8
9
+ from ignite import utils
10
10
from ignite .engine import Engine , Events
11
11
from ignite .handlers import Timer
12
+ from ignite .handlers .utils import global_step_from_engine # noqa
12
13
13
14
14
15
MB = 1024.0 * 1024.0
15
16
17
+ __all__ = ["FBResearchLogger" , "global_step_from_engine" ]
18
+
19
+
20
+ def is_iterable (obj ):
21
+ try :
22
+ iter (obj )
23
+ return True
24
+ except TypeError :
25
+ return False
26
+
16
27
17
28
class FBResearchLogger :
18
29
"""Logs training and validation metrics for research purposes.
@@ -60,32 +71,32 @@ class FBResearchLogger:
60
71
.. code-block:: text
61
72
62
73
2024-04-22 12:05:47,843 trainer INFO: Train: start epoch [1/4]
63
- 2024-04-22 12:05:47,861 trainer INFO: Epoch [1/4] [20/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5999 Iter time: 0.0008 s Data prep time: 0.0000 s
64
- 2024-04-22 12:05:47,877 trainer INFO: Epoch [1/4] [40/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9297 Iter time: 0.0008 s Data prep time: 0.0000 s
65
- 2024-04-22 12:05:47,893 trainer INFO: Epoch [1/4] [60/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9985 Iter time: 0.0008 s Data prep time: 0.0000 s
66
- 2024-04-22 12:05:47,910 trainer INFO: Epoch [1/4] [80/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9785 Iter time: 0.0008 s Data prep time: 0.0000 s
67
- 2024-04-22 12:05:47,925 trainer INFO: Epoch [1/4] [100/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6211 Iter time: 0.0008 s Data prep time: 0.0000 s
74
+ 2024-04-22 12:05:47,861 trainer INFO: Epoch [1/4] [20/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5999 Iter time: 0.0008 s Data prep time: 0.0000 s
75
+ 2024-04-22 12:05:47,877 trainer INFO: Epoch [1/4] [40/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9297 Iter time: 0.0008 s Data prep time: 0.0000 s
76
+ 2024-04-22 12:05:47,893 trainer INFO: Epoch [1/4] [60/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9985 Iter time: 0.0008 s Data prep time: 0.0000 s
77
+ 2024-04-22 12:05:47,910 trainer INFO: Epoch [1/4] [80/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9785 Iter time: 0.0008 s Data prep time: 0.0000 s
78
+ 2024-04-22 12:05:47,925 trainer INFO: Epoch [1/4] [100/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6211 Iter time: 0.0008 s Data prep time: 0.0000 s
68
79
2024-04-22 12:05:47,927 trainer INFO: Train: Epoch [1/4] Total time: 0:00:00 (0.0008 s / it)
69
80
2024-04-22 12:05:47,930 trainer INFO: Train: start epoch [2/4]
70
- 2024-04-22 12:05:47,949 trainer INFO: Epoch [2/4] [19/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5981 Iter time: 0.0009 s Data prep time: 0.0000 s
71
- 2024-04-22 12:05:47,965 trainer INFO: Epoch [2/4] [39/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9013 Iter time: 0.0008 s Data prep time: 0.0000 s
72
- 2024-04-22 12:05:47,981 trainer INFO: Epoch [2/4] [59/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9811 Iter time: 0.0008 s Data prep time: 0.0000 s
73
- 2024-04-22 12:05:47,997 trainer INFO: Epoch [2/4] [79/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9434 Iter time: 0.0008 s Data prep time: 0.0000 s
74
- 2024-04-22 12:05:48,016 trainer INFO: Epoch [2/4] [99/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6116 Iter time: 0.0008 s Data prep time: 0.0000 s
81
+ 2024-04-22 12:05:47,949 trainer INFO: Epoch [2/4] [19/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5981 Iter time: 0.0009 s Data prep time: 0.0000 s
82
+ 2024-04-22 12:05:47,965 trainer INFO: Epoch [2/4] [39/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9013 Iter time: 0.0008 s Data prep time: 0.0000 s
83
+ 2024-04-22 12:05:47,981 trainer INFO: Epoch [2/4] [59/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9811 Iter time: 0.0008 s Data prep time: 0.0000 s
84
+ 2024-04-22 12:05:47,997 trainer INFO: Epoch [2/4] [79/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9434 Iter time: 0.0008 s Data prep time: 0.0000 s
85
+ 2024-04-22 12:05:48,016 trainer INFO: Epoch [2/4] [99/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6116 Iter time: 0.0008 s Data prep time: 0.0000 s
75
86
2024-04-22 12:05:48,017 trainer INFO: Train: Epoch [2/4] Total time: 0:00:00 (0.0009 s / it)
76
87
2024-04-22 12:05:48,020 trainer INFO: Train: start epoch [3/4]
77
- 2024-04-22 12:05:48,038 trainer INFO: Epoch [3/4] [18/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5972 Iter time: 0.0008 s Data prep time: 0.0000 s
78
- 2024-04-22 12:05:48,055 trainer INFO: Epoch [3/4] [38/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8753 Iter time: 0.0008 s Data prep time: 0.0000 s
79
- 2024-04-22 12:05:48,076 trainer INFO: Epoch [3/4] [58/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9657 Iter time: 0.0009 s Data prep time: 0.0000 s
80
- 2024-04-22 12:05:48,092 trainer INFO: Epoch [3/4] [78/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9112 Iter time: 0.0008 s Data prep time: 0.0000 s
81
- 2024-04-22 12:05:48,108 trainer INFO: Epoch [3/4] [98/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6035 Iter time: 0.0008 s Data prep time: 0.0000 s
88
+ 2024-04-22 12:05:48,038 trainer INFO: Epoch [3/4] [18/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5972 Iter time: 0.0008 s Data prep time: 0.0000 s
89
+ 2024-04-22 12:05:48,055 trainer INFO: Epoch [3/4] [38/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8753 Iter time: 0.0008 s Data prep time: 0.0000 s
90
+ 2024-04-22 12:05:48,076 trainer INFO: Epoch [3/4] [58/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9657 Iter time: 0.0009 s Data prep time: 0.0000 s
91
+ 2024-04-22 12:05:48,092 trainer INFO: Epoch [3/4] [78/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9112 Iter time: 0.0008 s Data prep time: 0.0000 s
92
+ 2024-04-22 12:05:48,108 trainer INFO: Epoch [3/4] [98/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6035 Iter time: 0.0008 s Data prep time: 0.0000 s
82
93
2024-04-22 12:05:48,109 trainer INFO: Train: Epoch [3/4] Total time: 0:00:00 (0.0009 s / it)
83
94
2024-04-22 12:05:48,112 trainer INFO: Train: start epoch [4/4]
84
- 2024-04-22 12:05:48,129 trainer INFO: Epoch [4/4] [17/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5969 Iter time: 0.0008 s Data prep time: 0.0000 s
85
- 2024-04-22 12:05:48,145 trainer INFO: Epoch [4/4] [37/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8516 Iter time: 0.0008 s Data prep time: 0.0000 s
86
- 2024-04-22 12:05:48,161 trainer INFO: Epoch [4/4] [57/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9521 Iter time: 0.0008 s Data prep time: 0.0000 s
87
- 2024-04-22 12:05:48,181 trainer INFO: Epoch [4/4] [77/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8816 Iter time: 0.0008 s Data prep time: 0.0000 s
88
- 2024-04-22 12:05:48,205 trainer INFO: Epoch [4/4] [97/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5966 Iter time: 0.0009 s Data prep time: 0.0000 s
95
+ 2024-04-22 12:05:48,129 trainer INFO: Epoch [4/4] [17/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5969 Iter time: 0.0008 s Data prep time: 0.0000 s
96
+ 2024-04-22 12:05:48,145 trainer INFO: Epoch [4/4] [37/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8516 Iter time: 0.0008 s Data prep time: 0.0000 s
97
+ 2024-04-22 12:05:48,161 trainer INFO: Epoch [4/4] [57/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9521 Iter time: 0.0008 s Data prep time: 0.0000 s
98
+ 2024-04-22 12:05:48,181 trainer INFO: Epoch [4/4] [77/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8816 Iter time: 0.0008 s Data prep time: 0.0000 s
99
+ 2024-04-22 12:05:48,205 trainer INFO: Epoch [4/4] [97/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5966 Iter time: 0.0009 s Data prep time: 0.0000 s
89
100
2024-04-22 12:05:48,207 trainer INFO: Train: Epoch [4/4] Total time: 0:00:00 (0.0009 s / it)
90
101
2024-04-22 12:05:48,209 trainer INFO: Train: run completed Total time: 0:00:00
91
102
"""
@@ -98,16 +109,27 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False
98
109
self .show_output : bool = show_output
99
110
100
111
def attach (
101
- self , engine : Engine , name : str , every : int = 1 , optimizer : Optional [torch .optim .Optimizer ] = None
112
+ self ,
113
+ engine : Engine ,
114
+ name : str ,
115
+ every : int = 1 ,
116
+ output_transform : Optional [Callable ] = None ,
117
+ state_attributes : Optional [List [str ]] = None ,
118
+ optimizer : Optional [torch .optim .Optimizer ] = None ,
102
119
) -> None :
103
120
"""Attaches all the logging handlers to the given engine.
104
121
105
122
Args:
106
123
engine: The engine to attach the logging handlers to.
107
124
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
108
125
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
126
+ output_transform: A function to select the value to log.
127
+ state_attributes: A list of attributes to log.
109
128
optimizer: The optimizer used during training to log current learning rates.
110
129
"""
130
+ self .name = name
131
+ self .output_transform = output_transform
132
+ self .state_attributes = state_attributes
111
133
engine .add_event_handler (Events .EPOCH_STARTED , self .log_epoch_started , engine , name )
112
134
engine .add_event_handler (Events .ITERATION_COMPLETED (every = every ), self .log_every , engine , optimizer = optimizer )
113
135
engine .add_event_handler (Events .EPOCH_COMPLETED , self .log_epoch_completed , engine , name )
@@ -151,10 +173,15 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
151
173
outputs = []
152
174
if self .show_output and engine .state .output is not None :
153
175
output = engine .state .output
154
- if isinstance (output , dict ):
176
+ if self .output_transform is not None :
177
+ outputs .append (str (self .output_transform (output )))
178
+ elif isinstance (output , dict ):
155
179
outputs += [f"{ k } : { v :.4f} " for k , v in output .items ()]
180
+ elif isinstance (output , str ):
181
+ outputs .append (output )
156
182
else :
157
- outputs += [f"{ v :.4f} " if isinstance (v , float ) else f"{ v } " for v in output ] # type: ignore
183
+ # allow numbers or nested iterables of numbers
184
+ outputs .extend (utils .apply_to_type (output , numbers .Number , lambda x : f"{ x :.4f} " ))
158
185
159
186
lrs = ""
160
187
if optimizer is not None :
@@ -164,6 +191,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
164
191
for i , g in enumerate (optimizer .param_groups ):
165
192
lrs += f"lr [g{ i } ]: { g ['lr' ]:.5f} "
166
193
194
+ state_attrs = []
195
+ if self .state_attributes is not None :
196
+ state_attrs .append (str ({name : getattr (engine .state , name , None ) for name in self .state_attributes }))
167
197
msg = self .delimiter .join (
168
198
[
169
199
f"Epoch [{ engine .state .epoch } /{ engine .state .max_epochs } ]" ,
@@ -172,6 +202,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
172
202
f"{ lrs } " ,
173
203
]
174
204
+ outputs
205
+ + state_attrs
175
206
+ [
176
207
f"Iter time: { iter_avg_time :.4f} s" ,
177
208
f"Data prep time: { self .data_timer .value ():.4f} s" ,
0 commit comments