1
- import pytest
2
1
import logging
3
- from unittest .mock import MagicMock , _Call
2
+ import re
3
+ from unittest .mock import _Call , MagicMock
4
+
5
+ import pytest
6
+
4
7
from ignite .engine import Engine , Events
5
8
from ignite .handlers .fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
6
- import re
9
+
7
10
8
11
@pytest .fixture
9
12
def mock_engine ():
@@ -14,6 +17,7 @@ def mock_engine():
14
17
engine .state .iteration = 50
15
18
return engine
16
19
20
+
17
21
@pytest .fixture
18
22
def mock_logger ():
19
23
return MagicMock (spec = logging .Logger )
@@ -23,6 +27,7 @@ def mock_logger():
23
27
def fb_research_logger (mock_logger ):
24
28
yield FBResearchLogger (logger = mock_logger , show_output = True )
25
29
30
+
26
31
# Test logging with a dictionary output
27
32
def test_logging_dict_output (mock_engine , fb_research_logger ):
28
33
mock_engine .state .output = {"loss" : 0.456 , "accuracy" : 0.789 }
@@ -31,6 +36,7 @@ def test_logging_dict_output(mock_engine, fb_research_logger):
31
36
fb_research_logger .logger .info .assert_called_once ()
32
37
assert "accuracy: 0.7890" in fb_research_logger .logger .info .call_args_list [- 1 ].args [0 ]
33
38
39
+
34
40
# Test logging with a list output
35
41
def test_logging_list_output (mock_engine , fb_research_logger ):
36
42
mock_engine .state .output = [0.456 , 0.789 ]
@@ -39,6 +45,7 @@ def test_logging_list_output(mock_engine, fb_research_logger):
39
45
fb_research_logger .logger .info .assert_called_once ()
40
46
assert "0.456" in fb_research_logger .logger .info .call_args_list [- 1 ].args [0 ]
41
47
48
+
42
49
# Test logging with a tuple output
43
50
def test_logging_tuple_output (mock_engine , fb_research_logger ):
44
51
mock_engine .state .output = (0.456 , 0.789 )
@@ -49,20 +56,13 @@ def test_logging_tuple_output(mock_engine, fb_research_logger):
49
56
50
57
51
58
@pytest .mark .parametrize (
52
- "output,expected_pattern" ,
53
- [
54
- (
55
- {"loss" : 0.456 , "accuracy" : 0.789 },
56
- r"loss. *0.456.*accuracy. *0.789"
57
- ),
58
- (
59
- [0.456 , 0.789 ],
60
- r"0.456.*0.789"
61
-
62
- ),
63
- ]
59
+ "output,expected_pattern" ,
60
+ [
61
+ ({"loss" : 0.456 , "accuracy" : 0.789 }, r"loss. *0.456.*accuracy. *0.789" ),
62
+ ([0.456 , 0.789 ], r"0.456.*0.789" ),
63
+ ],
64
64
)
65
- def test_output_formatting (mock_engine , fb_research_logger ,output ,expected_pattern ):
65
+ def test_output_formatting (mock_engine , fb_research_logger , output , expected_pattern ):
66
66
# Ensure the logger correctly formats and logs the output for each type
67
67
mock_engine .state .output = output
68
68
fb_research_logger .attach (mock_engine , name = "Test" , every = 1 )
0 commit comments