Skip to content

Commit 72377f3

Browse files
committed
fbr logger: improve types and kwargs supported
1 parent 24e71af commit 72377f3

File tree

4 files changed

+173
-11
lines changed

4 files changed

+173
-11
lines changed

Diff for: ignite/handlers/fbresearch_logger.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""FBResearch logger and its helper handlers."""
22

33
import datetime
4-
from typing import Any, Optional
5-
6-
# from typing import Any, Dict, Optional, Union
4+
from typing import Any, Callable, List, Optional
75

86
import torch
97

8+
from ignite import utils
109
from ignite.engine import Engine, Events
1110
from ignite.handlers import Timer
11+
from ignite.handlers.utils import global_step_from_engine # noqa
1212

1313

1414
MB = 1024.0 * 1024.0
1515

16+
__all__ = ["FBResearchLogger", "global_step_from_engine"]
17+
1618

1719
class FBResearchLogger:
1820
"""Logs training and validation metrics for research purposes.
@@ -98,16 +100,27 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False
98100
self.show_output: bool = show_output
99101

100102
def attach(
101-
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
103+
self,
104+
engine: Engine,
105+
name: str,
106+
every: int = 1,
107+
output_transform: Optional[Callable] = None,
108+
state_attributes: Optional[List[str]] = None,
109+
optimizer: Optional[torch.optim.Optimizer] = None,
102110
) -> None:
103111
"""Attaches all the logging handlers to the given engine.
104112
105113
Args:
106114
engine: The engine to attach the logging handlers to.
107115
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
108116
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
117+
output_transform: A function to select the value to log.
118+
state_attributes: A list of attributes to log.
109119
optimizer: The optimizer used during training to log current learning rates.
110120
"""
121+
self.name = name
122+
self.output_transform = output_transform
123+
self.state_attributes = state_attributes
111124
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
112125
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
113126
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
@@ -151,10 +164,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
151164
outputs = []
152165
if self.show_output and engine.state.output is not None:
153166
output = engine.state.output
154-
if isinstance(output, dict):
155-
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
156-
else:
157-
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
167+
if self.output_transform is not None:
168+
output = self.output_transform(output)
169+
outputs = utils.flatten_format_and_include_keys(output)
158170

159171
lrs = ""
160172
if optimizer is not None:
@@ -164,6 +176,11 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
164176
for i, g in enumerate(optimizer.param_groups):
165177
lrs += f"lr [g{i}]: {g['lr']:.5f}"
166178

179+
state_attrs = []
180+
if self.state_attributes is not None:
181+
state_attrs = utils.flatten_format_and_include_keys(
182+
{name: getattr(engine.state, name, None) for name in self.state_attributes}
183+
)
167184
msg = self.delimiter.join(
168185
[
169186
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
@@ -172,6 +189,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
172189
f"{lrs}",
173190
]
174191
+ outputs
192+
+ [" ".join(state_attrs)]
175193
+ [
176194
f"Iter time: {iter_avg_time:.4f} s",
177195
f"Data prep time: {self.data_timer.value():.4f} s",

Diff for: ignite/utils.py

+67
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import hashlib
44
import logging
5+
import numbers
56
import random
67
import shutil
78
import warnings
@@ -14,6 +15,7 @@
1415
"convert_tensor",
1516
"apply_to_tensor",
1617
"apply_to_type",
18+
"flatten_format_and_include_keys",
1719
"to_onehot",
1820
"setup_logger",
1921
"manual_seed",
@@ -90,6 +92,71 @@ def _tree_map(
9092
return func(x, key=key)
9193

9294

95+
def flatten_format_and_include_keys(data: Any) -> List[str]:
96+
"""
97+
Recursively formats and flattens complex data structures, including keys for dictionaries.
98+
99+
This function processes nested dictionaries, lists, tuples, numbers, and PyTorch tensors,
100+
formatting numbers to four decimal places and handling tensors with special formatting rules.
101+
It's particularly useful for logging, debugging, or any scenario where a human-readable
102+
representation of complex, nested data structures is required.
103+
104+
The function handles the following types:
105+
- Numbers: Formatted to four decimal places.
106+
- PyTorch tensors:
107+
- Scalars are formatted to four decimal places.
108+
- 1D tensors with more than 10 elements show the first 10 elements followed by an ellips
109+
is.
110+
- 1D tensors with 10 or fewer elements are fully listed.
111+
- Multi-dimensional tensors display their shape.
112+
- Dictionaries: Each key-value pair is included in the output with the key as a prefix.
113+
- Lists and tuples: Flattened and included in the output. Empty lists/tuples are represented by an empty string.
114+
- None values: Represented by an empty string.
115+
116+
Args:
117+
data: The input data to be flattened and formatted. It can be a nested combination of
118+
dictionaries, lists, tuples, numbers, and PyTorch tensors.
119+
120+
Returns:
121+
A list of formatted strings, each representing a part of the input data structure.
122+
"""
123+
formatted_items: List[str] = []
124+
125+
def format_item(item: Any, prefix: str = "") -> Optional[str]:
126+
if isinstance(item, numbers.Number):
127+
return f"{prefix}{item:.4f}"
128+
elif torch.is_tensor(item):
129+
if item.dim() == 0:
130+
return f"{prefix}{item.item():.4f}" # Format scalar tensor without brackets
131+
elif item.dim() == 1 and item.size(0) > 10:
132+
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item[:10]) + ", ...]"
133+
elif item.dim() == 1:
134+
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]"
135+
else:
136+
return f"{prefix}Shape: {item.shape}"
137+
elif isinstance(item, dict):
138+
for key, value in item.items():
139+
formatted_value = format_item(value, f"{key}: ")
140+
if formatted_value is not None:
141+
formatted_items.append(formatted_value)
142+
elif isinstance(item, (list, tuple)):
143+
if not item:
144+
if prefix:
145+
formatted_items.append(f"{prefix}")
146+
else:
147+
values = [format_item(x) for x in item]
148+
values_str = [v for v in values if v is not None]
149+
if values_str:
150+
formatted_items.append(f"{prefix}" + ", ".join(values_str))
151+
elif item is None:
152+
if prefix:
153+
formatted_items.append(f"{prefix}")
154+
return None
155+
156+
format_item(data)
157+
return formatted_items
158+
159+
93160
class _CollectionItem:
94161
types_as_collection_item: Tuple = (int, float, torch.Tensor)
95162

Diff for: tests/ignite/handlers/test_fbresearch_logger.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
from unittest.mock import MagicMock
44

55
import pytest
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
69

7-
from ignite.engine import Engine, Events
8-
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
10+
from ignite.engine import create_supervised_trainer, Engine, Events
11+
from ignite.handlers.fbresearch_logger import FBResearchLogger
12+
from ignite.utils import setup_logger
913

1014

1115
@pytest.fixture
@@ -56,3 +60,47 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat
5660

5761
actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
5862
assert re.search(expected_pattern, actual_output)
63+
64+
65+
def test_logger_type_support():
66+
model = nn.Linear(10, 5)
67+
opt = optim.SGD(model.parameters(), lr=0.001)
68+
criterion = nn.CrossEntropyLoss()
69+
70+
data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)]
71+
72+
trainer = create_supervised_trainer(model, opt, criterion)
73+
74+
logger = setup_logger("trainer", level=logging.INFO)
75+
logger = FBResearchLogger(logger=logger, show_output=True)
76+
logger.attach(trainer, name="Train", every=20, optimizer=opt)
77+
78+
trainer.run(data, max_epochs=4)
79+
trainer.state.output = {"loss": 4.2}
80+
trainer.fire_event(Events.ITERATION_COMPLETED)
81+
trainer.state.output = "4.2"
82+
trainer.fire_event(Events.ITERATION_COMPLETED)
83+
trainer.state.output = [4.2, 4.2]
84+
trainer.fire_event(Events.ITERATION_COMPLETED)
85+
trainer.state.output = (4.2, 4.2)
86+
trainer.fire_event(Events.ITERATION_COMPLETED)
87+
88+
89+
def test_fbrlogger_with_output_transform(mock_logger):
90+
trainer = Engine(lambda e, b: 42)
91+
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
92+
fbr.attach(trainer, "Training", output_transform=lambda x: {"loss": x})
93+
trainer.run(data=[10], epoch_length=1, max_epochs=1)
94+
assert "loss: 42.0000" in fbr.logger.info.call_args_list[-2].args[0]
95+
96+
97+
def test_fbrlogger_with_state_attrs(mock_logger):
98+
trainer = Engine(lambda e, b: 42)
99+
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
100+
fbr.attach(trainer, "Training", state_attributes=["alpha", "beta", "gamma"])
101+
trainer.state.alpha = 3.899
102+
trainer.state.beta = torch.tensor(12.21)
103+
trainer.state.gamma = torch.tensor([21.0, 6.0])
104+
trainer.run(data=[10], epoch_length=1, max_epochs=1)
105+
attrs = "alpha: 3.8990 beta: 12.2100 gamma: [21.0000, 6.0000]"
106+
assert attrs in fbr.logger.info.call_args_list[-2].args[0]

Diff for: tests/ignite/test_utils.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from packaging.version import Version
99

1010
from ignite.engine import Engine, Events
11-
from ignite.utils import convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot
11+
from ignite.utils import (
12+
convert_tensor,
13+
deprecated,
14+
flatten_format_and_include_keys,
15+
hash_checkpoint,
16+
setup_logger,
17+
to_onehot,
18+
)
1219

1320

1421
def test_convert_tensor():
@@ -55,6 +62,28 @@ def test_convert_tensor():
5562
convert_tensor(12345)
5663

5764

65+
@pytest.mark.parametrize(
66+
"input_data,expected",
67+
[
68+
([{"a": 15, "b": torch.tensor([2.0])}], ["a: 15.0000", "b: [2.0000]"]),
69+
({"a": 10, "b": 2.33333}, ["a: 10.0000", "b: 2.3333"]),
70+
({"x": torch.tensor(0.1234), "y": [1, 2.3567]}, ["x: 0.1234", "y: 1.0000, 2.3567"]),
71+
(({"nested": [3.1415, torch.tensor(0.0001)]},), ["nested: 3.1415, 0.0001"]),
72+
(
73+
{"large_vector": torch.tensor(range(20))},
74+
["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"],
75+
),
76+
({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: torch.Size([5, 5])"]),
77+
({"empty": []}, ["empty: "]),
78+
([], []),
79+
({"none": None}, ["none: "]),
80+
({1: 100, 2: 200}, ["1: 100.0000", "2: 200.0000"]),
81+
],
82+
)
83+
def test_flatten_format_and_include_keys(input_data, expected):
84+
assert flatten_format_and_include_keys(input_data) == expected
85+
86+
5887
def test_to_onehot():
5988
indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
6089
actual = to_onehot(indices, 4)

0 commit comments

Comments
 (0)