Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New modular metric interface #2528

Merged
merged 42 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0ba7d63
new base structure
Jun 26, 2020
6481b6b
missing packages
Jun 26, 2020
f6a0a4d
updated interface
Jun 30, 2020
9368ac6
revert some changes
Jul 6, 2020
5d76528
fixes
Jul 6, 2020
6709952
Merge branch 'master' into new_metric_interface
SkafteNicki Jul 6, 2020
4958cb2
add changelog
Jul 6, 2020
0a3849e
fix bug
Jul 6, 2020
883efa9
added description
Jul 6, 2020
451d4b6
'merge'
SkafteNicki Aug 5, 2020
ea0dfc6
merge
SkafteNicki Aug 5, 2020
cdf2dbd
merge
SkafteNicki Aug 7, 2020
d99821e
test for pickable
SkafteNicki Aug 7, 2020
e8a6a7b
fixing test
SkafteNicki Aug 7, 2020
9626b22
fixing test
SkafteNicki Aug 7, 2020
1f458f6
fix pickle issue
SkafteNicki Aug 7, 2020
bdf9364
'mergeø
SkafteNicki Aug 10, 2020
be70ac8
reduceop typehints back
SkafteNicki Aug 10, 2020
fd6e719
remove redundant module arg
SkafteNicki Aug 11, 2020
9d16c69
add save/load test
SkafteNicki Aug 11, 2020
810d3dd
add aggregate method
SkafteNicki Aug 11, 2020
4923e34
text clarification
SkafteNicki Aug 11, 2020
ae60d3d
fix doctest
SkafteNicki Aug 11, 2020
a30a033
merge
SkafteNicki Aug 13, 2020
ffb4ed7
merge
SkafteNicki Aug 18, 2020
796d913
Apply suggestions from code review
awaelchli Aug 22, 2020
f14ec8f
Merge branch 'master' into new_metric_interface
awaelchli Aug 22, 2020
8a3a128
change test to results obj
SkafteNicki Aug 24, 2020
26b4bb4
fix docs
SkafteNicki Aug 24, 2020
8289a41
merge
SkafteNicki Aug 24, 2020
8aae0be
formatting
Borda Aug 25, 2020
b783ec5
formatting
rohitgr7 Aug 25, 2020
e0ba557
formatting
rohitgr7 Aug 25, 2020
da4e94d
formatting
rohitgr7 Aug 25, 2020
a27aebd
formatting
rohitgr7 Aug 25, 2020
9641f84
formatting
rohitgr7 Aug 25, 2020
9fabc89
pep
rohitgr7 Aug 25, 2020
619ffc2
Update CHANGELOG.md
Borda Aug 25, 2020
b03ce8f
suggestions
Aug 26, 2020
ac169df
fix tests
Aug 26, 2020
c1a1389
fix pep8
Aug 26, 2020
847d0ed
fix tests
Aug 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor
import os

from pytorch_lightning.metrics.converters import _sync_ddp_if_available
from pytorch_lightning.metrics.converters import sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -124,7 +124,7 @@ def log(

# sync across ddp
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
value = _sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)
value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)

if 'meta' not in self:
self.__setitem__('meta', {})
Expand Down
64 changes: 49 additions & 15 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class ReduceOp:
SUM = None

rank_zero_warn('Unsupported `ReduceOp` for distributed computing.')
rank_zero_warn('Unsupported `ReduceOp` for distributed computing')


def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
Expand Down Expand Up @@ -86,28 +86,30 @@ def new_func(*args, **kwargs):
return decorator_fn


def _convert_to_tensor(data: Any) -> Any:
def convert_to_tensor(data: Any, dtype=None, device=None) -> Any:
"""
Maps all kind of collections and numbers to tensors.

Args:
data: the data to convert to tensor
dtype: data type to convert to
device: device to cast to

Return:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
return torch.tensor([data], dtype=dtype, device=device)
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
return torch.from_numpy(data).to(device=device, dtype=dtype)
elif isinstance(data, torch.Tensor):
return data
return data.to(device=device, dtype=dtype)

raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.

Args:
Expand Down Expand Up @@ -137,7 +139,7 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)


def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -150,7 +152,7 @@ def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
return _apply_to_outputs(convert_to_tensor)(func_to_decorate)


def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -184,7 +186,7 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)


def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -198,7 +200,7 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
_convert_to_tensor)(func_to_decorate)
convert_to_tensor)(func_to_decorate)


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -238,10 +240,10 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
return _tensor_collection_metric_output_conversion(func_convert_inputs)


def _sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None,
) -> torch.Tensor:
def sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Expand Down Expand Up @@ -278,6 +280,38 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
return result


def gather_all_tensors_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes

Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)

Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i

"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if group is None:
group = torch.distributed.group.WORLD

world_size = torch.distributed.get_world_size(group)

gathered_result = world_size * [torch.zeros_like(result)]

# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)

result = gathered_result

return result


def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
Expand All @@ -294,7 +328,7 @@ def sync_ddp(group: Optional[Any] = None,

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
_sync_ddp_if_available, group=group,
sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)

return decorator_fn
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def confusion_matrix(
"""
num_classes = get_num_classes(pred, target, None)

unique_labels = target.view(-1) * num_classes + pred.view(-1)
unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int)

bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
cm = bins.reshape(num_classes, num_classes).squeeze().float()
Expand Down
Loading