diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 37a91bd98cdac..56a7b77dfac2b 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -22,28 +22,25 @@ from typing import Any, Dict, Optional, Union try: - from comet_ml import BaseExperiment as CometBaseExperiment + import comet_ml + +except ModuleNotFoundError: # pragma: no-cover + comet_ml = None + CometExperiment = None + CometExistingExperiment = None + CometOfflineExperiment = None + API = None + generate_guid = None +else: from comet_ml import ExistingExperiment as CometExistingExperiment from comet_ml import Experiment as CometExperiment from comet_ml import OfflineExperiment as CometOfflineExperiment - from comet_ml import generate_guid try: from comet_ml.api import API except ImportError: # pragma: no-cover # For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300 from comet_ml.papi import API # pragma: no-cover - from comet_ml.config import get_api_key, get_config -except ImportError: # pragma: no-cover - CometExperiment = None - CometExistingExperiment = None - CometOfflineExperiment = None - CometBaseExperiment = None - API = None - generate_guid = None - _COMET_AVAILABLE = False -else: - _COMET_AVAILABLE = True import torch from torch import is_tensor @@ -117,17 +114,17 @@ class CometLogger(LightningLoggerBase): """ def __init__( - self, - api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, - experiment_key: Optional[str] = None, - offline: bool = False, - **kwargs + self, + api_key: Optional[str] = None, + save_dir: Optional[str] = None, + project_name: Optional[str] = None, + rest_api_key: Optional[str] = None, + experiment_name: Optional[str] = None, + experiment_key: Optional[str] = None, + offline: bool = False, + **kwargs ): - if not _COMET_AVAILABLE: + if comet_ml is None: raise ImportError( "You want to use `comet_ml` logger which is not installed yet," " install it with `pip install comet-ml`." @@ -136,7 +133,7 @@ def __init__( self._experiment = None # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or get_api_key(None, get_config()) + api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) if api_key is not None and save_dir is not None: self.mode = "offline" if offline else "online" @@ -173,7 +170,7 @@ def __init__( @property @rank_zero_experiment - def experiment(self) -> CometBaseExperiment: + def experiment(self): r""" Actual Comet object. To use Comet features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. @@ -236,7 +233,6 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti metrics_without_epoch = metrics.copy() epoch = metrics_without_epoch.pop('epoch', None) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) def reset_experiment(self): @@ -284,7 +280,7 @@ def version(self) -> str: return self._future_experiment_key # Pre-generate an experiment key - self._future_experiment_key = generate_guid() + self._future_experiment_key = comet_ml.generate_guid() return self._future_experiment_key diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 0e1199e88d27a..0e7bbabaf4e9e 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -15,27 +15,24 @@ def _patch_comet_atexit(monkeypatch): monkeypatch.setattr(atexit, "register", lambda _: None) -def test_comet_logger_online(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_online(comet): """Test comet online with mocks.""" # Test api_key given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test both given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment - comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') - - # Test neither given - with pytest.raises(MisconfigurationException): - CometLogger(workspace='dummy-test', project_name='general') + comet_experiment.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test already exists with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing: @@ -61,52 +58,73 @@ def test_comet_logger_online(): api.assert_called_once_with('rest') -def test_comet_logger_experiment_name(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_no_api_key_given(comet): + """ Test that CometLogger fails to initialize if both api key and save_dir are missing. """ + with pytest.raises(MisconfigurationException): + comet.config.get_api_key.return_value = None + CometLogger(workspace='dummy-test', project_name='general') + + +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_experiment_name(comet): """Test that Comet Logger experiment name works correctly.""" api_key = "key" experiment_name = "My Name" # Test api_key given - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment: logger = CometLogger(api_key=api_key, experiment_name=experiment_name,) assert logger._experiment is None _ = logger.experiment - comet.assert_called_once_with(api_key=api_key, project_name=None) + comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet().set_name.assert_called_once_with(experiment_name) + comet_experiment().set_name.assert_called_once_with(experiment_name) -def test_comet_logger_dirs_creation(tmpdir, monkeypatch): +@patch('pytorch_lightning.loggers.comet.CometOfflineExperiment') +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): """ Test that the logger creates the folders and files in the right place. """ _patch_comet_atexit(monkeypatch) + comet.config.get_api_key.return_value = None + comet.generate_guid.return_value = "4321" + logger = CometLogger(project_name='test', save_dir=tmpdir) assert not os.listdir(tmpdir) assert logger.mode == 'offline' assert logger.save_dir == tmpdir + assert logger.name == 'test' + assert logger.version == "4321" _ = logger.experiment - version = logger.version - assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'} + + comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test') + + # mock return values of experiment + logger.experiment.id = '1' + logger.experiment.project_name = 'test' model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} -def test_comet_name_default(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_default(comet): """ Test that CometLogger.name don't create an Experiment and returns a default value. """ api_key = "key" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key) assert logger._experiment is None @@ -116,13 +134,14 @@ def test_comet_name_default(): assert logger._experiment is None -def test_comet_name_project_name(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_name_project_name(comet): """ Test that CometLogger.name does not create an Experiment and returns project name if passed. """ api_key = "key" project_name = "My Project Name" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, project_name=project_name) assert logger._experiment is None @@ -132,13 +151,15 @@ def test_comet_name_project_name(): assert logger._experiment is None -def test_comet_version_without_experiment(): +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_version_without_experiment(comet): """ Test that CometLogger.version does not create an Experiment. """ api_key = "key" experiment_name = "My Name" + comet.generate_guid.return_value = "1234" - with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: + with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, experiment_name=experiment_name) assert logger._experiment is None @@ -154,15 +175,16 @@ def test_comet_version_without_experiment(): logger.reset_experiment() - second_version = logger.version + second_version = logger.version == "1234" assert second_version is not None assert second_version != first_version -def test_comet_epoch_logging(tmpdir, monkeypatch): +@patch("pytorch_lightning.loggers.comet.CometExperiment") +@patch('pytorch_lightning.loggers.comet.comet_ml') +def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ _patch_comet_atexit(monkeypatch) - with patch("pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics") as log_metrics: - logger = CometLogger(project_name="test", save_dir=tmpdir) - logger.log_metrics({"test": 1, "epoch": 1}, step=123) - log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + logger = CometLogger(project_name="test", save_dir=tmpdir) + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)