From bfb8872126629e741cf39f91464dd2bec630fed1 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sat, 27 Feb 2021 11:55:44 -0600 Subject: [PATCH 01/46] feat(wandb): log models as artifacts --- pytorch_lightning/loggers/wandb.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 285388d6c6765..469a4794e6d9b 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -16,6 +16,8 @@ ------------------------- """ import os +import re +import numbers from argparse import Namespace from typing import Any, Dict, Optional, Union @@ -168,10 +170,6 @@ def experiment(self) -> Run: **self._kwargs ) if wandb.run is None else wandb.run - # save checkpoints in wandb dir to upload on W&B servers - if self._save_dir is None: - self._save_dir = self._experiment.dir - # define default x-axis (for latest wandb versions) if getattr(self._experiment, "define_metric", None): self._experiment.define_metric("trainer/global_step") @@ -215,6 +213,18 @@ def version(self) -> Optional[str]: @rank_zero_only def finalize(self, status: str) -> None: - # upload all checkpoints from saving dir + # save checkpoints as artifacts if self._log_model: - wandb.save(os.path.join(self.save_dir, "*.ckpt")) + # use run name and ensure it's a valid Artifact name + artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) + # gather interesting metadata + metadata = { + k: v + for k, v in dict(self.experiment.summary).items() + if isinstance(v, numbers.Number) and not k.startswith("_") + } + # TODO: see if we can also log data from `trainer.checkpoint_callback` (best_model_path, etc) + artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + # TODO: we need access to `trainer.checkpoint_callback.dirpath` + artifact.add_dir(trainer.checkpoint_callback.dirpath) + self.experiment.log_artifact(artifact) From 541b001b84c99039cdf9b8924299c3783d0f0a10 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 17:00:00 -0600 Subject: [PATCH 02/46] feat: add Logger.connect --- pytorch_lightning/loggers/base.py | 14 ++++++++++++++ .../logger_connector/logger_connector.py | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 4fdb5e8c437bf..981284f14278f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -26,6 +26,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning import Trainer def rank_zero_experiment(fn: Callable) -> Callable: @@ -71,6 +72,15 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func + def connect(self, trainer: Optional[Trainer] = None) -> None: + """ + Connect trainer to logger + + Args: + trainer: lightning Trainer + """ + pass + def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, @@ -355,6 +365,10 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] + def connect(self, trainer: Optional[Trainer] = None) -> None: + for logger in self._logger_iterable: + logger.connect(trainer) + def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8ebec3238e276..daa2b5599d329 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -143,6 +143,7 @@ def should_update_logs(self): return should_log_every_n_steps or self.trainer.should_stop def configure_logger(self, logger): + # connect logger to trainer if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) @@ -158,6 +159,9 @@ def configure_logger(self, logger): else: self.trainer.logger = logger + # connect trainer to logger + logger.connect(self.trainer) + def cache_training_step_metrics(self, opt_closure_result): """ This function is responsible to update From bbd86331bf42f1697e9038a661a43dbe0b08c337 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 19:13:56 -0600 Subject: [PATCH 03/46] fix: circular ref with type checking --- pytorch_lightning/loggers/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 981284f14278f..ec0bc46040322 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,14 +19,16 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union, TYPE_CHECKING import numpy as np import torch from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning import Trainer + +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer def rank_zero_experiment(fn: Callable) -> Callable: @@ -72,7 +74,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def connect(self, trainer: Optional[Trainer] = None) -> None: + def connect(self, trainer: Optional['Trainer'] = None) -> None: """ Connect trainer to logger @@ -365,7 +367,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def connect(self, trainer: Optional[Trainer] = None) -> None: + def connect(self, trainer: Optional['Trainer'] = None) -> None: for logger in self._logger_iterable: logger.connect(trainer) From 3365261c2dc9eba86f06d4fa57bcd6524b4369ca Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 20:48:37 -0600 Subject: [PATCH 04/46] feat(wandb): use connect method --- pytorch_lightning/loggers/wandb.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 469a4794e6d9b..825fe20b47ccd 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,7 +19,8 @@ import re import numbers from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from weakref import proxy import torch.nn as nn @@ -28,6 +29,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer + warning_cache = WarningCache() _WANDB_AVAILABLE = _module_available("wandb") @@ -134,6 +138,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._kwargs = kwargs + self._trainer = None def __getstate__(self): state = self.__dict__.copy() @@ -144,6 +149,10 @@ def __getstate__(self): state['_experiment'] = None return state + def connect(self, trainer: Optional['Trainer'] = None) -> None: + if trainer is not None: + self._trainer = proxy(trainer) + @property @rank_zero_experiment def experiment(self) -> Run: @@ -214,7 +223,7 @@ def version(self) -> Optional[str]: @rank_zero_only def finalize(self, status: str) -> None: # save checkpoints as artifacts - if self._log_model: + if self._log_model and self._trainer is not None and self._trainer.checkpoint_callback is not None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) # gather interesting metadata @@ -223,8 +232,14 @@ def finalize(self, status: str) -> None: for k, v in dict(self.experiment.summary).items() if isinstance(v, numbers.Number) and not k.startswith("_") } - # TODO: see if we can also log data from `trainer.checkpoint_callback` (best_model_path, etc) + metadata['ModelCheckpoint'] = {k: v for k, v in vars( + self._trainer.checkpoint_callback).items() if not callable(v)} artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) - # TODO: we need access to `trainer.checkpoint_callback.dirpath` - artifact.add_dir(trainer.checkpoint_callback.dirpath) + # add relevant checkpoints + for checkpoint in set([ + self._trainer.checkpoint_callback.best_model_path, + self._trainer.checkpoint_callback.last_model_path, + *self._trainer.checkpoint_callback.best_k_models.keys() + ]) - {''}: + artifact.add_file(checkpoint) self.experiment.log_artifact(artifact) From dfd7553bcfa9183017d368f8d68af81cafc3c9df Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 21:15:42 -0600 Subject: [PATCH 05/46] style: pep8 --- pytorch_lightning/loggers/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index ec0bc46040322..c3e9aadc076fc 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,7 +19,8 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union, TYPE_CHECKING +from typing import (Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, + Optional, Sequence, Tuple, Union, TYPE_CHECKING) import numpy as np import torch From 6950d3d61041edf0a10d60bfda10de02196509a7 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 21:18:44 -0600 Subject: [PATCH 06/46] fix(configure_logger): logger can be bool --- .../trainer/connectors/logger_connector/logger_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index daa2b5599d329..d8a47d38a8348 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -160,7 +160,8 @@ def configure_logger(self, logger): self.trainer.logger = logger # connect trainer to logger - logger.connect(self.trainer) + if hasattr(self.trainer.logger, 'connect'): + self.trainer.logger.connect(self.trainer) def cache_training_step_metrics(self, opt_closure_result): """ From f9cc20f075b825ff4ff2ee6b2be78e7f8bf2bc3e Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 28 Feb 2021 21:26:29 -0600 Subject: [PATCH 07/46] feat(connect): Trainer is not optional --- pytorch_lightning/loggers/base.py | 4 ++-- pytorch_lightning/loggers/wandb.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index c3e9aadc076fc..6c2b8e7b3d351 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -75,7 +75,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def connect(self, trainer: Optional['Trainer'] = None) -> None: + def connect(self, trainer: 'Trainer') -> None: """ Connect trainer to logger @@ -368,7 +368,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def connect(self, trainer: Optional['Trainer'] = None) -> None: + def connect(self, trainer: 'Trainer') -> None: for logger in self._logger_iterable: logger.connect(trainer) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 825fe20b47ccd..3e0e8044994a2 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -149,7 +149,7 @@ def __getstate__(self): state['_experiment'] = None return state - def connect(self, trainer: Optional['Trainer'] = None) -> None: + def connect(self, trainer: 'Trainer') -> None: if trainer is not None: self._trainer = proxy(trainer) From c518d710debad2e5fc989852ff0b19562a6c4774 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 10:48:17 -0600 Subject: [PATCH 08/46] feat(configure_logger): make trainer a proxy --- pytorch_lightning/loggers/wandb.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 3e0e8044994a2..b50bfe532f441 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -151,7 +151,7 @@ def __getstate__(self): def connect(self, trainer: 'Trainer') -> None: if trainer is not None: - self._trainer = proxy(trainer) + self._trainer = trainer @property @rank_zero_experiment diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d8a47d38a8348..a0ea4faddf258 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,6 +14,7 @@ import os from copy import deepcopy from pprint import pprint +from weakref import proxy from typing import Dict, Iterable, Optional, Union import torch @@ -161,7 +162,7 @@ def configure_logger(self, logger): # connect trainer to logger if hasattr(self.trainer.logger, 'connect'): - self.trainer.logger.connect(self.trainer) + self.trainer.logger.connect(proxy(self.trainer)) def cache_training_step_metrics(self, opt_closure_result): """ From 9b9aaa631c246d96d69c0db7adc55fa922d972a3 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 10:50:42 -0600 Subject: [PATCH 09/46] =?UTF-8?q?fix:=C2=A0unused=20import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/loggers/wandb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index b50bfe532f441..86499a5f1e1a0 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -20,7 +20,6 @@ import numbers from argparse import Namespace from typing import Any, Dict, Optional, Union, TYPE_CHECKING -from weakref import proxy import torch.nn as nn From eb2080dd3bc2acea1044d4e64320dc380f47a698 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 11:01:20 -0600 Subject: [PATCH 10/46] =?UTF-8?q?docs:=C2=A0more=20explicit=20doc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/loggers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 6c2b8e7b3d351..722c4207f1ae2 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -80,7 +80,7 @@ def connect(self, trainer: 'Trainer') -> None: Connect trainer to logger Args: - trainer: lightning Trainer + trainer: the trainer instance to connect to """ pass From 7d98a99ac566ba29baf7ba630d583946bba60b91 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 11:32:00 -0600 Subject: [PATCH 11/46] doc: update docstring --- pytorch_lightning/loggers/wandb.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 86499a5f1e1a0..6cf94bfa3fb51 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -61,7 +61,7 @@ class WandbLogger(LightningLoggerBase): version: Same as id. anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. - log_model: Save checkpoints in wandb dir to upload on W&B servers. + log_model: Save checkpoints as W&B artifacts. prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by @@ -80,9 +80,6 @@ class WandbLogger(LightningLoggerBase): wandb_logger = WandbLogger() trainer = Trainer(logger=wandb_logger) - Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, - make sure to use `commit=False` so the logging step does not increase. - See Also: - `Tutorial `__ on how to use W&B with PyTorch Lightning From a6ad9aa3e9e9f08c2909a9d1327294a813314689 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 12:59:24 -0600 Subject: [PATCH 12/46] feat: ModelCheckpoint metadata --- pytorch_lightning/loggers/wandb.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 6cf94bfa3fb51..b2aafb2e642d0 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,6 +19,7 @@ import re import numbers from argparse import Namespace +from pathlib import Path from typing import Any, Dict, Optional, Union, TYPE_CHECKING import torch.nn as nn @@ -222,14 +223,23 @@ def finalize(self, status: str) -> None: if self._log_model and self._trainer is not None and self._trainer.checkpoint_callback is not None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) - # gather interesting metadata + # gather summary metrics metadata = { k: v for k, v in dict(self.experiment.summary).items() if isinstance(v, numbers.Number) and not k.startswith("_") } - metadata['ModelCheckpoint'] = {k: v for k, v in vars( - self._trainer.checkpoint_callback).items() if not callable(v)} + # add interesting "ModelCheckpoint" data (mainly non-default values) + metadata['ModelCheckpoint'] = {k: getattr(self._trainer.checkpoint_callback, k) for k, ignore_val in [ + ('monitor', ''), ('mode', ''), # save also default values + ('current_score', None), ('best_model_score', None), ('best_model_path', ''), ('last_model_path', ''), + ('save_last', None), ('save_top_k', None), ('save_weights_only', False), ('period', 1), + ('_last_global_step_saved', 0)] + if getattr(self._trainer.checkpoint_callback, k, ignore_val) != ignore_val} + if getattr(self._trainer.checkpoint_callback, 'best_k_models', None): + # keep only filename + metadata['ModelCheckpoint']['best_k_models'] = { + Path(k).name: v for k, v in self._trainer.checkpoint_callback.best_k_models.items()} artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) # add relevant checkpoints for checkpoint in set([ From 52b642fefc8ead99a80543ac4179f1d8a293bb7a Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Wed, 3 Mar 2021 18:03:51 -0600 Subject: [PATCH 13/46] feat: 1 checkpoint = 1 artifact --- pytorch_lightning/loggers/wandb.py | 47 +++++++++++++----------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index b2aafb2e642d0..8a41fa24cfcef 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -17,7 +17,6 @@ """ import os import re -import numbers from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, Union, TYPE_CHECKING @@ -223,29 +222,23 @@ def finalize(self, status: str) -> None: if self._log_model and self._trainer is not None and self._trainer.checkpoint_callback is not None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) - # gather summary metrics - metadata = { - k: v - for k, v in dict(self.experiment.summary).items() - if isinstance(v, numbers.Number) and not k.startswith("_") - } - # add interesting "ModelCheckpoint" data (mainly non-default values) - metadata['ModelCheckpoint'] = {k: getattr(self._trainer.checkpoint_callback, k) for k, ignore_val in [ - ('monitor', ''), ('mode', ''), # save also default values - ('current_score', None), ('best_model_score', None), ('best_model_path', ''), ('last_model_path', ''), - ('save_last', None), ('save_top_k', None), ('save_weights_only', False), ('period', 1), - ('_last_global_step_saved', 0)] - if getattr(self._trainer.checkpoint_callback, k, ignore_val) != ignore_val} - if getattr(self._trainer.checkpoint_callback, 'best_k_models', None): - # keep only filename - metadata['ModelCheckpoint']['best_k_models'] = { - Path(k).name: v for k, v in self._trainer.checkpoint_callback.best_k_models.items()} - artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) - # add relevant checkpoints - for checkpoint in set([ - self._trainer.checkpoint_callback.best_model_path, - self._trainer.checkpoint_callback.last_model_path, - *self._trainer.checkpoint_callback.best_k_models.keys() - ]) - {''}: - artifact.add_file(checkpoint) - self.experiment.log_artifact(artifact) + # get checkpoints to be saved with associated score + checkpoints = { + self._trainer.checkpoint_callback.last_model_path: self._trainer.checkpoint_callback.current_score, + self._trainer.checkpoint_callback.best_model_path: self._trainer.checkpoint_callback.best_model_score, + **self._trainer.checkpoint_callback.best_k_models} + checkpoints.pop('', None) + ordered_checkpoints = sorted([(Path(p).stat().st_mtime, p, s) + for p, s in checkpoints.items() if Path(p).is_file()]) + # log iteratively all checkpoints + for _, p, s in ordered_checkpoints: + metadata = {'score': s, 'original_filename': Path(p).name, + 'ModelCheckpoint': {k: getattr(self._trainer.checkpoint_callback, k) for k in [ + 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' + ]}} + artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + artifact.add_file(p, name='model.ckpt') + self.experiment.log_artifact( + artifact, + aliases=["latest", "best"] if p == self._trainer.checkpoint_callback.best_model_path + else ["latest"]) From 765d081d909115b5ab0a60dfa3c77cd07e24e554 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 09:40:53 -0600 Subject: [PATCH 14/46] feat: proxy typing + apply suggestions --- pytorch_lightning/loggers/base.py | 5 +++-- pytorch_lightning/loggers/wandb.py | 6 +++--- .../trainer/connectors/logger_connector/logger_connector.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 722c4207f1ae2..aa66ca35cf36f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer + from weakref import ReferenceType def rank_zero_experiment(fn: Callable) -> Callable: @@ -75,7 +76,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def connect(self, trainer: 'Trainer') -> None: + def connect(self, trainer: 'ReferenceType[Trainer]') -> None: """ Connect trainer to logger @@ -368,7 +369,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def connect(self, trainer: 'Trainer') -> None: + def connect(self, trainer: 'ReferenceType[Trainer]') -> None: for logger in self._logger_iterable: logger.connect(trainer) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 8a41fa24cfcef..8f20c987734ba 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer + from weakref import ReferenceType warning_cache = WarningCache() @@ -145,9 +146,8 @@ def __getstate__(self): state['_experiment'] = None return state - def connect(self, trainer: 'Trainer') -> None: - if trainer is not None: - self._trainer = trainer + def connect(self, trainer: 'ReferenceType[Trainer]') -> None: + self._trainer = trainer @property @rank_zero_experiment diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a0ea4faddf258..5aa36a7e4af4f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -144,7 +144,6 @@ def should_update_logs(self): return should_log_every_n_steps or self.trainer.should_stop def configure_logger(self, logger): - # connect logger to trainer if logger is True: version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id) From 4a55e4676c7344dc4f44d1962b0f723f2b3884ef Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 12:15:05 -0600 Subject: [PATCH 15/46] feat: don't log same model twice --- pytorch_lightning/loggers/wandb.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 8f20c987734ba..eded38ed7907d 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -132,6 +132,7 @@ def __init__( self._anonymous = 'allow' if anonymous else None self._project = project self._log_model = log_model + self._logged_model_time = {} self._prefix = prefix self._experiment = experiment self._kwargs = kwargs @@ -227,11 +228,13 @@ def finalize(self, status: str) -> None: self._trainer.checkpoint_callback.last_model_path: self._trainer.checkpoint_callback.current_score, self._trainer.checkpoint_callback.best_model_path: self._trainer.checkpoint_callback.best_model_score, **self._trainer.checkpoint_callback.best_k_models} - checkpoints.pop('', None) - ordered_checkpoints = sorted([(Path(p).stat().st_mtime, p, s) - for p, s in checkpoints.items() if Path(p).is_file()]) - # log iteratively all checkpoints - for _, p, s in ordered_checkpoints: + checkpoints = sorted([(Path(p).stat().st_mtime, p, s) + for p, s in checkpoints.items() if Path(p).is_file()]) + checkpoints = [c for c in checkpoints + if c[1] not in self._logged_models.keys() or self._logged_models[c[1]] < c[0]] + + # log iteratively all new checkpoints + for t, p, s in checkpoints: metadata = {'score': s, 'original_filename': Path(p).name, 'ModelCheckpoint': {k: getattr(self._trainer.checkpoint_callback, k) for k in [ 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' @@ -242,3 +245,5 @@ def finalize(self, status: str) -> None: artifact, aliases=["latest", "best"] if p == self._trainer.checkpoint_callback.best_model_path else ["latest"]) + # remember logged models + self._logged_model_time[p] = t From f16231cf8573d63ef97fae2624b164aa1006a0a6 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 12:22:31 -0600 Subject: [PATCH 16/46] =?UTF-8?q?fix:=C2=A0typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index eded38ed7907d..1f9a3ea6273de 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -231,7 +231,7 @@ def finalize(self, status: str) -> None: checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()]) checkpoints = [c for c in checkpoints - if c[1] not in self._logged_models.keys() or self._logged_models[c[1]] < c[0]] + if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]] # log iteratively all new checkpoints for t, p, s in checkpoints: From cbbf8ffc8f35b429f47204eb4664294c4f1010da Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 15:13:31 -0600 Subject: [PATCH 17/46] feat: log artifacts during training --- .../callbacks/model_checkpoint.py | 4 + pytorch_lightning/loggers/base.py | 13 ++- pytorch_lightning/loggers/wandb.py | 84 +++++++++++-------- .../logger_connector/logger_connector.py | 4 - 4 files changed, 59 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d552560191a35..510ae735ec780 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -241,6 +241,10 @@ def save_checkpoint(self, trainer, pl_module): # Mode 2: save the last checkpoint self._save_last_checkpoint(trainer, monitor_candidates) + # notify loggers + if trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): + trainer.logger.after_save_checkpoint(self) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index aa66ca35cf36f..5c1492d56f615 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -29,8 +29,7 @@ from pytorch_lightning.utilities import rank_zero_only if TYPE_CHECKING: - from pytorch_lightning.trainer.trainer import Trainer - from weakref import ReferenceType + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint def rank_zero_experiment(fn: Callable) -> Callable: @@ -76,12 +75,12 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def connect(self, trainer: 'ReferenceType[Trainer]') -> None: + def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: """ - Connect trainer to logger + Called after model checkpoint callback saves a new checkpoint Args: - trainer: the trainer instance to connect to + model_checkpoint: the model checkpoint callback instance """ pass @@ -369,9 +368,9 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def connect(self, trainer: 'ReferenceType[Trainer]') -> None: + def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: for logger in self._logger_iterable: - logger.connect(trainer) + logger.after_save_checkpoint(checkpoint_callback) def update_agg_funcs( self, diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 1f9a3ea6273de..91031135c69c3 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -29,8 +29,7 @@ from pytorch_lightning.utilities.warnings import WarningCache if TYPE_CHECKING: - from pytorch_lightning.trainer.trainer import Trainer - from weakref import ReferenceType + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint warning_cache = WarningCache() @@ -63,6 +62,13 @@ class WandbLogger(LightningLoggerBase): anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. log_model: Save checkpoints as W&B artifacts. + if ``log_model == 'all'`` or + :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` == -1, + all checkpoints are saved during training. + if ``log_model == True``, checkpoints are saved at the end of training based on + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` parameters + (last checkpoint by default). + if ``log_model == False`` (Default), no checkpoint is saved. prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by @@ -132,11 +138,11 @@ def __init__( self._anonymous = 'allow' if anonymous else None self._project = project self._log_model = log_model - self._logged_model_time = {} self._prefix = prefix self._experiment = experiment self._kwargs = kwargs - self._trainer = None + self._logged_model_time = {} + self._checkpoint_callback = None def __getstate__(self): state = self.__dict__.copy() @@ -147,9 +153,6 @@ def __getstate__(self): state['_experiment'] = None return state - def connect(self, trainer: 'ReferenceType[Trainer]') -> None: - self._trainer = trainer - @property @rank_zero_experiment def experiment(self) -> Run: @@ -217,33 +220,44 @@ def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id + @rank_zero_only + def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: + # log checkpoints as artifacts + if self._log_model == 'all' or checkpoint_callback.save_top_k == -1: + self._scan_and_log_checkpoints(checkpoint_callback) + elif self._log_model is True: + self._checkpoint_callback = checkpoint_callback + @rank_zero_only def finalize(self, status: str) -> None: - # save checkpoints as artifacts - if self._log_model and self._trainer is not None and self._trainer.checkpoint_callback is not None: - # use run name and ensure it's a valid Artifact name - artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) - # get checkpoints to be saved with associated score - checkpoints = { - self._trainer.checkpoint_callback.last_model_path: self._trainer.checkpoint_callback.current_score, - self._trainer.checkpoint_callback.best_model_path: self._trainer.checkpoint_callback.best_model_score, - **self._trainer.checkpoint_callback.best_k_models} - checkpoints = sorted([(Path(p).stat().st_mtime, p, s) - for p, s in checkpoints.items() if Path(p).is_file()]) - checkpoints = [c for c in checkpoints - if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]] - - # log iteratively all new checkpoints - for t, p, s in checkpoints: - metadata = {'score': s, 'original_filename': Path(p).name, - 'ModelCheckpoint': {k: getattr(self._trainer.checkpoint_callback, k) for k in [ - 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' - ]}} - artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) - artifact.add_file(p, name='model.ckpt') - self.experiment.log_artifact( - artifact, - aliases=["latest", "best"] if p == self._trainer.checkpoint_callback.best_model_path - else ["latest"]) - # remember logged models - self._logged_model_time[p] = t + # log checkpoints as artifacts + if self._checkpoint_callback: + self._scan_and_log_checkpoints(self._checkpoint_callback) + + def _scan_and_log_checkpoints(self, checkpoint_callback: 'ModelCheckpoint') -> None: + # use run name and ensure it's a valid Artifact name + artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) + # get checkpoints to be saved with associated score + checkpoints = { + checkpoint_callback.last_model_path: checkpoint_callback.current_score, + checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, + **checkpoint_callback.best_k_models} + checkpoints = sorted([(Path(p).stat().st_mtime, p, s) + for p, s in checkpoints.items() if Path(p).is_file()]) + checkpoints = [c for c in checkpoints + if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]] + + # log iteratively all new checkpoints + for t, p, s in checkpoints: + metadata = {'score': s, 'original_filename': Path(p).name, + 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ + 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' + ]}} + artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + artifact.add_file(p, name='model.ckpt') + self.experiment.log_artifact( + artifact, + aliases=["latest", "best"] if p == checkpoint_callback.best_model_path + else ["latest"]) + # remember logged models + self._logged_model_time[p] = t diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bd85088d1ab0a..761776aeefb42 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -158,10 +158,6 @@ def configure_logger(self, logger): else: self.trainer.logger = logger - # connect trainer to logger - if hasattr(self.trainer.logger, 'connect'): - self.trainer.logger.connect(proxy(self.trainer)) - def cache_training_step_metrics(self, opt_closure_result): """ This function is responsible to update From 123cd8807fdda6fd39ba143331e06bf5e8003204 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 15:30:10 -0600 Subject: [PATCH 18/46] fix: docs build --- pytorch_lightning/loggers/base.py | 10 ++++------ pytorch_lightning/loggers/wandb.py | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 5c1492d56f615..21bf66efd3d1a 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -20,16 +20,14 @@ from argparse import Namespace from functools import wraps from typing import (Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, - Optional, Sequence, Tuple, Union, TYPE_CHECKING) + Optional, Sequence, Tuple, Union) import numpy as np import torch from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only - -if TYPE_CHECKING: - from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint def rank_zero_experiment(fn: Callable) -> Callable: @@ -75,7 +73,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: """ Called after model checkpoint callback saves a new checkpoint @@ -368,7 +366,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 91031135c69c3..930ce34a87cbf 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -19,7 +19,7 @@ import re from argparse import Namespace from pathlib import Path -from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, Optional, Union import torch.nn as nn @@ -27,9 +27,7 @@ from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache - -if TYPE_CHECKING: - from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint warning_cache = WarningCache() @@ -221,7 +219,7 @@ def version(self) -> Optional[str]: return self._experiment.id if self._experiment else self._id @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: 'ModelCheckpoint') -> None: + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == 'all' or checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) @@ -234,7 +232,7 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: 'ModelCheckpoint') -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) # get checkpoints to be saved with associated score From 0822d5d72dc108c878f509817deaeedf3c717718 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 15:36:54 -0600 Subject: [PATCH 19/46] feat: use proxy ref --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- pytorch_lightning/loggers/base.py | 5 +++-- pytorch_lightning/loggers/wandb.py | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 510ae735ec780..11cadb8d5e5e0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -23,6 +23,7 @@ import re from copy import deepcopy from pathlib import Path +from weakref import proxy from typing import Any, Dict, Optional, Union import numpy as np @@ -243,7 +244,7 @@ def save_checkpoint(self, trainer, pl_module): # notify loggers if trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): - trainer.logger.after_save_checkpoint(self) + trainer.logger.after_save_checkpoint(proxy(self)) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 21bf66efd3d1a..3ae93ec0328db 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,6 +19,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps +from weakref import ReferenceType from typing import (Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union) @@ -73,7 +74,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: """ Called after model checkpoint callback saves a new checkpoint @@ -366,7 +367,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 930ce34a87cbf..95f4f0c52b6f3 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -20,6 +20,7 @@ from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, Union +from weakref import ReferenceType import torch.nn as nn @@ -219,7 +220,7 @@ def version(self) -> Optional[str]: return self._experiment.id if self._experiment else self._id @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: # log checkpoints as artifacts if self._log_model == 'all' or checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) @@ -232,7 +233,7 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) # get checkpoints to be saved with associated score From 947ab7aff287677087d67e28c76106a92a020df4 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 16:14:48 -0600 Subject: [PATCH 20/46] fix: mypy --- pytorch_lightning/loggers/base.py | 4 ++-- pytorch_lightning/loggers/wandb.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 3ae93ec0328db..b2b47ad675b3f 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -74,7 +74,7 @@ def __init__( self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {} self._agg_default_func = agg_default_func - def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: """ Called after model checkpoint callback saves a new checkpoint @@ -367,7 +367,7 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): def __getitem__(self, index: int) -> LightningLoggerBase: return [logger for logger in self._logger_iterable][index] - def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 95f4f0c52b6f3..f5e35391aad74 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -220,7 +220,7 @@ def version(self) -> Optional[str]: return self._experiment.id if self._experiment else self._id @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: + def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: # log checkpoints as artifacts if self._log_model == 'all' or checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) @@ -233,7 +233,7 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: ReferenceType[ModelCheckpoint]) -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: # use run name and ensure it's a valid Artifact name artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) # get checkpoints to be saved with associated score From 03af2c3f3eb6b3b5e112cd36e5538258e193719c Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 16:23:45 -0600 Subject: [PATCH 21/46] fix: unused import --- .../trainer/connectors/logger_connector/logger_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 761776aeefb42..45cdecfdc8515 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,6 @@ import os from copy import deepcopy from pprint import pprint -from weakref import proxy from typing import Dict, Iterable, Optional, Union import torch From 743903c42ca5ffb1018b7e03348e9370493edc36 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 4 Mar 2021 16:25:23 -0600 Subject: [PATCH 22/46] fix: continuous logging logic --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f5e35391aad74..9c5efe8f9dc0f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -222,7 +222,7 @@ def version(self) -> Optional[str]: @rank_zero_only def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: # log checkpoints as artifacts - if self._log_model == 'all' or checkpoint_callback.save_top_k == -1: + if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback From 363b3acd4eb6c0534d94a9268904ab1d7e400714 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 5 Mar 2021 09:33:26 -0600 Subject: [PATCH 23/46] fix: formatting --- pytorch_lightning/loggers/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index b2b47ad675b3f..e930701a4a512 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -20,8 +20,7 @@ from argparse import Namespace from functools import wraps from weakref import ReferenceType -from typing import (Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, - Optional, Sequence, Tuple, Union) +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch From 7e331c119767ad9caa7b0b6e7e63a8b6e04be7cb Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 5 Mar 2021 13:04:14 -0600 Subject: [PATCH 24/46] docs: update log_model --- pytorch_lightning/loggers/wandb.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 9c5efe8f9dc0f..a90c0f0bcae83 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -60,14 +60,15 @@ class WandbLogger(LightningLoggerBase): version: Same as id. anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. - log_model: Save checkpoints as W&B artifacts. - if ``log_model == 'all'`` or - :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` == -1, - all checkpoints are saved during training. - if ``log_model == True``, checkpoints are saved at the end of training based on - :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` parameters - (last checkpoint by default). - if ``log_model == False`` (Default), no checkpoint is saved. + log_model: Save checkpoints as W&B artifacts as created by + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + + * if ``log_model == 'all'``, checkpoints are logged during training. + * if ``log_model == True``, checkpoints are logged at the end of training, except when + :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` == -1 + which also saves every checkpoint during training. + * if ``log_model == False`` (default), no checkpoint is logged. + prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by From b4389404abb3bc4acf489eeca05fe239902d94bf Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 5 Mar 2021 13:18:00 -0600 Subject: [PATCH 25/46] docs(wandb): improve log_model --- pytorch_lightning/loggers/wandb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index a90c0f0bcae83..925d1a47c60d4 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -60,13 +60,13 @@ class WandbLogger(LightningLoggerBase): version: Same as id. anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. - log_model: Save checkpoints as W&B artifacts as created by - :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` + as W&B artifacts. * if ``log_model == 'all'``, checkpoints are logged during training. * if ``log_model == True``, checkpoints are logged at the end of training, except when - :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` == -1 - which also saves every checkpoint during training. + :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1`` + which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. prefix: A string to put at the beginning of metric keys. From 0dc78cc4e1b10efefd90b01089da41f945473cce Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 5 Mar 2021 13:18:22 -0600 Subject: [PATCH 26/46] feat(wandb): more explicit artifact name --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 925d1a47c60d4..799156b253568 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -253,7 +253,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' ]}} - artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + artifact = wandb.Artifact(name=f"model-{artifact_name}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') self.experiment.log_artifact( artifact, From 78cfc7c1b3d2143e432bb8ecbb9ee6baa3250493 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 5 Mar 2021 14:09:31 -0600 Subject: [PATCH 27/46] feat(wandb): simplify artifact name --- pytorch_lightning/loggers/wandb.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 799156b253568..456260f58d74d 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -16,7 +16,6 @@ ------------------------- """ import os -import re from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, Union @@ -235,8 +234,6 @@ def finalize(self, status: str) -> None: self._scan_and_log_checkpoints(self._checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: - # use run name and ensure it's a valid Artifact name - artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name) # get checkpoints to be saved with associated score checkpoints = { checkpoint_callback.last_model_path: checkpoint_callback.current_score, @@ -253,7 +250,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' ]}} - artifact = wandb.Artifact(name=f"model-{artifact_name}", type="model", metadata=metadata) + artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') self.experiment.log_artifact( artifact, From eeed466e2b2e093725e361a18354c5b02ccf97f7 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 16:00:42 -0600 Subject: [PATCH 28/46] docs(wandb): improve documentation --- docs/source/common/loggers.rst | 13 +++++++++---- pytorch_lightning/loggers/wandb.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/source/common/loggers.rst b/docs/source/common/loggers.rst index c6c5f0d8653c7..a42e1394599bb 100644 --- a/docs/source/common/loggers.rst +++ b/docs/source/common/loggers.rst @@ -202,7 +202,7 @@ The :class:`~pytorch_lightning.loggers.TestTubeLogger` is available anywhere exc Weights and Biases ================== -`Weights and Biases `_ is a third-party logger. +`Weights and Biases `_ is a third-party logger. To use :class:`~pytorch_lightning.loggers.WandbLogger` as your logger do the following. First, install the package: @@ -215,9 +215,14 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. code-block:: python from pytorch_lightning.loggers import WandbLogger - wandb_logger = WandbLogger(offline=True) + + # instrument experiment with W&B + wandb_logger = WandbLogger(project='MNIST', log_model='all') trainer = Trainer(logger=wandb_logger) + # log gradients and model topology + WandbLogger.watch(model) + The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your :class:`~pytorch_lightning.core.lightning.LightningModule`. @@ -226,8 +231,8 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except class MyModule(LightningModule): def any_lightning_module_function_or_hook(self): some_img = fake_image() - self.logger.experiment.log({ - "generated_images": [wandb.Image(some_img, caption="...")] + self.log({ + "generated_images": [wandb.Image(some_img, caption="...")] }) .. seealso:: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 456260f58d74d..ecd8055edda63 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -43,7 +43,7 @@ class WandbLogger(LightningLoggerBase): r""" - Log using `Weights and Biases `_. + Log using `Weights and Biases `__ - on how to use W&B with PyTorch Lightning + - `Demo in Google Colab `__ with model logging - `W&B Documentation `__ """ From cc0fcd6b680829624c0d4620fa3ed5601e8761aa Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 16:46:10 -0600 Subject: [PATCH 29/46] test: after_save_checkpoint called --- tests/loggers/test_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index c48fef5e04b49..2215b287c60fb 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -60,6 +60,7 @@ def __init__(self): self.hparams_logged = None self.metrics_logged = {} self.finalized = False + self.after_save_checkpoint_called = False @property def experiment(self): @@ -93,6 +94,9 @@ def name(self): def version(self): return "1" + def after_save_checkpoint(self, checkpoint_callback): + self.after_save_checkpoint_called = True + def test_custom_logger(tmpdir): @@ -116,6 +120,7 @@ def training_step(self, batch, batch_idx): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert logger.hparams_logged == model.hparams assert logger.metrics_logged != {} + assert logger.after_save_checkpoint_called assert logger.finalized_status == "success" From a71603d24828687d19cfa30f4f2dd28a208f8401 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 17:27:01 -0600 Subject: [PATCH 30/46] docs(wandb): fix typo --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index ecd8055edda63..189500910c7be 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -43,7 +43,7 @@ class WandbLogger(LightningLoggerBase): r""" - Log using `Weights and Biases `_. Install it with pip: From ded7204e994680a224708e50d095facd06eb8345 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 17:27:17 -0600 Subject: [PATCH 31/46] test(wandb): test log_model --- tests/loggers/test_wandb.py | 41 +++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 0eefb9625ddc7..421849a30922c 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -25,14 +25,8 @@ from tests.helpers import BoringModel -def get_warnings(recwarn): - warnings_text = '\n'.join(str(w.message) for w in recwarn.list) - recwarn.clear() - return warnings_text - - @mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger_init(wandb, recwarn): +def test_wandb_logger_init(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" @@ -127,10 +121,8 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): # mock return values of experiment wandb.run = None - wandb.init().step = 0 logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' - logger.experiment.step = 0 for _ in range(2): _ = logger.experiment @@ -151,6 +143,37 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): assert trainer.log_dir == logger.save_dir +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_log_model(wandb, tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + + wandb.run = None + model = BoringModel() + + # test log_model=True + logger = WandbLogger(log_model=True) + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + wandb.init().log_artifact.assert_called_once() + + # test log_model='all' + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model='all') + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + assert wandb.init().log_artifact.call_count == 2 + + # test log_model=False + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + logger = WandbLogger(log_model=False) + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + wandb.init().log_artifact.asser() + assert not wandb.init().log_artifact.called + + def test_wandb_sanitize_callable_params(tmpdir): """ Callback function are not serializiable. Therefore, we get them a chance to return From 1b88a5e50be681492a9cbfa28563d09e509bb70f Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 17:54:13 -0600 Subject: [PATCH 32/46] feat(wandb): min version --- pytorch_lightning/loggers/wandb.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 189500910c7be..fd4023e0992c1 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -20,11 +20,13 @@ from pathlib import Path from typing import Any, Dict, Optional, Union from weakref import ReferenceType +import operator import torch.nn as nn from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only +from pytorch_lightning.utilities.imports import _compare_version from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -36,6 +38,7 @@ try: import wandb from wandb.wandb_run import Run + _WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") except ImportError: # needed for test mocks, these tests shall be updated wandb, Run = None, None @@ -127,6 +130,13 @@ def __init__( 'Hint: Set `offline=False` to log your model.' ) + if log_model and not _WANDB_GREATER_EQUAL_0_10_22: + warning_cache.warn( + f'Providing log_model={log_model} requires wandb version >= 0.10.22' + ' for logging associated model metadata.\n' + 'Hint: Upgrade with `pip install --ugrade wandb`.' + ) + if sync_step is not None: warning_cache.warn( "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5." @@ -253,7 +263,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe metadata = {'score': s, 'original_filename': Path(p).name, 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' - ]}} + ]}} if _WANDB_GREATER_EQUAL_0_10_22 else None artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') self.experiment.log_artifact( From 4f3581344f2cfa293a177eb399cc9b7edc2cc908 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 17:59:34 -0600 Subject: [PATCH 33/46] test(wandb): fix directory creation --- tests/loggers/test_wandb.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 421849a30922c..c70ae4f9ac9f0 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -152,6 +152,8 @@ def test_wandb_log_model(wandb, tmpdir): # test log_model=True logger = WandbLogger(log_model=True) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.init().log_artifact.assert_called_once() @@ -160,6 +162,8 @@ def test_wandb_log_model(wandb, tmpdir): wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model='all') + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) assert wandb.init().log_artifact.call_count == 2 @@ -168,6 +172,8 @@ def test_wandb_log_model(wandb, tmpdir): wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model=False) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.init().log_artifact.asser() From 876dbee08dcfdb04e3632807b72356f0b258630a Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 18:12:44 -0600 Subject: [PATCH 34/46] docs: update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed7eec7cff7f9..208922739620e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231)) + + ### Deprecated From ba1e9376c0ec3faa14c2e7e7e52be6761eee6bfa Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 7 Mar 2021 18:49:16 -0600 Subject: [PATCH 35/46] test(wandb): fix variable not defined --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index fd4023e0992c1..680696ff7caf4 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -34,11 +34,11 @@ warning_cache = WarningCache() _WANDB_AVAILABLE = _module_available("wandb") +_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") try: import wandb from wandb.wandb_run import Run - _WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") except ImportError: # needed for test mocks, these tests shall be updated wandb, Run = None, None From fe98f4f9ec4b9b519ede44d7fca9dd1ff4e0daf5 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 9 Mar 2021 07:49:04 -0600 Subject: [PATCH 36/46] feat: after_save_checkpoint on rank 0 only --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- pytorch_lightning/loggers/wandb.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 442382775b8d8..112e0100c3fd1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -250,7 +250,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): self._save_last_checkpoint(trainer, monitor_candidates) # notify loggers - if trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): + if trainer.is_global_zero and trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): trainer.logger.after_save_checkpoint(proxy(self)) def __validate_init_configuration(self): diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 680696ff7caf4..886a8a5597a23 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -233,7 +233,6 @@ def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id - @rank_zero_only def after_save_checkpoint(self, checkpoint_callback: 'ReferenceType[ModelCheckpoint]') -> None: # log checkpoints as artifacts if self._log_model == 'all' or self._log_model is True and checkpoint_callback.save_top_k == -1: From aa904ce3dfe5e8fa4263e0b5a67ec94243e4e38e Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 12 Mar 2021 12:06:22 -0600 Subject: [PATCH 37/46] feat: handle new args of ModelCheckpoint --- pytorch_lightning/loggers/wandb.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 886a8a5597a23..feb42faec06c5 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -261,13 +261,15 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe for t, p, s in checkpoints: metadata = {'score': s, 'original_filename': Path(p).name, 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ - 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', 'period' - ]}} if _WANDB_GREATER_EQUAL_0_10_22 else None + 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps', + '_every_n_val_epochs'] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k)}} if _WANDB_GREATER_EQUAL_0_10_22 else None artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') self.experiment.log_artifact( artifact, aliases=["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]) - # remember logged models + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t From 27c49eb57f6da05d29b7cccab703388678b7479a Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Fri, 12 Mar 2021 12:06:33 -0600 Subject: [PATCH 38/46] test(wandb): check correct metadata --- tests/loggers/test_wandb.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index c70ae4f9ac9f0..3058870f56ed1 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -176,9 +176,27 @@ def test_wandb_log_model(wandb, tmpdir): logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) - wandb.init().log_artifact.asser() assert not wandb.init().log_artifact.called + # test correct metadata + import pytorch_lightning.loggers.wandb as pl_wandb + pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True + wandb.init().log_artifact.reset_mock() + wandb.init.reset_mock() + wandb.Artifact.reset_mock() + logger = pl_wandb.WandbLogger(log_model=True) + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) + trainer.fit(model) + wandb.Artifact.assert_called_once_with(name='model-1', type='model', + metadata={'score': None, 'original_filename': 'epoch=1-step=5-v3.ckpt', + 'ModelCheckpoint': {'monitor': None, 'mode': 'min', + 'save_last': None, 'save_top_k': None, + 'save_weights_only': False, + '_every_n_train_steps': 0, + '_every_n_val_epochs': 1}}) + def test_wandb_sanitize_callable_params(tmpdir): """ @@ -211,7 +229,7 @@ def wrapper_something(): assert params["wrapper_something_wo_name"] == "" -@mock.patch('pytorch_lightning.loggers.wandb.wandb') +@ mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_wandb_logger_offline_log_model(wandb, tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): From e0a95785189a4bec38810b64ac29df4f15ef3cdf Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 14 Mar 2021 14:21:09 -0500 Subject: [PATCH 39/46] tests(wandb): unused fixture --- tests/loggers/test_wandb.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 3058870f56ed1..f56c80dad3487 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -229,8 +229,7 @@ def wrapper_something(): assert params["wrapper_something_wo_name"] == "" -@ mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger_offline_log_model(wandb, tmpdir): +def test_wandb_logger_offline_log_model(tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) From 58193e874a6a51e9cf70968dcd90168e4701adf3 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 14 Mar 2021 14:37:12 -0500 Subject: [PATCH 40/46] feat: logger.after_save_checkpoint always exists --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fb5f0aede7007..2a2e6acd0a5af 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -285,7 +285,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): self._save_last_checkpoint(trainer, monitor_candidates) # notify loggers - if trainer.is_global_zero and trainer.logger and hasattr(trainer.logger, 'after_save_checkpoint'): + if trainer.is_global_zero and trainer.logger: trainer.logger.after_save_checkpoint(proxy(self)) def _should_skip_saving_checkpoint(self, trainer) -> bool: From fda377f6c63417ee8a8c6e4a519b8912d3e37044 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Sun, 14 Mar 2021 14:59:21 -0500 Subject: [PATCH 41/46] test: wandb fixture required --- tests/loggers/test_wandb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index f56c80dad3487..6c53ca92e6c58 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -229,7 +229,8 @@ def wrapper_something(): assert params["wrapper_something_wo_name"] == "" -def test_wandb_logger_offline_log_model(tmpdir): +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_logger_offline_log_model(wandb, tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) From 0b7bb3979f374ce4347b28478567834fdf36b3a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 May 2021 00:25:31 +0000 Subject: [PATCH 42/46] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../callbacks/model_checkpoint.py | 2 +- pytorch_lightning/loggers/base.py | 4 +- pytorch_lightning/loggers/wandb.py | 41 +++++++++++-------- tests/loggers/test_wandb.py | 24 +++++++---- 4 files changed, 44 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7443e133eec6b..295dbb6133ed2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -23,8 +23,8 @@ import re from copy import deepcopy from pathlib import Path -from weakref import proxy from typing import Any, Callable, Dict, Optional, Union +from weakref import proxy import numpy as np import torch diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index f3a81b3edd5bc..7736ed24baefe 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,15 +19,15 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from weakref import ReferenceType from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union +from weakref import ReferenceType import numpy as np import torch +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint def rank_zero_experiment(fn: Callable) -> Callable: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 1c8f991c570c3..9695af2f16f0d 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -15,21 +15,21 @@ Weights and Biases Logger ------------------------- """ +import operator import os from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, Union from weakref import ReferenceType -import operator import torch.nn as nn +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only -from pytorch_lightning.utilities.imports import _compare_version from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _compare_version from pytorch_lightning.utilities.warnings import WarningCache -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint warning_cache = WarningCache() @@ -251,25 +251,32 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe checkpoints = { checkpoint_callback.last_model_path: checkpoint_callback.current_score, checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, - **checkpoint_callback.best_k_models} - checkpoints = sorted([(Path(p).stat().st_mtime, p, s) - for p, s in checkpoints.items() if Path(p).is_file()]) - checkpoints = [c for c in checkpoints - if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]] + **checkpoint_callback.best_k_models + } + checkpoints = sorted([(Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()]) + checkpoints = [ + c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0] + ] # log iteratively all new checkpoints for t, p, s in checkpoints: - metadata = {'score': s, 'original_filename': Path(p).name, - 'ModelCheckpoint': {k: getattr(checkpoint_callback, k) for k in [ - 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps', - '_every_n_val_epochs'] - # ensure it does not break if `ModelCheckpoint` args change - if hasattr(checkpoint_callback, k)}} if _WANDB_GREATER_EQUAL_0_10_22 else None + metadata = { + 'score': s, + 'original_filename': Path(p).name, + 'ModelCheckpoint': { + k: getattr(checkpoint_callback, k) + for k in [ + 'monitor', 'mode', 'save_last', 'save_top_k', 'save_weights_only', '_every_n_train_steps', + '_every_n_val_epochs' + ] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k) + } + } if _WANDB_GREATER_EQUAL_0_10_22 else None artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') self.experiment.log_artifact( - artifact, - aliases=["latest", "best"] if p == checkpoint_callback.best_model_path - else ["latest"]) + artifact, aliases=["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + ) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 98a81fc519302..50e703c98de02 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -202,13 +202,23 @@ def test_wandb_log_model(wandb, tmpdir): logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) - wandb.Artifact.assert_called_once_with(name='model-1', type='model', - metadata={'score': None, 'original_filename': 'epoch=1-step=5-v3.ckpt', - 'ModelCheckpoint': {'monitor': None, 'mode': 'min', - 'save_last': None, 'save_top_k': None, - 'save_weights_only': False, - '_every_n_train_steps': 0, - '_every_n_val_epochs': 1}}) + wandb.Artifact.assert_called_once_with( + name='model-1', + type='model', + metadata={ + 'score': None, + 'original_filename': 'epoch=1-step=5-v3.ckpt', + 'ModelCheckpoint': { + 'monitor': None, + 'mode': 'min', + 'save_last': None, + 'save_top_k': None, + 'save_weights_only': False, + '_every_n_train_steps': 0, + '_every_n_val_epochs': 1 + } + } + ) def test_wandb_sanitize_callable_params(tmpdir): From c06fc8f96df5e6137574a454bd9464d8d934c771 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Thu, 13 May 2021 19:50:29 -0500 Subject: [PATCH 43/46] test(wandb): parameter unset --- tests/loggers/test_wandb.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 50e703c98de02..27185b911b6d0 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -45,8 +45,6 @@ def test_wandb_logger_init(wandb): run = wandb.init() logger = WandbLogger(experiment=run) assert logger.experiment - assert run.dir is not None - assert logger.save_dir == run.dir # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() From 0ca6abb52c07e9e8e48125c155aa565871f0dbee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 11:30:52 +0200 Subject: [PATCH 44/46] formatting --- pytorch_lightning/loggers/wandb.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 9695af2f16f0d..0a1e17b258953 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -275,8 +275,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: 'ReferenceType[ModelChe } if _WANDB_GREATER_EQUAL_0_10_22 else None artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name='model.ckpt') - self.experiment.log_artifact( - artifact, aliases=["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - ) + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t From f6f8f616675019048e85a4aee4790a49a08d818b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 11:31:22 +0200 Subject: [PATCH 45/46] typo fix --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0a1e17b258953..c127fa037ed6b 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -91,7 +91,7 @@ class WandbLogger(LightningLoggerBase): trainer = Trainer(logger=wandb_logger) # log gradients and model topology - WandbLogger.watch(model) + wandb_logger.watch(model) See Also: - `Demo in Google Colab `__ with model logging From 1faa3896d1d1a615bfa9171469d43a962bbe4829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 May 2021 11:36:14 +0200 Subject: [PATCH 46/46] fix typo in docs --- docs/source/common/loggers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/loggers.rst b/docs/source/common/loggers.rst index a42e1394599bb..5b1f13dbf4b8c 100644 --- a/docs/source/common/loggers.rst +++ b/docs/source/common/loggers.rst @@ -221,7 +221,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. trainer = Trainer(logger=wandb_logger) # log gradients and model topology - WandbLogger.watch(model) + wandb_logger.watch(model) The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your :class:`~pytorch_lightning.core.lightning.LightningModule`.