Skip to content

Commit

Permalink
Comet fix (#481)
Browse files Browse the repository at this point in the history
* Fixing comet ml bug and adding functionality

* Updating documents

* Fixing code style issues in comet_logger

* Changing comet_logger experiment to execute lazily

* Adding tests for comet_logger and addressing comments from @Borda

* Setting step_num to optional keyword argument in log_metrics() to comply to other loggers

* Adding offline logging mode for comet_ml, updating tests and docs

* Switching to MisconfigurationException
  • Loading branch information
rwesterman authored and williamFalcon committed Nov 12, 2019
1 parent ba0a32c commit d1b6b01
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 8 deletions.
1 change: 1 addition & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
rm -rf _ckpt_*
rm -rf tests/save_dir*
rm -rf tests/mlruns_*
rm -rf tests/cometruns*
rm -rf tests/tests/*
rm -rf lightning_logs
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules
Expand Down
20 changes: 19 additions & 1 deletion docs/Trainer/Logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,30 @@ def any_lightning_module_function_or_hook(...):

Log using [comet](https://www.comet.ml)

Comet logger can be used in either online or offline mode.
To log in online mode, CometLogger requries an API key:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
api_key=os.environ["COMET_KEY"],
workspace=os.environ["COMET_KEY"],
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
To log in offline mode, CometLogger requires a path to a local directory:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
save_dir=".",
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
Expand Down
113 changes: 107 additions & 6 deletions pytorch_lightning/logging/comet_logger.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,126 @@
from logging import getLogger

try:
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml.papi import API
except ImportError:
raise ImportError('Missing comet_ml package.')

from torch import is_tensor

from .base import LightningLoggerBase, rank_zero_only
from ..utilities.debugging import MisconfigurationException

logger = getLogger(__name__)


class CometLogger(LightningLoggerBase):
def __init__(self, *args, **kwargs):
super(CometLogger, self).__init__()
self.experiment = CometExperiment(*args, **kwargs)
def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode)
:param str api_key: Required in online mode. API key, found on Comet.ml
:param str save_dir: Required in offline mode. The path for the directory to save local comet logs
:param str workspace: Optional. Name of workspace for this user
:param str project_name: Optional. Send your experiment to a specific project.
Otherwise will be sent to Uncategorized Experiments.
If project name does not already exists Comet.ml will create a new project.
:param str rest_api_key: Optional. Rest API key found in Comet.ml settings.
This is used to determine version number
:param str experiment_name: Optional. String representing the name for this particular experiment on Comet.ml
"""
super().__init__()
self._experiment = None

# Determine online or offline mode based on which arguments were passed to CometLogger
if save_dir is not None and api_key is not None:
# If arguments are passed for both save_dir and api_key, preference is given to online mode
self.mode = "online"
self.api_key = api_key
elif api_key is not None:
self.mode = "online"
self.api_key = api_key
elif save_dir is not None:
self.mode = "offline"
self.save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")

logger.info(f"CometLogger will be initialized in {self.mode} mode")

self.workspace = workspace
self.project_name = project_name
self._kwargs = kwargs

if rest_api_key is not None:
# Comet.ml rest API, used to determine version number
self.rest_api_key = rest_api_key
self.comet_api = API(self.rest_api_key)
else:
self.rest_api_key = None
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError as e:
logger.exception("Failed to set experiment name for comet.ml logger")

@property
def experiment(self):
if self._experiment is not None:
return self._experiment

if self.mode == "online":
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)

return self._experiment

@rank_zero_only
def log_hyperparams(self, params):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step_num):
# self.experiment.set_epoch(self, metrics.get('epoch', 0))
self.experiment.log_metrics(metrics)
def log_metrics(self, metrics, step_num=None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
metrics[key] = val.cpu().detach()

self.experiment.log_metrics(metrics, step=step_num)

@rank_zero_only
def finalize(self, status):
self.experiment.end()

@property
def name(self):
return self.experiment.project_name

@name.setter
def name(self, value):
self.experiment.set_name(value)

@property
def version(self):
if self.project_name and self.rest_api_key:
# Determines the number of experiments in this project, and returns the next integer as the version number
nb_exps = len(self.comet_api.get_experiments(self.workspace, self.project_name))
return nb_exps + 1
else:
return None
77 changes: 76 additions & 1 deletion tests/test_y_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,83 @@ def test_mlflow_pickle():
testing_utils.clear_save_dir()


def test_custom_logger(tmpdir):
def test_comet_logger():
"""
verify that basic functionality of Comet.ml logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name="general",
workspace="dummy-test",
)

trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.01,
logger=logger
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

print('result finished')
assert result == 1, "Training failed"

testing_utils.clear_save_dir()


def test_comet_pickle():
"""
verify that pickling trainer with comet logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name="general",
workspace="dummy-test",
)

trainer_options = dict(
max_nb_epochs=1,
logger=logger
)

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

testing_utils.clear_save_dir()


def test_custom_logger(tmpdir):
class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit d1b6b01

Please sign in to comment.