Skip to content

Commit

Permalink
Use argmax when topk=1 (#419)
Browse files Browse the repository at this point in the history
* argmax for k=1
* changelog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and pre-commit-ci[bot] authored Aug 3, 2021
1 parent 21fe0ca commit 5a2388c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Pearson metrics now only store 6 statistics instead of all predictions and targets ([#380](https://github.com/PyTorchLightning/metrics/pull/380))


- Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419))


### Deprecated

- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))
Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def select_topk(prob_tensor: Tensor, topk: int = 1, dim: int = 1) -> Tensor:
[1, 1, 0]], dtype=torch.int32)
"""
zeros = torch.zeros_like(prob_tensor)
topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0)
if topk == 1: # argmax has better performance than topk
topk_tensor = zeros.scatter(dim, prob_tensor.argmax(dim=dim, keepdim=True), 1.0)
else:
topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0)
return topk_tensor.int()


Expand Down

0 comments on commit 5a2388c

Please sign in to comment.