Skip to content

Commit

Permalink
New modular metric interface (#2528)
Browse files Browse the repository at this point in the history
* new base structure

* missing packages

* updated interface

* revert some changes

* fixes

* add changelog

* fix bug

* added description

* test for pickable

* fixing test

* fixing test

* fix pickle issue

* reduceop typehints back

* remove redundant module arg

* add save/load test

* add aggregate method

* text clarification

* fix doctest

* Apply suggestions from code review

* change test to results obj

* fix docs

* formatting

Co-authored-by: Rohit Gupta <[email protected]>

* formatting

* pep

* Update CHANGELOG.md

* suggestions

* fix tests

* fix pep8

* fix tests

Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
5 people authored Aug 26, 2020
1 parent 0112355 commit 17d8773
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor
import os

from pytorch_lightning.metrics.converters import _sync_ddp_if_available
from pytorch_lightning.metrics.converters import sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -124,7 +124,7 @@ def log(

# sync across ddp
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
value = _sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)
value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)

if 'meta' not in self:
self.__setitem__('meta', {})
Expand Down
64 changes: 49 additions & 15 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class ReduceOp:
SUM = None

rank_zero_warn('Unsupported `ReduceOp` for distributed computing.')
rank_zero_warn('Unsupported `ReduceOp` for distributed computing')


def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
Expand Down Expand Up @@ -86,28 +86,30 @@ def new_func(*args, **kwargs):
return decorator_fn


def _convert_to_tensor(data: Any) -> Any:
def convert_to_tensor(data: Any, dtype=None, device=None) -> Any:
"""
Maps all kind of collections and numbers to tensors.
Args:
data: the data to convert to tensor
dtype: data type to convert to
device: device to cast to
Return:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
return torch.tensor([data], dtype=dtype, device=device)
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
return torch.from_numpy(data).to(device=device, dtype=dtype)
elif isinstance(data, torch.Tensor):
return data
return data.to(device=device, dtype=dtype)

raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.
Args:
Expand Down Expand Up @@ -137,7 +139,7 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)


def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -150,7 +152,7 @@ def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
return _apply_to_outputs(convert_to_tensor)(func_to_decorate)


def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -184,7 +186,7 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)


def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -198,7 +200,7 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
_convert_to_tensor)(func_to_decorate)
convert_to_tensor)(func_to_decorate)


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -238,10 +240,10 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
return _tensor_collection_metric_output_conversion(func_convert_inputs)


def _sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None,
) -> torch.Tensor:
def sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Expand Down Expand Up @@ -278,6 +280,38 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
return result


def gather_all_tensors_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if group is None:
group = torch.distributed.group.WORLD

world_size = torch.distributed.get_world_size(group)

gathered_result = world_size * [torch.zeros_like(result)]

# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)

result = gathered_result

return result


def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
Expand All @@ -294,7 +328,7 @@ def sync_ddp(group: Optional[Any] = None,

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
_sync_ddp_if_available, group=group,
sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)

return decorator_fn
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def confusion_matrix(
"""
num_classes = get_num_classes(pred, target, None)

unique_labels = target.view(-1) * num_classes + pred.view(-1)
unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int)

bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
cm = bins.reshape(num_classes, num_classes).squeeze().float()
Expand Down
Loading

0 comments on commit 17d8773

Please sign in to comment.