diff --git a/ignite/contrib/handlers/mlflow_logger.py b/ignite/contrib/handlers/mlflow_logger.py index 45dfad2cafa..61cb590c077 100644 --- a/ignite/contrib/handlers/mlflow_logger.py +++ b/ignite/contrib/handlers/mlflow_logger.py @@ -180,6 +180,20 @@ def global_step_transform(*args, **kwargs): global_step_transform=global_step_transform ) + Another example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta`` + are also logged along with the NLL and Accuracy after each iteration: + + .. code-block:: python + + mlflow_logger.attach_output_handler( + trainer, + event_name=Events.ITERATION_COMPLETED, + tag="training", + metrics=["nll", "accuracy"], + state_attributes=["alpha", "beta"], + ) + + Args: tag: common title for all produced plots. For example, 'training' metric_names: list of metric names to plot or a string "all" to plot all available @@ -193,6 +207,7 @@ def global_step_transform(*args, **kwargs): Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use :meth:`~ignite.contrib.handlers.mlflow_logger.global_step_from_engine`. + state_attributes: list of attributes of the ``trainer.state`` to plot. Note: @@ -203,6 +218,8 @@ def global_step_transform(*args, **kwargs): def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name) + .. versionchanged:: 0.5.0 + accepts an optional list of `state_attributes` """ def __init__( @@ -211,8 +228,11 @@ def __init__( metric_names: Optional[Union[str, List[str]]] = None, output_transform: Optional[Callable] = None, global_step_transform: Optional[Callable] = None, + state_attributes: Optional[List[str]] = None, ) -> None: - super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform) + super(OutputHandler, self).__init__( + tag, metric_names, output_transform, global_step_transform, state_attributes + ) def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None: diff --git a/tests/ignite/contrib/handlers/test_mlflow_logger.py b/tests/ignite/contrib/handlers/test_mlflow_logger.py index 2b225641037..d907b9a873c 100644 --- a/tests/ignite/contrib/handlers/test_mlflow_logger.py +++ b/tests/ignite/contrib/handlers/test_mlflow_logger.py @@ -188,6 +188,25 @@ def test_output_handler_with_global_step_from_engine(): ) +def test_output_handler_state_attrs(): + wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) + mock_logger = MagicMock(spec=MLflowLogger) + mock_logger.log_metrics = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.iteration = 5 + mock_engine.state.alpha = 3.899 + mock_engine.state.beta = torch.tensor(12.21) + mock_engine.state.gamma = torch.tensor([21.0, 6.0]) + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + mock_logger.log_metrics.assert_called_once_with( + {"tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0,}, step=5, + ) + + def test_optimizer_params_handler_wrong_setup(): with pytest.raises(TypeError):