Skip to content

Commit f7fe4a2

Browse files
committed
add fbresearch_logger.py
Add FBResearchLogger class from unmerged branch object-detection-example Add minimal docs and tests
1 parent a7246e1 commit f7fe4a2

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

Diff for: docs/source/handlers.rst

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Loggers
5454

5555
visdom_logger
5656
wandb_logger
57+
fbresearch_logger
5758

5859
.. seealso::
5960

Diff for: ignite/handlers/fbresearch_logger.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""FBResearch logger and its helper handlers."""
2+
3+
import datetime
4+
5+
import torch
6+
7+
from ignite.engine import Engine, Events
8+
from ignite.handlers import Timer
9+
10+
11+
MB = 1024.0 * 1024.0
12+
13+
14+
class FBResearchLogger:
15+
"""Logs training and validation metrics for research purposes.
16+
17+
This logger is designed to attach to an Ignite Engine and log various metrics
18+
and system stats at configurable intervals, including learning rates, iteration
19+
times, and GPU memory usage.
20+
21+
Args:
22+
logger (logging.Logger): The logger to use for output.
23+
delimiter (str): The delimiter to use between metrics in the log output.
24+
show_output (bool): Flag to enable logging of the output from the engine's process function.
25+
26+
Examples:
27+
.. code-block:: python
28+
29+
import logging
30+
from ignite.handlers.fbresearch_logger import *
31+
32+
logger = FBResearchLogger(logger=logging.Logger(__name__), show_output=True)
33+
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
34+
"""
35+
36+
def __init__(self, logger, delimiter=" ", show_output=False):
37+
self.delimiter = delimiter
38+
self.logger = logger
39+
self.iter_timer = None
40+
self.data_timer = None
41+
self.show_output = show_output
42+
43+
def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
44+
"""Attaches all the logging handlers to the given engine.
45+
46+
Args:
47+
engine (Engine): The engine to attach the logging handlers to.
48+
name (str): The name of the engine (e.g., "Train", "Validate") to include in log messages.
49+
every (int): Frequency of iterations to log information. Logs are generated every 'every' iterations.
50+
optimizer: The optimizer used during training to log current learning rates.
51+
"""
52+
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
53+
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
54+
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
55+
engine.add_event_handler(Events.COMPLETED, self.log_completed, engine, name)
56+
57+
self.iter_timer = Timer(average=True)
58+
self.iter_timer.attach(
59+
engine,
60+
start=Events.EPOCH_STARTED,
61+
resume=Events.ITERATION_STARTED,
62+
pause=Events.ITERATION_COMPLETED,
63+
step=Events.ITERATION_COMPLETED,
64+
)
65+
self.data_timer = Timer(average=True)
66+
self.data_timer.attach(
67+
engine,
68+
start=Events.EPOCH_STARTED,
69+
resume=Events.GET_BATCH_STARTED,
70+
pause=Events.GET_BATCH_COMPLETED,
71+
step=Events.GET_BATCH_COMPLETED,
72+
)
73+
74+
def log_every(self, engine, optimizer=None):
75+
cuda_max_mem = ""
76+
if torch.cuda.is_available():
77+
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
78+
79+
current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
80+
iter_avg_time = self.iter_timer.value()
81+
82+
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)
83+
84+
outputs = []
85+
if self.show_output:
86+
output = engine.state.output
87+
if isinstance(output, dict):
88+
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
89+
else:
90+
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output]
91+
92+
lrs = ""
93+
if optimizer is not None:
94+
if len(optimizer.param_groups) == 1:
95+
lrs += f"lr: {optimizer.param_groups[0]['lr']:.5f}"
96+
else:
97+
for i, g in enumerate(optimizer.param_groups):
98+
lrs += f"lr [g{i}]: {g['lr']:.5f}"
99+
100+
msg = self.delimiter.join(
101+
[
102+
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
103+
f"[{current_iter}/{engine.state.epoch_length}]:",
104+
f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}",
105+
f"{lrs}",
106+
]
107+
+ outputs
108+
+ [
109+
f"Iter time: {iter_avg_time:.4f} s",
110+
f"Data prep time: {self.data_timer.value():.4f} s",
111+
cuda_max_mem,
112+
]
113+
)
114+
self.logger.info(msg)
115+
116+
def log_epoch_started(self, engine, name):
117+
msg = f"{name}: start epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
118+
self.logger.info(msg)
119+
120+
def log_epoch_completed(self, engine, name):
121+
epoch_time = engine.state.times[Events.EPOCH_COMPLETED.name]
122+
epoch_info = f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]" if engine.state.max_epochs > 1 else ""
123+
msg = self.delimiter.join(
124+
[
125+
f"{name}: {epoch_info}",
126+
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}",
127+
f"({epoch_time / engine.state.epoch_length:.4f} s / it)",
128+
]
129+
)
130+
self.logger.info(msg)
131+
132+
def log_completed(self, engine, name):
133+
if engine.state.max_epochs > 1:
134+
total_time = engine.state.times[Events.COMPLETED.name]
135+
msg = self.delimiter.join(
136+
[
137+
f"{name}: run completed",
138+
f"Total time: {datetime.timedelta(seconds=int(total_time))}",
139+
]
140+
)
141+
self.logger.info(msg)

Diff for: tests/ignite/handlers/test_fbresearch_logger.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
import re
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
from ignite.engine import Engine, Events
8+
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
9+
10+
11+
@pytest.fixture
12+
def mock_engine():
13+
engine = Engine(lambda e, b: None)
14+
engine.state.epoch = 1
15+
engine.state.max_epochs = 10
16+
engine.state.epoch_length = 100
17+
engine.state.iteration = 50
18+
return engine
19+
20+
21+
@pytest.fixture
22+
def mock_logger():
23+
return MagicMock(spec=logging.Logger)
24+
25+
26+
@pytest.fixture
27+
def fb_research_logger(mock_logger):
28+
yield FBResearchLogger(logger=mock_logger, show_output=True)
29+
30+
31+
@pytest.mark.parametrize(
32+
"output,expected_pattern",
33+
[
34+
({"loss": 0.456, "accuracy": 0.789}, r"loss. *0.456.*accuracy. *0.789"),
35+
((0.456, 0.789), r"0.456.*0.789"),
36+
([0.456, 0.789], r"0.456.*0.789"),
37+
],
38+
)
39+
def test_output_formatting(mock_engine, fb_research_logger, output, expected_pattern):
40+
# Ensure the logger correctly formats and logs the output for each type
41+
mock_engine.state.output = output
42+
fb_research_logger.attach(mock_engine, name="Test", every=1)
43+
mock_engine.fire_event(Events.ITERATION_COMPLETED)
44+
45+
actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
46+
assert re.search(expected_pattern, actual_output)

0 commit comments

Comments
 (0)