Skip to content

Commit 1f6f20d

Browse files
committed
add some mypy fixes
1 parent cc88528 commit 1f6f20d

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

Diff for: ignite/handlers/fbresearch_logger.py

+35-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""FBResearch logger and its helper handlers."""
22

33
import datetime
4+
from typing import Any, Optional
5+
6+
# from typing import Any, Dict, Optional, Union
47

58
import torch
69

@@ -33,14 +36,16 @@ class FBResearchLogger:
3336
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
3437
"""
3538

36-
def __init__(self, logger, delimiter=" ", show_output=False):
39+
def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False):
3740
self.delimiter = delimiter
38-
self.logger = logger
39-
self.iter_timer = None
40-
self.data_timer = None
41-
self.show_output = show_output
42-
43-
def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
41+
self.logger: Any = logger
42+
self.iter_timer: Timer = Timer(average=True)
43+
self.data_timer: Timer = Timer(average=True)
44+
self.show_output: bool = show_output
45+
46+
def attach(
47+
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
48+
) -> None:
4449
"""Attaches all the logging handlers to the given engine.
4550
4651
Args:
@@ -54,15 +59,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
5459
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
5560
engine.add_event_handler(Events.COMPLETED, self.log_completed, engine, name)
5661

57-
self.iter_timer = Timer(average=True)
62+
self.iter_timer.reset()
5863
self.iter_timer.attach(
5964
engine,
6065
start=Events.EPOCH_STARTED,
6166
resume=Events.ITERATION_STARTED,
6267
pause=Events.ITERATION_COMPLETED,
6368
step=Events.ITERATION_COMPLETED,
6469
)
65-
self.data_timer = Timer(average=True)
70+
self.data_timer.reset()
6671
self.data_timer.attach(
6772
engine,
6873
start=Events.EPOCH_STARTED,
@@ -71,14 +76,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
7176
step=Events.GET_BATCH_COMPLETED,
7277
)
7378

74-
def log_every(self, engine, optimizer=None):
79+
def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = None) -> None:
7580
"""
7681
Logs the training progress at regular intervals.
7782
7883
Args:
7984
engine (Engine): The training engine.
8085
optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Defaults to None.
8186
"""
87+
assert engine.state.epoch_length is not None
8288
cuda_max_mem = ""
8389
if torch.cuda.is_available():
8490
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
@@ -89,12 +95,12 @@ def log_every(self, engine, optimizer=None):
8995
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
9096

9197
outputs = []
92-
if self.show_output:
98+
if self.show_output and engine.state.output is not None:
9399
output = engine.state.output
94100
if isinstance(output, dict):
95101
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
96102
else:
97-
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output]
103+
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
98104

99105
lrs = ""
100106
if optimizer is not None:
@@ -120,7 +126,7 @@ def log_every(self, engine, optimizer=None):
120126
)
121127
self.logger.info(msg)
122128

123-
def log_epoch_started(self, engine, name):
129+
def log_epoch_started(self, engine: Engine, name: str) -> None:
124130
"""
125131
Logs the start of an epoch.
126132
@@ -132,37 +138,44 @@ def log_epoch_started(self, engine, name):
132138
msg = f"{name}: start epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
133139
self.logger.info(msg)
134140

135-
def log_epoch_completed(self, engine, name):
141+
def log_epoch_completed(self, engine: Engine, name: str) -> None:
136142
"""
137143
Logs the completion of an epoch.
138144
139145
Args:
140-
engine (Engine): The engine object.
141-
name (str): The name of the epoch.
146+
engine (Engine): The engine object that triggered the event.
147+
name (str): The name of the event.
142148
149+
Returns:
150+
None
143151
"""
144152
epoch_time = engine.state.times[Events.EPOCH_COMPLETED.name]
145-
epoch_info = f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]" if engine.state.max_epochs > 1 else ""
153+
epoch_info = (
154+
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
155+
if engine.state.max_epochs > 1
156+
else "" # type: ignore
157+
)
146158
msg = self.delimiter.join(
147159
[
148160
f"{name}: {epoch_info}",
149-
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}",
150-
f"({epoch_time / engine.state.epoch_length:.4f} s / it)",
161+
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}", # type: ignore
162+
f"({epoch_time / engine.state.epoch_length:.4f} s / it)", # type: ignore
151163
]
152164
)
153165
self.logger.info(msg)
154166

155-
def log_completed(self, engine, name):
167+
def log_completed(self, engine: Engine, name: str) -> None:
156168
"""
157169
Logs the completion of a run.
158170
159171
Args:
160-
engine (Engine): The engine object.
172+
engine (Engine): The engine object representing the training/validation loop.
161173
name (str): The name of the run.
162174
163175
"""
164-
if engine.state.max_epochs > 1:
176+
if engine.state.max_epochs and engine.state.max_epochs > 1:
165177
total_time = engine.state.times[Events.COMPLETED.name]
178+
assert total_time is not None
166179
msg = self.delimiter.join(
167180
[
168181
f"{name}: run completed",

0 commit comments

Comments
 (0)