Skip to content

Commit

Permalink
Merge branch 'master' into instance_segmentation_metric
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored May 24, 2022
2 parents 2a991a7 + 85d798e commit 7006e01
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 4 deletions.
12 changes: 12 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,15 @@ def get_memory_usage():
metric.update(x.sum())
memory = get_memory_usage()
assert base_memory_level >= memory, "memory increased above base level"


@pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum])
def test_warning_on_not_set_full_state_update(metric_class):
class UnsetProperty(metric_class):
full_state_update = None

with pytest.warns(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
UnsetProperty()
2 changes: 2 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def run_differentiability_test(

class DummyMetric(Metric):
name = "Dummy"
full_state_update: Optional[bool] = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -584,6 +585,7 @@ def compute(self):

class DummyListMetric(Metric):
name = "DummyList"
full_state_update: Optional[bool] = True

def __init__(self):
super().__init__()
Expand Down
30 changes: 29 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics import MeanSquaredError, PearsonCorrCoef
from torchmetrics.utilities import check_forward_no_full_state, rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.checks import _allclose_recursive
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce

Expand Down Expand Up @@ -126,3 +128,29 @@ def test_bincount():
# check for correctness
assert torch.allclose(res1, res2)
assert torch.allclose(res1, res3)


@pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, True), (PearsonCorrCoef, False)])
def test_check_full_state_update_fn(metric_class, expected):
"""Test that the check function works as it should."""
out = check_forward_no_full_state(
metric_class=metric_class,
input_args=dict(preds=torch.randn(100), target=torch.randn(100)),
)
assert out == expected


@pytest.mark.parametrize(
"input, expected",
[
((torch.ones(2), torch.ones(2)), True),
((torch.rand(2), torch.rand(2)), False),
(([torch.ones(2) for _ in range(2)], [torch.ones(2) for _ in range(2)]), True),
(([torch.rand(2) for _ in range(2)], [torch.rand(2) for _ in range(2)]), False),
(({f"{i}": torch.ones(2) for i in range(2)}, {f"{i}": torch.ones(2) for i in range(2)}), True),
(({f"{i}": torch.rand(2) for i in range(2)}, {f"{i}": torch.rand(2) for i in range(2)}), False),
],
)
def test_recursive_allclose(input, expected):
res = _allclose_recursive(*input)
assert res == expected
18 changes: 16 additions & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Metric(Module, ABC):
__jit_unused_properties__ = ["is_differentiable"]
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
full_state_update: bool = True
full_state_update: Optional[bool] = None

def __init__(
self,
Expand Down Expand Up @@ -127,6 +127,20 @@ def __init__(
self._is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None

if self.full_state_update is None:
rank_zero_warn(
f"""Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class ({self.__class__.__name__}). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
""",
UserWarning,
)

@property
def _update_called(self) -> bool:
# Needed for lightning integration
Expand Down Expand Up @@ -216,7 +230,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"HINT: Did you forget to call ``unsync`` ?."
)

if self.full_state_update or self.dist_sync_on_step:
if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
self._forward_cache = self._forward_full_state_update(*args, **kwargs)
else:
self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torchmetrics.utilities.checks import check_forward_no_full_state # noqa: F401
from torchmetrics.utilities.data import apply_to_collection # noqa: F401
from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401
from torchmetrics.utilities.prints import _future_warning, rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401
119 changes: 118 additions & 1 deletion torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from time import perf_counter
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, no_type_check

import torch
from torch import Tensor
Expand Down Expand Up @@ -604,3 +605,119 @@ def _check_retrieval_target_and_prediction_types(
preds = preds.float()

return preds.flatten(), target.flatten()


def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool:
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(res1, Tensor):
return torch.allclose(res1, res2, atol=atol)
elif isinstance(res1, str):
return res1 == res2
elif isinstance(res1, Sequence):
return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2))
elif isinstance(res1, Mapping):
return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys())
return res1 == res2


@no_type_check
def check_forward_no_full_state(
metric_class,
init_args: Dict[str, Any] = {},
input_args: Dict[str, Any] = {},
num_update_to_compare: Sequence[int] = [10, 100, 1000],
reps: int = 5,
) -> bool:
"""Utility function for checking if the new ``full_state_update`` property can safely be set to ``False`` which
will for most metrics results in a speedup when using ``forward``.
Args:
metric_class: metric class object that should be checked
init_args: dict containing arguments for initializing the metric class
input_args: dict containing arguments to pass to ``forward``
num_update_to_compare: if we successfully detech that the flag is safe to set to ``False``
we will run some speedup test. This arg should be a list of integers for how many
steps to compare over.
reps: number of repetitions of speedup test
Example (states in ``update`` are independent, save to set ``full_state_update=False``)
>>> from torchmetrics import ConfusionMatrix
>>> check_forward_no_full_state(
... ConfusionMatrix,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... ) # doctest: +ELLIPSIS
Full state for 10 steps took: ...
Partial state for 10 steps took: ...
Full state for 100 steps took: ...
Partial state for 100 steps took: ...
Full state for 1000 steps took: ...
Partial state for 1000 steps took: ...
True
Example (states in ``update`` are dependend meaning that ``full_state_update=True``):
>>> from torchmetrics import ConfusionMatrix
>>> class MyMetric(ConfusionMatrix):
... def update(self, preds, target):
... super().update(preds, target)
... # by construction make future states dependent on prior states
... if self.confmat.sum() > 20:
... self.reset()
>>> check_forward_no_full_state(
... MyMetric,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... )
False
"""

class FullState(metric_class):
full_state_update = True

class PartState(metric_class):
full_state_update = False

fullstate = FullState(**init_args)
partstate = PartState(**init_args)

equal = True
for _ in range(num_update_to_compare[0]):
out1 = fullstate(**input_args)
try: # if it fails, the code most likely need access to the full state
out2 = partstate(**input_args)
except RuntimeError:
equal = False
break
equal = equal & _allclose_recursive(out1, out2)

res1 = fullstate.compute()
try: # if it fails, the code most likely need access to the full state
res2 = partstate.compute()
except RuntimeError:
equal = False
equal = equal & _allclose_recursive(res1, res2)

if not equal: # we can stop early because the results did not match
return False

# Do timings
res = torch.zeros(2, len(num_update_to_compare), reps)
for i, metric in enumerate([fullstate, partstate]):
for j, t in enumerate(num_update_to_compare):
for r in range(reps):
start = perf_counter()
for _ in range(t):
_ = metric(**input_args)
end = perf_counter()
res[i, j, r] = end - start
metric.reset()

mean = torch.mean(res, -1)
std = torch.std(res, -1)

for t in range(len(num_update_to_compare)):
print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}")
print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}")

return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading

0 comments on commit 7006e01

Please sign in to comment.