Skip to content

Commit

Permalink
Added state attributes for MLFlow Logger (#2160)
Browse files Browse the repository at this point in the history
* added state attributes for MLFlow Logger

* autopep8 fix

* added state attributes in args

Co-authored-by: Ishan-Kumar2 <[email protected]>
  • Loading branch information
Ishan-Kumar2 and Ishan-Kumar2 authored Aug 12, 2021
1 parent f2b812d commit 327cec5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
22 changes: 21 additions & 1 deletion ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,20 @@ def global_step_transform(*args, **kwargs):
global_step_transform=global_step_transform
)
Another example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
are also logged along with the NLL and Accuracy after each iteration:
.. code-block:: python
mlflow_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED,
tag="training",
metrics=["nll", "accuracy"],
state_attributes=["alpha", "beta"],
)
Args:
tag: common title for all produced plots. For example, 'training'
metric_names: list of metric names to plot or a string "all" to plot all available
Expand All @@ -193,6 +207,7 @@ def global_step_transform(*args, **kwargs):
Default is None, global_step based on attached engine. If provided,
uses function output as global_step. To setup global step from another engine, please use
:meth:`~ignite.contrib.handlers.mlflow_logger.global_step_from_engine`.
state_attributes: list of attributes of the ``trainer.state`` to plot.
Note:
Expand All @@ -203,6 +218,8 @@ def global_step_transform(*args, **kwargs):
def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)
.. versionchanged:: 0.5.0
accepts an optional list of `state_attributes`
"""

def __init__(
Expand All @@ -211,8 +228,11 @@ def __init__(
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
state_attributes: Optional[List[str]] = None,
) -> None:
super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform)
super(OutputHandler, self).__init__(
tag, metric_names, output_transform, global_step_transform, state_attributes
)

def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None:

Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/contrib/handlers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ def test_output_handler_with_global_step_from_engine():
)


def test_output_handler_state_attrs():
wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
mock_logger = MagicMock(spec=MLflowLogger)
mock_logger.log_metrics = MagicMock()

mock_engine = MagicMock()
mock_engine.state = State()
mock_engine.state.iteration = 5
mock_engine.state.alpha = 3.899
mock_engine.state.beta = torch.tensor(12.21)
mock_engine.state.gamma = torch.tensor([21.0, 6.0])

wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

mock_logger.log_metrics.assert_called_once_with(
{"tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0,}, step=5,
)


def test_optimizer_params_handler_wrong_setup():

with pytest.raises(TypeError):
Expand Down

0 comments on commit 327cec5

Please sign in to comment.