Skip to content

Commit

Permalink
Extend support for logging a collection (#7771)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jun 1, 2021
1 parent 9a001fe commit 1dd61e4
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 51 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Raise `ValueError` when a `None` value is `self.log`-ed ([#7771](https://github.com/PyTorchLightning/pytorch-lightning/pull/7771))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


Expand Down
33 changes: 17 additions & 16 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(self, x):
def log(
self,
name: str,
value: Any,
value: _METRIC_COLLECTION,
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
Expand Down Expand Up @@ -324,6 +324,9 @@ def log(
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)

# check for none values
apply_to_collection(value, type(None), partial(self.__check_none, name, value))

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
Expand All @@ -335,14 +338,15 @@ def log(
if "/dataloader_idx_" in name:
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")

value = self.__sync(
value,
sync_fn = partial(
self.__sync,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn)

assert self._results is not None
self._results.log(
Expand All @@ -359,7 +363,7 @@ def log(

def log_dict(
self,
dictionary: dict,
dictionary: Dict[str, _METRIC_COLLECTION],
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
Expand Down Expand Up @@ -416,29 +420,26 @@ def log_dict(

@staticmethod
def __sync(
value: _METRIC,
value: Union[torch.Tensor, numbers.Number],
sync_fn: Optional[Callable] = None,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
device: torch.device = None,
) -> _METRIC:
) -> torch.Tensor:
"""Sync across workers when using distributed training"""
if not isinstance(value, (torch.Tensor, numbers.Number)):
return value

if isinstance(value, numbers.Number):
value = torch.tensor(value, device=device, dtype=torch.float)
sync_fn = sync_fn or sync_ddp_if_available
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
if not sync_dist or not dist_available:
return value

# TODO: Find a way to make the reduction only once, so we don't need to clone.
if isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

@staticmethod
def __check_none(name: str, value: Any, _) -> Any:
raise ValueError(f'`self.log({name}, {value})` was called, but `None` values cannot be logged')

def write_prediction(
self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt'
):
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch

from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DDP2Plugin(DDPPlugin):
Expand All @@ -34,26 +35,25 @@ def setup(self, model):
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here

def reduce(self, tensor, *args, **kwargs):
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Reduces a tensor from all processes to one aggregated tensor.
Reduces a collection of tensors from all processes. It can be applied to just a single tensor.
In DDP2, the reduction here is only across local devices within the node.
Args:
tensor: the tensor to sync and reduce
collection: The collection of tensors to sync and reduce.
*args: ignored for DDP2
**kwargs: ignored for DDP2
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
Reduced tensor values or the same value if it was not or did not contain a tensor.
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

elif isinstance(tensor, torch.Tensor):
tensor = tensor.mean()
def mean(t: torch.Tensor) -> torch.Tensor:
original_dtype = t.dtype
return t.float().mean().to(original_dtype)

return tensor
return apply_to_collection(collection, torch.Tensor, mean)

@property
def root_device(self):
Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DataParallelPlugin(ParallelPlugin):
Expand Down Expand Up @@ -52,30 +52,24 @@ def setup(self, model):
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)

def reduce(self, tensor, *args, **kwargs):
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Reduces a tensor from all parallel processes to one aggregated tensor.
Reduces a collection of tensors from all processes. It can be applied to just a single tensor.
Args:
tensor: the tensor to sync and reduce
collection: The collection of tensors to sync and reduce.
*args: ignored for DP
**kwargs: ignored for DP
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
Reduced tensor values or the same value if it was not or did not contain a tensor.
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

else:
def mean(t: torch.Tensor) -> torch.Tensor:
original_dtype = t.dtype
return t.float().mean().to(original_dtype)

def _reduce(t: torch.Tensor):
dtype_tensor = t.dtype
return t.float().mean().type(dtype_tensor)

tensor = apply_to_collection(tensor, torch.Tensor, _reduce)

return tensor
return apply_to_collection(collection, torch.Tensor, mean)

@property
def root_device(self):
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,6 @@ def reduce_across_time(cls, time_outputs):
result['meta'] = meta
return result

def dp_reduce(self):
for k, value in self.items():
if k == 'meta' or isinstance(value, Metric):
continue

if isinstance(value, list):
value = torch.tensor(value)

self[k] = value.mean(dim=-1)

@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torchmetrics import Metric

_METRIC = Union[Metric, torch.Tensor, Number]
# real type is `Union[_METRIC, Dict[str, '_METRIC_COLLECTION']]` but Sphinx fails with `RecursionError`
_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]]
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
EPOCH_OUTPUT = List[STEP_OUTPUT]
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,17 @@ def validation_step(self, batch, batch_idx):

assert trainer.callback_metrics["val_acc"] == 8 / 32.
assert "train_loss" in trainer.callback_metrics


@pytest.mark.parametrize('value', [None, {'a': {'b': None}}])
def test_log_none_raises(tmpdir, value):

class TestModel(BoringModel):

def training_step(self, *args):
self.log("foo", value)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
with pytest.raises(ValueError, match=rf"self.log\(foo, {value}\)` was called"):
trainer.fit(model)

0 comments on commit 1dd61e4

Please sign in to comment.