Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Mar 1, 2021
2 parents 393dae1 + 15c477e commit fc5dca6
Show file tree
Hide file tree
Showing 30 changed files with 283 additions and 247 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))


- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand All @@ -66,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))


- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def amp_backend(self) -> Optional[LightningEnum]:
return None

@property
def precision(self) -> int:
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision

@property
Expand Down
34 changes: 0 additions & 34 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import copy
import inspect
import os
import re
import tempfile
import uuid
from abc import ABC
Expand Down Expand Up @@ -1806,39 +1805,6 @@ def hparams_initial(self) -> AttributeDict:
# prevent any change
return copy.deepcopy(self._hparams_initial)

@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
# TODO: remove this method in v1.3.0.
rank_zero_warn(
"The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be"
" removed in v1.3.0. Replace the assignment `self.hparams = hparams` with "
" `self.save_hyperparameters()`.", DeprecationWarning
)
hparams_assignment_name = self.__get_hparams_assignment_variable()
self._hparams_name = hparams_assignment_name
self._set_hparams(hp)
# this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = copy.deepcopy(self._hparams)

def __get_hparams_assignment_variable(self):
"""
looks at the code of the class to figure out what the user named self.hparams
this only happens when the user explicitly sets self.hparams
"""
try:
class_code = inspect.getsource(self.__class__)
lines = class_code.split("\n")
for line in lines:
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
if ".hparams=" in line:
return line.split("=")[1]
# todo: specify the possible exception
except Exception:
return "hparams"

return None

@property
def model_size(self) -> float:
# todo: think about better way without need to dump model to drive
Expand Down
39 changes: 17 additions & 22 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()

_WANDB_AVAILABLE = _module_available("wandb")

try:
Expand Down Expand Up @@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
project: The name of the project to which this run will belong.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
prefix: A string to put at the beginning of metric keys.
sync_step: Sync Trainer step with wandb step.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
log_model: Optional[bool] = False,
experiment=None,
prefix: Optional[str] = '',
sync_step: Optional[bool] = True,
sync_step: Optional[bool] = None,
**kwargs
):
if wandb is None:
Expand All @@ -114,6 +115,12 @@ def __init__(
'Hint: Set `offline=False` to log your model.'
)

if sync_step is not None:
warning_cache.warn(
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand All @@ -123,12 +130,8 @@ def __init__(
self._project = project
self._log_model = log_model
self._prefix = prefix
self._sync_step = sync_step
self._experiment = experiment
self._kwargs = kwargs
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
self.warning_cache = WarningCache()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -165,12 +168,15 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)

return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

metrics = self._add_prefix(metrics)
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
self.warning_cache.warn(
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
' or try logging with `commit=False` when calling manually `wandb.log`.'
)
if self._sync_step:
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
elif step is not None:
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
if step is not None:
self.experiment.log({**metrics, 'trainer/global_step': step})
else:
self.experiment.log(metrics)

Expand All @@ -216,10 +215,6 @@ def version(self) -> Optional[str]:

@rank_zero_only
def finalize(self, status: str) -> None:
# offset future training logged on same W&B run
if self._experiment is not None:
self._step_offset = self._experiment.step

# upload all checkpoints from saving dir
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
4 changes: 4 additions & 0 deletions pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class ExplainedVariance(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example:
>>> from pytorch_lightning.metrics import ExplainedVariance
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/metrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class PSNR(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not given.
Example:
>>> from pytorch_lightning.metrics import PSNR
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/metrics/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class R2Score(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Raises:
ValueError:
If ``adjusted`` parameter is not an integer larger or equal to 0.
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example:
>>> from pytorch_lightning.metrics import R2Score
Expand Down Expand Up @@ -102,7 +108,7 @@ def __init__(
self.num_outputs = num_outputs

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.')
raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.')
self.adjusted = adjusted

allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
Expand Down
17 changes: 2 additions & 15 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Generator, Optional, Sequence, Tuple

from torch.nn import Module
from abc import ABC
from typing import Generator


class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

@abstractmethod
def connect(
self,
model: Module,
*args: Sequence,
**kwargs: Sequence,
) -> Optional[Tuple[Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""

def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""

Expand Down
40 changes: 24 additions & 16 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Tuple
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -23,37 +22,41 @@
if _APEX_AVAILABLE:
from apex import amp

if TYPE_CHECKING:
from torch.optim import Optimizer


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
def __init__(self, amp_level: str) -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation
Args:
Expand Down Expand Up @@ -94,11 +97,11 @@ def backward(

def configure_apex(
self,
amp: object,
amp: Type,
model: LightningModule,
optimizers: List[Optimizer],
optimizers: List['Optimizer'],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple[LightningModule, List['Optimizer']]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Expand Down Expand Up @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
Expand All @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
break

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: LightningModule,
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
Expand All @@ -160,6 +168,6 @@ def pre_optimizer_step(
if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

optimizer.step()
optimizer.step(**kwargs)

return False
Loading

0 comments on commit fc5dca6

Please sign in to comment.