Skip to content

Commit

Permalink
Fix bug in metric tracker (#1306)
Browse files Browse the repository at this point in the history
* fix + test

* changelog
  • Loading branch information
SkafteNicki authored Nov 3, 2022
1 parent 669bca0 commit 8fc6de8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed restrictive dtype checking in `spearman_corrcoef` when used with autocast ([#1303](https://github.com/Lightning-AI/metrics/pull/1303))


- Fixed bug in `Metrictracker.best_metric` when `return_step=False` ([#1306](https://github.com/Lightning-AI/metrics/pull/1306))


## [0.10.1] - 2022-10-21

### Fixed
Expand Down
22 changes: 11 additions & 11 deletions src/torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -21,6 +20,7 @@

from torchmetrics.collections import MetricCollection
from torchmetrics.metric import Metric
from torchmetrics.utilities.prints import rank_zero_warn


class MetricTracker(ModuleList):
Expand Down Expand Up @@ -171,12 +171,12 @@ def best_metric(
if isinstance(self._base_metric, Metric):
fn = torch.max if self.maximize else torch.min
try:
idx, best = fn(self.compute_all(), 0)
value, idx = fn(self.compute_all(), 0)
if return_step:
return idx.item(), best.item()
return best.item()
return value.item(), idx.item()
return value.item()
except ValueError as error:
warnings.warn(
rank_zero_warn(
f"Encountered the following error when trying to get the best metric: {error}"
"this is probably due to the 'best' not being defined for this metric."
"Returning `None` instead.",
Expand All @@ -189,24 +189,24 @@ def best_metric(
else: # this is a metric collection
res = self.compute_all()
maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize]
idx, best = {}, {}
value, idx = {}, {}
for i, (k, v) in enumerate(res.items()):
try:
fn = torch.max if maximize[i] else torch.min
out = fn(v, 0)
idx[k], best[k] = out[0].item(), out[1].item()
value[k], idx[k] = out[0].item(), out[1].item()
except ValueError as error:
warnings.warn(
rank_zero_warn(
f"Encountered the following error when trying to get the best metric for metric {k}:"
f"{error} this is probably due to the 'best' not being defined for this metric."
"Returning `None` instead.",
UserWarning,
)
idx[k], best[k] = None, None
value[k], idx[k] = None, None

if return_step:
return idx, best
return best
return value, idx
return value

def _check_for_increment(self, method: str) -> None:
if not self._increment_called:
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def test_tracker(base_metric, metric_input, maximize):
assert val != 0.0
assert idx in list(range(5))

val2 = tracker.best_metric(return_step=False)
assert val == val2


@pytest.mark.parametrize(
"base_metric",
Expand Down

0 comments on commit 8fc6de8

Please sign in to comment.