Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
sidhantls authored Jan 7, 2021
2 parents 4e279c6 + 5f94900 commit a2ce948
Show file tree
Hide file tree
Showing 32 changed files with 340 additions and 187 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))


- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))


Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def ddp_train(self, process_idx, model):
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def ddp_train(self, process_idx, model):
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def train(self):
self.__recover_child_process_weights(model, best_path, last_path)
return results

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
def ddp_train(self, process_idx, mp_queue, model, is_master: bool = False, proc_offset: int = 0):
"""
Entry point for ddp
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def to_device(self, batch):
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
Monitor a metric and stop training when it stops improving.
"""
import numbers

import numpy as np
import torch

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn


class EarlyStopping(Callback):
Expand Down Expand Up @@ -196,15 +194,6 @@ def _run_early_stopping_check(self, trainer, pl_module):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if trainer.use_tpu and _TPU_AVAILABLE:
current = current.cpu()

if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
Expand Down
10 changes: 1 addition & 9 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""

import numbers
import os
import re
from copy import deepcopy
Expand All @@ -33,7 +32,6 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -554,12 +552,6 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down Expand Up @@ -587,7 +579,7 @@ def _update_best_and_save(
self.best_k_models.pop(del_filepath)

# do not save nan, replace with +/- inf
if torch.isnan(current):
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))

filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def average_precision(
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weight: sample weights for each data point
sample_weights: sample weights for each data point
Returns:
tensor with average precision. If multiclass will return list
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def explained_variance(
Computes explained variance.
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def fbeta(
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
pred: estimated probabilities
preds: estimated probabilities
target: ground-truth labels
num_classes: Number of classes in the dataset.
beta: Beta coefficient in the F measure.
Expand Down Expand Up @@ -128,7 +128,7 @@ def f1(
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
pred: estimated probabilities
preds: estimated probabilities
target: ground-truth labels
num_classes: Number of classes in the dataset.
threshold:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tenso
Computes mean squared error
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
Return:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.T
Computes mean squared log error
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
Return:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def precision_recall_curve(
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weight: sample weights for each data point
sample_weights: sample weights for each data point
Returns: 3-element tuple containing
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/metrics/functional/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def psnr(
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with PSNR score
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def r2score(
be provided as the ``adjusted`` argument.
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
adjusted: number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def roc(
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weight: sample weights for each data point
sample_weights: sample weights for each data point
Returns: 3-element tuple containing
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def ssim(
Computes Structual Similarity Index Measure
Args:
pred: estimated image
preds: estimated image
target: ground truth image
kernel_size: size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def class_reduce(
Args:
num: numerator tensor
decom: denominator tensor
denom: denominator tensor
weights: weights for each class
class_reduction: reduction method for multiclass problems
Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import threading
from collections.abc import Iterable, Mapping
from itertools import chain
from typing import Optional

import torch
from torch import Tensor
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel
from torch.nn import DataParallel, Module
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather

Expand Down Expand Up @@ -222,15 +224,20 @@ def warn_missing_output(fx_called):
warning_cache.warn("Your training_step returned None. Make sure that was your intention!")


def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover
def parallel_apply(
modules: Module,
inputs: Tensor,
kwargs_tup: Optional[tuple] = None,
devices: Optional[list] = None,
): # pragma: no-cover
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
modules: modules to be parallelized
inputs: inputs to the modules
devices: CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,11 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
model.on_hpc_load(checkpoint)

def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with name_key, then yield maximum suffix number.
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def update_logger_connector(self) -> None:

if is_train:
# Only log and add to callback epoch step during evaluation, test.
logger_connector.logged_metrics.update(batch_log_metrics)
logger_connector._logged_metrics.update(batch_log_metrics)
callback_metrics.update(batch_pbar_metrics)
callback_metrics.update(batch_log_metrics)
else:
Expand All @@ -389,8 +389,8 @@ def update_logger_connector(self) -> None:

# get logged_metrics
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch)
logger_connector._logged_metrics.update(epoch_log_metrics)
logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch})

# get forked_metrics
forked_metrics = self.get_forked_metrics()
Expand All @@ -403,8 +403,8 @@ def update_logger_connector(self) -> None:
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
logger_connector._callback_metrics.update(callback_metrics)
logger_connector._callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
Expand Down
Loading

0 comments on commit a2ce948

Please sign in to comment.