Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comet fix #481

Merged
merged 8 commits into from
Nov 12, 2019
5 changes: 4 additions & 1 deletion docs/Trainer/Logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
api_key=os.environ["COMET_KEY"],
workspace=os.environ["COMET_KEY"],
workspace=os.environ["COMET_WORKSPACE"],
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
Expand Down
86 changes: 80 additions & 6 deletions pytorch_lightning/logging/comet_logger.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,99 @@
from logging import getLogger

try:
from comet_ml import Experiment as CometExperiment
from comet_ml.papi import API
except ImportError:
raise ImportError('Missing comet_ml package.')

from torch import is_tensor

from .base import LightningLoggerBase, rank_zero_only

logger = getLogger(__name__)


class CometLogger(LightningLoggerBase):
def __init__(self, *args, **kwargs):
super(CometLogger, self).__init__()
self.experiment = CometExperiment(*args, **kwargs)
def __init__(self, api_key, workspace, rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger

:param str api_key: API key, found on Comet.ml
:param str workspace: Name of workspace for this user
:param str project_name: Optional. Send your experiment to a specific project.
Otherwise will be sent to Uncategorized Experiments.
If project name does not already exists Comet.ml will create a new project.
:param str rest_api_key: Optional. Rest API key found in Comet.ml settings.
This is used to determine version number
:param str experiment_name: Optional. String representing the name for this particular experiment on Comet.ml
"""
super().__init__()
self._experiment = None

self.api_key = api_key
self.workspace = workspace
self.project_name = project_name

self._kwargs = kwargs

if rest_api_key is not None:
# Comet.ml rest API, used to determine version number
self.rest_api_key = rest_api_key
self.comet_api = API(self.rest_api_key)
else:
self.rest_api_key = None
self.comet_api = None

if experiment_name:
try:
self._set_experiment_name(experiment_name)
except TypeError as e:
logger.exception("Failed to set experiment name for comet.ml logger")

@property
def experiment(self):
if self._experiment is not None:
return self._experiment

self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)

return self._experiment

@rank_zero_only
def log_hyperparams(self, params):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step_num):
# self.experiment.set_epoch(self, metrics.get('epoch', 0))
self.experiment.log_metrics(metrics)
def log_metrics(self, metrics, step_num=None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
metrics[key] = val.cpu().detach()

self.experiment.log_metrics(metrics, step=step_num)

@rank_zero_only
def finalize(self, status):
self.experiment.end()

@rank_zero_only
def _set_experiment_name(self, experiment_name):
self.experiment.set_name(experiment_name)

@property
def name(self):
return self.experiment.project_name

@property
def version(self):
if self.project_name and self.rest_api_key:
# Determines the number of experiments in this project, and returns the next integer as the version number
nb_exps = len(self.comet_api.get_experiments(self.workspace, self.project_name))
return nb_exps + 1
else:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if returning None as a version is a good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, logger.version is only being used in the tqdm progress bar. If version is None, it is ignored, otherwise the version is displayed in the progress bar.

If there are future plans to use this version elsewhere, I'm open to suggestions for a different default return value.

66 changes: 66 additions & 0 deletions tests/test_y_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,72 @@ def test_mlflow_pickle():
testing_utils.clear_save_dir()


def test_comet_logger():
"""
verify that basic functionality of Comet.ml logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

# API key for dummy Comet.ml account
logger = CometLogger(
api_key="KnmgASRHHyxWXOpwUfgrAFz8C",
project_name="general",
workspace="dummy-test",
)

trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.01,
logger=logger
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

print('result finished')
assert result == 1, "Training failed"


def test_comet_pickle():
"""
verify that pickling trainer with mlflow logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

# API key for dummy Comet.ml account
logger = CometLogger(
api_key="KnmgASRHHyxWXOpwUfgrAFz8C",
project_name="general",
workspace="dummy-test"
)

trainer_options = dict(
max_nb_epochs=1,
logger=logger
)

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})


def test_custom_logger(tmpdir):

class CustomLogger(LightningLoggerBase):
Expand Down