-
Notifications
You must be signed in to change notification settings - Fork 413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve calibration error speed by replacing for
loop
#769
Conversation
Nice, have you measured the perfomance? 🐰 |
Codecov Report
@@ Coverage Diff @@
## master #769 +/- ##
=======================================
- Coverage 95% 71% -24%
=======================================
Files 171 171
Lines 6908 6926 +18
=======================================
- Hits 6546 4904 -1642
- Misses 362 2022 +1660 |
@Borda Yes 👌 Only on CPU though, here's the script: import timeit
import torch
def method_a(confidences, accuracies, bin_boundaries):
def _method_a():
conf_bin = torch.zeros_like(bin_boundaries)
acc_bin = torch.zeros_like(bin_boundaries)
prop_bin = torch.zeros_like(bin_boundaries)
for i, (bin_lower, bin_upper) in enumerate(
zip(bin_boundaries[:-1], bin_boundaries[1:])
):
# Calculated confidence and accuracy in each bin
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0:
acc_bin[i] = accuracies[in_bin].float().mean()
conf_bin[i] = confidences[in_bin].mean()
prop_bin[i] = prop_in_bin
return _method_a
def method_b(confidences, accuracies, bin_boundaries):
def _method_b():
acc_bin = torch.zeros(len(bin_boundaries) - 1)
conf_bin = torch.zeros(len(bin_boundaries) - 1)
count_bin = torch.zeros(len(bin_boundaries) - 1)
indices = torch.bucketize(confidences, bin_boundaries) - 1
count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))
conf_bin.scatter_add_(dim=0, index=indices, src=confidences)
conf_bin = torch.nan_to_num(conf_bin / count_bin)
acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)
acc_bin = torch.nan_to_num(acc_bin / count_bin)
prop_bin = count_bin / count_bin.sum()
return _method_b
n_bins = 20
size = (10000000,)
confidences = torch.rand(size)
accuracies = torch.randint(low=0, high=2, size=size).float()
bin_boundaries = torch.linspace(0, 1, steps=n_bins + 1)
t = timeit.Timer(method_a(confidences, accuracies, bin_boundaries))
print(t.timeit(100))
t = timeit.Timer(method_b(confidences, accuracies, bin_boundaries))
print(t.timeit(100)) The time depends on the |
Can confirm a 30-50x on GPU with the proposed solution (after implementing the suggested changes) :) |
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
* Remove deprecated functions, and warnings * Update links for docstring Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
seems that older PyTorch versions do not have
said so we can have a hard switch if needed - older versions will use loop, new version this fancy solutions :] |
Should a acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
prop_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device)
try:
indices = torch.bucketize(confidences, bin_boundaries) - 1
except AttributeError:
for i, (bin_lower, bin_upper) in enumerate(zip(bin_boundaries[:-1], bin_boundaries[1:])):
# Calculated confidence and accuracy in each bin
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0:
acc_bin[i] = accuracies[in_bin].float().mean()
conf_bin[i] = confidences[in_bin].mean()
prop_bin[i] = prop_in_bin
else:
count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))
conf_bin.scatter_add_(dim=0, index=indices, src=confidences)
conf_bin = torch.nan_to_num(conf_bin / count_bin)
acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)
acc_bin = torch.nan_to_num(acc_bin / count_bin)
prop_bin = count_bin / count_bin.sum() |
for
loop
@Borda sure about this? |
good point, I did not dive into details, just saw that all tests lower 1.8 was failing (could be also other reason lol) |
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
@ramonemiliani93 do you think we can finish it so we can include it in the next bug-fix release? 🐰 |
for more information, see https://pre-commit.ci
@Borda should be taken care of now |
@Borda Sorry for the late reply! I've had a lot of things on my side 😔 I just saw that @SkafteNicki already solved it. I hope to be able to contribute more next time! |
Co-authored-by: Justus Schock <[email protected]>
It complains about |
* Improve speed by removing for loop and using bucketize + scatter_add. * fast and slow binning * Apply suggestions from code review * cleaning & flake8 * increase to 1.8 Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]> (cherry picked from commit 51d952d)
What does this PR do?
Improve calibration error speed by removing for loop and using
bucketize
+scatter_add
.Removes the for loop in the calibration error and uses bucketize with scatter add for improved speed (~10x).
Fixes #767
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Oh yes 😎
Make sure you had fun coding 🙃