Skip to content

Commit

Permalink
Mocking Loggers (part 4a, mlflow) (#3884)
Browse files Browse the repository at this point in the history
* extensive mlflow test

* revert accidental commits
  • Loading branch information
awaelchli authored Oct 6, 2020
1 parent b34c7ad commit 0823cdd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
4 changes: 1 addition & 3 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@
try:
import mlflow
from mlflow.tracking import MlflowClient
_MLFLOW_AVAILABLE = True
except ModuleNotFoundError: # pragma: no-cover
mlflow = None
MlflowClient = None
_MLFLOW_AVAILABLE = False


from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -83,7 +81,7 @@ def __init__(
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns'
):
if not _MLFLOW_AVAILABLE:
if mlflow is None:
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
' install it with `pip install mlflow`.')
super().__init__()
Expand Down
54 changes: 49 additions & 5 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,66 @@
import os

from unittest import mock
from unittest.mock import MagicMock

from mlflow.tracking import MlflowClient

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
from tests.base import EvalModelTemplate


def test_mlflow_logger_exists(tmpdir):
""" Test launching two independent loggers. """
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
def test_mlflow_logger_exists(client, mlflow, tmpdir):
""" Test launching three independent loggers with either same or different experiment name. """

run1 = MagicMock()
run1.info.run_id = "run-id-1"

run2 = MagicMock()
run2.info.run_id = "run-id-2"

run3 = MagicMock()
run3.info.run_id = "run-id-3"

# simulate non-existing experiment creation
client.return_value.get_experiment_by_name = MagicMock(return_value=None)
client.return_value.create_experiment = MagicMock(return_value="exp-id-1") # experiment_id
client.return_value.create_run = MagicMock(return_value=run1)

logger = MLFlowLogger('test', save_dir=tmpdir)
assert logger._experiment_id is None
assert logger._run_id is None
_ = logger.experiment
assert logger.experiment_id == "exp-id-1"
assert logger.run_id == "run-id-1"
assert logger.experiment.create_experiment.asset_called_once()
client.reset_mock(return_value=True)

# simulate existing experiment returns experiment id
exp1 = MagicMock()
exp1.experiment_id = "exp-id-1"
client.return_value.get_experiment_by_name = MagicMock(return_value=exp1)
client.return_value.create_run = MagicMock(return_value=run2)

# same name leads to same experiment id, but different runs get recorded
logger2 = MLFlowLogger('test', save_dir=tmpdir)
assert logger.experiment_id == logger2.experiment_id
assert logger.run_id != logger2.run_id
assert logger2.experiment_id == logger.experiment_id
assert logger2.run_id == "run-id-2"
assert logger2.experiment.create_experiment.call_count == 0
assert logger2.experiment.create_run.asset_called_once()
client.reset_mock(return_value=True)

# simulate a 3rd experiment with new name
client.return_value.get_experiment_by_name = MagicMock(return_value=None)
client.return_value.create_experiment = MagicMock(return_value="exp-id-3")
client.return_value.create_run = MagicMock(return_value=run3)

# logger with new experiment name causes new experiment id and new run id to be created
logger3 = MLFlowLogger('new', save_dir=tmpdir)
assert logger3.experiment_id != logger.experiment_id
assert logger3.experiment_id == "exp-id-3" != logger.experiment_id
assert logger3.run_id == "run-id-3"


def test_mlflow_logger_dirs_creation(tmpdir):
Expand Down

0 comments on commit 0823cdd

Please sign in to comment.