1
1
"""FBResearch logger and its helper handlers."""
2
2
3
3
import datetime
4
+ from typing import Any , Optional
5
+
6
+ # from typing import Any, Dict, Optional, Union
4
7
5
8
import torch
6
9
@@ -33,14 +36,16 @@ class FBResearchLogger:
33
36
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
34
37
"""
35
38
36
- def __init__ (self , logger , delimiter = " " , show_output = False ):
39
+ def __init__ (self , logger : Any , delimiter : str = " " , show_output : bool = False ):
37
40
self .delimiter = delimiter
38
- self .logger = logger
39
- self .iter_timer = None
40
- self .data_timer = None
41
- self .show_output = show_output
42
-
43
- def attach (self , engine : Engine , name : str , every : int = 1 , optimizer = None ):
41
+ self .logger : Any = logger
42
+ self .iter_timer : Timer = Timer (average = True )
43
+ self .data_timer : Timer = Timer (average = True )
44
+ self .show_output : bool = show_output
45
+
46
+ def attach (
47
+ self , engine : Engine , name : str , every : int = 1 , optimizer : Optional [torch .optim .Optimizer ] = None
48
+ ) -> None :
44
49
"""Attaches all the logging handlers to the given engine.
45
50
46
51
Args:
@@ -54,15 +59,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
54
59
engine .add_event_handler (Events .EPOCH_COMPLETED , self .log_epoch_completed , engine , name )
55
60
engine .add_event_handler (Events .COMPLETED , self .log_completed , engine , name )
56
61
57
- self .iter_timer = Timer ( average = True )
62
+ self .iter_timer . reset ( )
58
63
self .iter_timer .attach (
59
64
engine ,
60
65
start = Events .EPOCH_STARTED ,
61
66
resume = Events .ITERATION_STARTED ,
62
67
pause = Events .ITERATION_COMPLETED ,
63
68
step = Events .ITERATION_COMPLETED ,
64
69
)
65
- self .data_timer = Timer ( average = True )
70
+ self .data_timer . reset ( )
66
71
self .data_timer .attach (
67
72
engine ,
68
73
start = Events .EPOCH_STARTED ,
@@ -71,14 +76,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
71
76
step = Events .GET_BATCH_COMPLETED ,
72
77
)
73
78
74
- def log_every (self , engine , optimizer = None ):
79
+ def log_every (self , engine : Engine , optimizer : Optional [ torch . optim . Optimizer ] = None ) -> None :
75
80
"""
76
81
Logs the training progress at regular intervals.
77
82
78
83
Args:
79
84
engine (Engine): The training engine.
80
85
optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Defaults to None.
81
86
"""
87
+ assert engine .state .epoch_length is not None
82
88
cuda_max_mem = ""
83
89
if torch .cuda .is_available ():
84
90
cuda_max_mem = f"GPU Max Mem: { torch .cuda .max_memory_allocated () / MB :.0f} MB"
@@ -89,12 +95,12 @@ def log_every(self, engine, optimizer=None):
89
95
eta_seconds = iter_avg_time * (engine .state .epoch_length - current_iter )
90
96
91
97
outputs = []
92
- if self .show_output :
98
+ if self .show_output and engine . state . output is not None :
93
99
output = engine .state .output
94
100
if isinstance (output , dict ):
95
101
outputs += [f"{ k } : { v :.4f} " for k , v in output .items ()]
96
102
else :
97
- outputs += [f"{ v :.4f} " if isinstance (v , float ) else f"{ v } " for v in output ]
103
+ outputs += [f"{ v :.4f} " if isinstance (v , float ) else f"{ v } " for v in output ] # type: ignore
98
104
99
105
lrs = ""
100
106
if optimizer is not None :
@@ -120,7 +126,7 @@ def log_every(self, engine, optimizer=None):
120
126
)
121
127
self .logger .info (msg )
122
128
123
- def log_epoch_started (self , engine , name ) :
129
+ def log_epoch_started (self , engine : Engine , name : str ) -> None :
124
130
"""
125
131
Logs the start of an epoch.
126
132
@@ -132,37 +138,44 @@ def log_epoch_started(self, engine, name):
132
138
msg = f"{ name } : start epoch [{ engine .state .epoch } /{ engine .state .max_epochs } ]"
133
139
self .logger .info (msg )
134
140
135
- def log_epoch_completed (self , engine , name ) :
141
+ def log_epoch_completed (self , engine : Engine , name : str ) -> None :
136
142
"""
137
143
Logs the completion of an epoch.
138
144
139
145
Args:
140
- engine (Engine): The engine object.
141
- name (str): The name of the epoch .
146
+ engine (Engine): The engine object that triggered the event .
147
+ name (str): The name of the event .
142
148
149
+ Returns:
150
+ None
143
151
"""
144
152
epoch_time = engine .state .times [Events .EPOCH_COMPLETED .name ]
145
- epoch_info = f"Epoch [{ engine .state .epoch } /{ engine .state .max_epochs } ]" if engine .state .max_epochs > 1 else ""
153
+ epoch_info = (
154
+ f"Epoch [{ engine .state .epoch } /{ engine .state .max_epochs } ]"
155
+ if engine .state .max_epochs > 1
156
+ else "" # type: ignore
157
+ )
146
158
msg = self .delimiter .join (
147
159
[
148
160
f"{ name } : { epoch_info } " ,
149
- f"Total time: { datetime .timedelta (seconds = int (epoch_time ))} " ,
150
- f"({ epoch_time / engine .state .epoch_length :.4f} s / it)" ,
161
+ f"Total time: { datetime .timedelta (seconds = int (epoch_time ))} " , # type: ignore
162
+ f"({ epoch_time / engine .state .epoch_length :.4f} s / it)" , # type: ignore
151
163
]
152
164
)
153
165
self .logger .info (msg )
154
166
155
- def log_completed (self , engine , name ) :
167
+ def log_completed (self , engine : Engine , name : str ) -> None :
156
168
"""
157
169
Logs the completion of a run.
158
170
159
171
Args:
160
- engine (Engine): The engine object.
172
+ engine (Engine): The engine object representing the training/validation loop .
161
173
name (str): The name of the run.
162
174
163
175
"""
164
- if engine .state .max_epochs > 1 :
176
+ if engine .state .max_epochs and engine . state . max_epochs > 1 :
165
177
total_time = engine .state .times [Events .COMPLETED .name ]
178
+ assert total_time is not None
166
179
msg = self .delimiter .join (
167
180
[
168
181
f"{ name } : run completed" ,
0 commit comments