Skip to content

Commit cd42f96

Browse files
committed
fbr logger: improve types and kwargs supported
1 parent 0ae7ce8 commit cd42f96

File tree

2 files changed

+107
-28
lines changed

2 files changed

+107
-28
lines changed

Diff for: ignite/handlers/fbresearch_logger.py

+57-26
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
"""FBResearch logger and its helper handlers."""
22

33
import datetime
4-
from typing import Any, Optional
5-
6-
# from typing import Any, Dict, Optional, Union
4+
import numbers
5+
from typing import Any, Callable, List, Optional
76

87
import torch
98

9+
from ignite import utils
1010
from ignite.engine import Engine, Events
1111
from ignite.handlers import Timer
12+
from ignite.handlers.utils import global_step_from_engine # noqa
1213

1314

1415
MB = 1024.0 * 1024.0
1516

17+
__all__ = ["FBResearchLogger", "global_step_from_engine"]
18+
19+
20+
def is_iterable(obj):
21+
try:
22+
iter(obj)
23+
return True
24+
except TypeError:
25+
return False
26+
1627

1728
class FBResearchLogger:
1829
"""Logs training and validation metrics for research purposes.
@@ -60,32 +71,32 @@ class FBResearchLogger:
6071
.. code-block:: text
6172
6273
2024-04-22 12:05:47,843 trainer INFO: Train: start epoch [1/4]
63-
2024-04-22 12:05:47,861 trainer INFO: Epoch [1/4] [20/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5999 Iter time: 0.0008 s Data prep time: 0.0000 s
64-
2024-04-22 12:05:47,877 trainer INFO: Epoch [1/4] [40/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9297 Iter time: 0.0008 s Data prep time: 0.0000 s
65-
2024-04-22 12:05:47,893 trainer INFO: Epoch [1/4] [60/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9985 Iter time: 0.0008 s Data prep time: 0.0000 s
66-
2024-04-22 12:05:47,910 trainer INFO: Epoch [1/4] [80/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9785 Iter time: 0.0008 s Data prep time: 0.0000 s
67-
2024-04-22 12:05:47,925 trainer INFO: Epoch [1/4] [100/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6211 Iter time: 0.0008 s Data prep time: 0.0000 s
74+
2024-04-22 12:05:47,861 trainer INFO: Epoch [1/4] [20/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5999 Iter time: 0.0008 s Data prep time: 0.0000 s
75+
2024-04-22 12:05:47,877 trainer INFO: Epoch [1/4] [40/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9297 Iter time: 0.0008 s Data prep time: 0.0000 s
76+
2024-04-22 12:05:47,893 trainer INFO: Epoch [1/4] [60/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9985 Iter time: 0.0008 s Data prep time: 0.0000 s
77+
2024-04-22 12:05:47,910 trainer INFO: Epoch [1/4] [80/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9785 Iter time: 0.0008 s Data prep time: 0.0000 s
78+
2024-04-22 12:05:47,925 trainer INFO: Epoch [1/4] [100/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6211 Iter time: 0.0008 s Data prep time: 0.0000 s
6879
2024-04-22 12:05:47,927 trainer INFO: Train: Epoch [1/4] Total time: 0:00:00 (0.0008 s / it)
6980
2024-04-22 12:05:47,930 trainer INFO: Train: start epoch [2/4]
70-
2024-04-22 12:05:47,949 trainer INFO: Epoch [2/4] [19/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5981 Iter time: 0.0009 s Data prep time: 0.0000 s
71-
2024-04-22 12:05:47,965 trainer INFO: Epoch [2/4] [39/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9013 Iter time: 0.0008 s Data prep time: 0.0000 s
72-
2024-04-22 12:05:47,981 trainer INFO: Epoch [2/4] [59/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9811 Iter time: 0.0008 s Data prep time: 0.0000 s
73-
2024-04-22 12:05:47,997 trainer INFO: Epoch [2/4] [79/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9434 Iter time: 0.0008 s Data prep time: 0.0000 s
74-
2024-04-22 12:05:48,016 trainer INFO: Epoch [2/4] [99/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6116 Iter time: 0.0008 s Data prep time: 0.0000 s
81+
2024-04-22 12:05:47,949 trainer INFO: Epoch [2/4] [19/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5981 Iter time: 0.0009 s Data prep time: 0.0000 s
82+
2024-04-22 12:05:47,965 trainer INFO: Epoch [2/4] [39/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9013 Iter time: 0.0008 s Data prep time: 0.0000 s
83+
2024-04-22 12:05:47,981 trainer INFO: Epoch [2/4] [59/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9811 Iter time: 0.0008 s Data prep time: 0.0000 s
84+
2024-04-22 12:05:47,997 trainer INFO: Epoch [2/4] [79/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9434 Iter time: 0.0008 s Data prep time: 0.0000 s
85+
2024-04-22 12:05:48,016 trainer INFO: Epoch [2/4] [99/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6116 Iter time: 0.0008 s Data prep time: 0.0000 s
7586
2024-04-22 12:05:48,017 trainer INFO: Train: Epoch [2/4] Total time: 0:00:00 (0.0009 s / it)
7687
2024-04-22 12:05:48,020 trainer INFO: Train: start epoch [3/4]
77-
2024-04-22 12:05:48,038 trainer INFO: Epoch [3/4] [18/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5972 Iter time: 0.0008 s Data prep time: 0.0000 s
78-
2024-04-22 12:05:48,055 trainer INFO: Epoch [3/4] [38/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8753 Iter time: 0.0008 s Data prep time: 0.0000 s
79-
2024-04-22 12:05:48,076 trainer INFO: Epoch [3/4] [58/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9657 Iter time: 0.0009 s Data prep time: 0.0000 s
80-
2024-04-22 12:05:48,092 trainer INFO: Epoch [3/4] [78/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9112 Iter time: 0.0008 s Data prep time: 0.0000 s
81-
2024-04-22 12:05:48,108 trainer INFO: Epoch [3/4] [98/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6035 Iter time: 0.0008 s Data prep time: 0.0000 s
88+
2024-04-22 12:05:48,038 trainer INFO: Epoch [3/4] [18/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5972 Iter time: 0.0008 s Data prep time: 0.0000 s
89+
2024-04-22 12:05:48,055 trainer INFO: Epoch [3/4] [38/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8753 Iter time: 0.0008 s Data prep time: 0.0000 s
90+
2024-04-22 12:05:48,076 trainer INFO: Epoch [3/4] [58/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9657 Iter time: 0.0009 s Data prep time: 0.0000 s
91+
2024-04-22 12:05:48,092 trainer INFO: Epoch [3/4] [78/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9112 Iter time: 0.0008 s Data prep time: 0.0000 s
92+
2024-04-22 12:05:48,108 trainer INFO: Epoch [3/4] [98/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.6035 Iter time: 0.0008 s Data prep time: 0.0000 s
8293
2024-04-22 12:05:48,109 trainer INFO: Train: Epoch [3/4] Total time: 0:00:00 (0.0009 s / it)
8394
2024-04-22 12:05:48,112 trainer INFO: Train: start epoch [4/4]
84-
2024-04-22 12:05:48,129 trainer INFO: Epoch [4/4] [17/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5969 Iter time: 0.0008 s Data prep time: 0.0000 s
85-
2024-04-22 12:05:48,145 trainer INFO: Epoch [4/4] [37/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8516 Iter time: 0.0008 s Data prep time: 0.0000 s
86-
2024-04-22 12:05:48,161 trainer INFO: Epoch [4/4] [57/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9521 Iter time: 0.0008 s Data prep time: 0.0000 s
87-
2024-04-22 12:05:48,181 trainer INFO: Epoch [4/4] [77/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8816 Iter time: 0.0008 s Data prep time: 0.0000 s
88-
2024-04-22 12:05:48,205 trainer INFO: Epoch [4/4] [97/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5966 Iter time: 0.0009 s Data prep time: 0.0000 s
95+
2024-04-22 12:05:48,129 trainer INFO: Epoch [4/4] [17/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5969 Iter time: 0.0008 s Data prep time: 0.0000 s
96+
2024-04-22 12:05:48,145 trainer INFO: Epoch [4/4] [37/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8516 Iter time: 0.0008 s Data prep time: 0.0000 s
97+
2024-04-22 12:05:48,161 trainer INFO: Epoch [4/4] [57/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.9521 Iter time: 0.0008 s Data prep time: 0.0000 s
98+
2024-04-22 12:05:48,181 trainer INFO: Epoch [4/4] [77/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.8816 Iter time: 0.0008 s Data prep time: 0.0000 s
99+
2024-04-22 12:05:48,205 trainer INFO: Epoch [4/4] [97/100]: ETA: 0:00:00 lr: 0.00100 total_loss: 1.5966 Iter time: 0.0009 s Data prep time: 0.0000 s
89100
2024-04-22 12:05:48,207 trainer INFO: Train: Epoch [4/4] Total time: 0:00:00 (0.0009 s / it)
90101
2024-04-22 12:05:48,209 trainer INFO: Train: run completed Total time: 0:00:00
91102
"""
@@ -98,16 +109,27 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False
98109
self.show_output: bool = show_output
99110

100111
def attach(
101-
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
112+
self,
113+
engine: Engine,
114+
name: str,
115+
every: int = 1,
116+
output_transform: Optional[Callable] = None,
117+
state_attributes: Optional[List[str]] = None,
118+
optimizer: Optional[torch.optim.Optimizer] = None,
102119
) -> None:
103120
"""Attaches all the logging handlers to the given engine.
104121
105122
Args:
106123
engine: The engine to attach the logging handlers to.
107124
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
108125
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
126+
output_transform: A function to select the value to log.
127+
state_attributes: A list of attributes to log.
109128
optimizer: The optimizer used during training to log current learning rates.
110129
"""
130+
self.name = name
131+
self.output_transform = output_transform
132+
self.state_attributes = state_attributes
111133
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
112134
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
113135
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
@@ -151,10 +173,15 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
151173
outputs = []
152174
if self.show_output and engine.state.output is not None:
153175
output = engine.state.output
154-
if isinstance(output, dict):
176+
if self.output_transform is not None:
177+
outputs.append(str(self.output_transform(output)))
178+
elif isinstance(output, dict):
155179
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
180+
elif isinstance(output, str):
181+
outputs.append(output)
156182
else:
157-
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
183+
# allow numbers or nested iterables of numbers
184+
outputs.extend(utils.apply_to_type(output, numbers.Number, lambda x: f"{x:.4f}"))
158185

159186
lrs = ""
160187
if optimizer is not None:
@@ -164,6 +191,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
164191
for i, g in enumerate(optimizer.param_groups):
165192
lrs += f"lr [g{i}]: {g['lr']:.5f}"
166193

194+
state_attrs = []
195+
if self.state_attributes is not None:
196+
state_attrs.append(str({name: getattr(engine.state, name, None) for name in self.state_attributes}))
167197
msg = self.delimiter.join(
168198
[
169199
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
@@ -172,6 +202,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
172202
f"{lrs}",
173203
]
174204
+ outputs
205+
+ state_attrs
175206
+ [
176207
f"Iter time: {iter_avg_time:.4f} s",
177208
f"Data prep time: {self.data_timer.value():.4f} s",

Diff for: tests/ignite/handlers/test_fbresearch_logger.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
from unittest.mock import MagicMock
44

55
import pytest
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
69

7-
from ignite.engine import Engine, Events
8-
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
10+
from ignite.engine import create_supervised_trainer, Engine, Events
11+
from ignite.handlers.fbresearch_logger import FBResearchLogger
12+
from ignite.utils import setup_logger
913

1014

1115
@pytest.fixture
@@ -56,3 +60,47 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat
5660

5761
actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
5862
assert re.search(expected_pattern, actual_output)
63+
64+
65+
def test_logger_type_support():
66+
model = nn.Linear(10, 5)
67+
opt = optim.SGD(model.parameters(), lr=0.001)
68+
criterion = nn.CrossEntropyLoss()
69+
70+
data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)]
71+
72+
trainer = create_supervised_trainer(model, opt, criterion)
73+
74+
logger = setup_logger("trainer", level=logging.INFO)
75+
logger = FBResearchLogger(logger=logger, show_output=True)
76+
logger.attach(trainer, name="Train", every=20, optimizer=opt)
77+
78+
trainer.run(data, max_epochs=4)
79+
trainer.state.output = {"loss": 4.2}
80+
trainer.fire_event(Events.ITERATION_COMPLETED)
81+
trainer.state.output = "4.2"
82+
trainer.fire_event(Events.ITERATION_COMPLETED)
83+
trainer.state.output = [4.2, 4.2]
84+
trainer.fire_event(Events.ITERATION_COMPLETED)
85+
trainer.state.output = (4.2, 4.2)
86+
trainer.fire_event(Events.ITERATION_COMPLETED)
87+
88+
89+
def test_fbrlogger_with_output_transform(mock_logger):
90+
trainer = Engine(lambda e, b: 42)
91+
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
92+
fbr.attach(trainer, "Training", output_transform=lambda x: {"loss": x})
93+
trainer.run(data=[10], epoch_length=1, max_epochs=1)
94+
assert "{'loss': 42}" in fbr.logger.info.call_args_list[-2].args[0]
95+
96+
97+
def test_fbrlogger_with_state_attrs(mock_logger):
98+
trainer = Engine(lambda e, b: 42)
99+
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
100+
fbr.attach(trainer, "Training", state_attributes=["alpha", "beta", "gamma"])
101+
trainer.state.alpha = 3.899
102+
trainer.state.beta = torch.tensor(12.21)
103+
trainer.state.gamma = torch.tensor([21.0, 6.0])
104+
trainer.run(data=[10], epoch_length=1, max_epochs=1)
105+
attrs = "{'alpha': 3.899, 'beta': 12.2100, 'gamma': [21., 6.]}"
106+
assert attrs in fbr.logger.info.call_args_list[-2].args[0]

0 commit comments

Comments
 (0)