Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve calibration error speed by replacing for loop #769

Merged
merged 26 commits into from
Jan 20, 2022
Merged

Improve calibration error speed by replacing for loop #769

merged 26 commits into from
Jan 20, 2022

Conversation

ramonemiliani93
Copy link
Contributor

@ramonemiliani93 ramonemiliani93 commented Jan 17, 2022

What does this PR do?

Improve calibration error speed by removing for loop and using bucketize + scatter_add.
Removes the for loop in the calibration error and uses bucketize with scatter add for improved speed (~10x).

Fixes #767

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Oh yes 😎

Make sure you had fun coding 🙃

@Borda Borda added this to the v0.7 milestone Jan 17, 2022
@Borda Borda added the enhancement New feature or request label Jan 17, 2022
@Borda
Copy link
Member

Borda commented Jan 17, 2022

Nice, have you measured the perfomance? 🐰

@codecov
Copy link

codecov bot commented Jan 17, 2022

Codecov Report

Merging #769 (97be5af) into master (d6c423e) will decrease coverage by 24%.
The diff coverage is 17%.

@@           Coverage Diff           @@
##           master   #769     +/-   ##
=======================================
- Coverage      95%    71%    -24%     
=======================================
  Files         171    171             
  Lines        6908   6926     +18     
=======================================
- Hits         6546   4904   -1642     
- Misses        362   2022   +1660     

@ramonemiliani93
Copy link
Contributor Author

@Borda Yes 👌 Only on CPU though, here's the script:

import timeit

import torch


def method_a(confidences, accuracies, bin_boundaries):
    def _method_a():
        conf_bin = torch.zeros_like(bin_boundaries)
        acc_bin = torch.zeros_like(bin_boundaries)
        prop_bin = torch.zeros_like(bin_boundaries)
        for i, (bin_lower, bin_upper) in enumerate(
            zip(bin_boundaries[:-1], bin_boundaries[1:])
        ):
            # Calculated confidence and accuracy in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                acc_bin[i] = accuracies[in_bin].float().mean()
                conf_bin[i] = confidences[in_bin].mean()
                prop_bin[i] = prop_in_bin

    return _method_a


def method_b(confidences, accuracies, bin_boundaries):
    def _method_b():
        acc_bin = torch.zeros(len(bin_boundaries) - 1)
        conf_bin = torch.zeros(len(bin_boundaries) - 1)
        count_bin = torch.zeros(len(bin_boundaries) - 1)

        indices = torch.bucketize(confidences, bin_boundaries) - 1

        count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))

        conf_bin.scatter_add_(dim=0, index=indices, src=confidences)
        conf_bin = torch.nan_to_num(conf_bin / count_bin)

        acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)
        acc_bin = torch.nan_to_num(acc_bin / count_bin)

        prop_bin = count_bin / count_bin.sum()

    return _method_b


n_bins = 20
size = (10000000,)
confidences = torch.rand(size)
accuracies = torch.randint(low=0, high=2, size=size).float()
bin_boundaries = torch.linspace(0, 1, steps=n_bins + 1)

t = timeit.Timer(method_a(confidences, accuracies, bin_boundaries))
print(t.timeit(100))

t = timeit.Timer(method_b(confidences, accuracies, bin_boundaries))
print(t.timeit(100))

The time depends on the size and n_bins, for the values I set the difference was from 60 s to 11.4 s on 100 runs which speeds it up ~6x.

@SkafteNicki
Copy link
Member

Can confirm a 30-50x on GPU with the proposed solution (after implementing the suggested changes) :)

Borda and others added 2 commits January 18, 2022 16:19
@Borda Borda requested a review from SkafteNicki January 18, 2022 15:20
@Borda Borda enabled auto-merge (squash) January 18, 2022 15:21
@Borda Borda mentioned this pull request Jan 18, 2022
4 tasks
Borda and others added 2 commits January 18, 2022 19:42
* Remove deprecated functions, and warnings
* Update links for docstring

Co-authored-by: Daniel Stancl <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
@Borda
Copy link
Member

Borda commented Jan 18, 2022

seems that older PyTorch versions do not have bucketize it seems it was added in 1.8

AttributeError: module 'torch' has no attribute 'bucketize'

said so we can have a hard switch if needed - older versions will use loop, new version this fancy solutions :]

@ramonemiliani93
Copy link
Contributor Author

Should a try-except block do it?

    acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
    conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
    count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
    prop_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
    
    try:
        indices = torch.bucketize(confidences, bin_boundaries) - 1

    except AttributeError:
        for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])):
            # Calculated confidence and accuracy in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                acc_bin[i] = accuracies[in_bin].float().mean()
                conf_bin[i] = confidences[in_bin].mean()
                prop_bin[i] = prop_in_bin
    else:
        count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))
        conf_bin.scatter_add_(dim=0, index=indices, src=confidences)
        conf_bin = torch.nan_to_num(conf_bin / count_bin)
        acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)
        acc_bin = torch.nan_to_num(acc_bin / count_bin)
        prop_bin = count_bin / count_bin.sum()

@Borda Borda changed the title Improve calibration error speed by removing for loop and using bucketize + scatter_add. Improve calibration error speed by replacing for loop Jan 19, 2022
@SkafteNicki
Copy link
Member

seems that older PyTorch versions do not have bucketize it seems it was added in 1.8

AttributeError: module 'torch' has no attribute 'bucketize'

said so we can have a hard switch if needed - older versions will use loop, new version this fancy solutions :]

@Borda sure about this?
We have the following code elsewhere also relying on bucketize:
https://github.com/PyTorchLightning/metrics/blob/c519402693a6cb235367614f192d67332cbb4bc0/torchmetrics/functional/classification/auroc.py#L105-L109
From that it seems that 1.6 was the point it was introduced.

@Borda
Copy link
Member

Borda commented Jan 19, 2022

We have the following code elsewhere also relying on bucketize:

good point, I did not dive into details, just saw that all tests lower 1.8 was failing (could be also other reason lol)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
@Borda
Copy link
Member

Borda commented Jan 19, 2022

@ramonemiliani93 do you think we can finish it so we can include it in the next bug-fix release? 🐰

@SkafteNicki
Copy link
Member

@Borda should be taken care of now

@ramonemiliani93
Copy link
Contributor Author

ramonemiliani93 commented Jan 20, 2022

@Borda Sorry for the late reply! I've had a lot of things on my side 😔 I just saw that @SkafteNicki already solved it. I hope to be able to contribute more next time!

@mergify mergify bot added the ready label Jan 20, 2022
@SkafteNicki
Copy link
Member

It complains about torch.nan_to_num which was introduced in 1.8...
I increase the required version for the faster version :]

@Borda Borda disabled auto-merge January 20, 2022 19:10
@Borda Borda enabled auto-merge (squash) January 20, 2022 19:15
@Borda Borda disabled auto-merge January 20, 2022 19:16
@Borda Borda merged commit 51d952d into Lightning-AI:master Jan 20, 2022
Borda pushed a commit that referenced this pull request Jan 20, 2022
* Improve speed by removing for loop and using bucketize + scatter_add.
* fast and slow binning
* Apply suggestions from code review
* cleaning & flake8
* increase to 1.8

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
(cherry picked from commit 51d952d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove for loop on calibration error.
5 participants