-
Notifications
You must be signed in to change notification settings - Fork 413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/remove double forward #984
Conversation
Codecov Report
@@ Coverage Diff @@
## master #984 +/- ##
======================================
Coverage 95% 95%
======================================
Files 180 180
Lines 7666 7823 +157
======================================
+ Hits 7276 7430 +154
- Misses 390 393 +3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also include in the PR the speed-up chat? 🐰
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a substantial change. While it does indeed prevent two update calls, I feel from the user experience it might be complicating the process of implementing custom metrics.
Not saying that I am in general against merging this one, but I don't want to rush it here... I feel like we should discuss the implications a bit more
@justusschock I completely agree that this is not a trival change and should only be done with care. We can also consider if this should be an opt-in feature. With the changes we did to the additional metric arguments, collapsing them all into the class Metric(nn.Module):
def __init__(self, **kwargs):
...
self.use_fast_forward = kwargs.pop('use_fast_forward', False) # find better name, what should default be?
...
def forward(self, *args, **kwargs):
if self.use_fast_forward:
return self.forward_method_that_only_calls_update_once(*args, **kwargs)
else:
return self.forward_as_it_already_is(*args, **kwargs) @Borda any opinions? |
@SkafteNicki ready to go? 🎉 |
Code should be ready. Let me run some speed tests before we merge to make sure that we actually get the speedup that we expect :] |
Sweet, could you pls include it also here in this PR? :) |
Yes will do :] |
@justusschock and @Borda, finally created an updated figure: The TLDR is that metrics where Code to create figurefrom time import perf_counter
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
import torchmetrics
NUM_REPS = 5
NUM_CALLS = [1, 10, 25, 50, 75, 100, 250, 500, 750, 1000, 2500]
metrics = [
torchmetrics.MeanSquaredError,
torchmetrics.CosineSimilarity,
torchmetrics.Accuracy,
torchmetrics.ConfusionMatrix,
torchmetrics.StructuralSimilarityIndexMeasure,
torchmetrics.audio.sdr.SignalDistortionRatio,
torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity,
torchmetrics.SQuAD,
torchmetrics.WordErrorRate,
torchmetrics.AUROC,
]
metric_args = [{}, {}, {"num_classes": 10}, {"num_classes": 10}, {}, {}, {"net_type": "alex"}, {}, {}, {"num_classes": 3}]
inputs = [
(torch.randn(100,), torch.randn(100,)),
(torch.randn(100,), torch.randn(100,)),
(torch.randn(100, 10).softmax(dim=-1), torch.randint(10, (100,))),
(torch.randn(100, 10).softmax(dim=-1), torch.randint(10, (100,))),
(torch.rand(5, 3, 25, 25), torch.rand(5, 3, 25, 25)),
(torch.randn(1, 8000), torch.randn(1, 8000)),
(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 32, 32)),
([{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}], [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]),
(["this is the prediction", "there is an other sample"], ["this is the reference", "there is another one"]),
(torch.tensor([[0.90, 0.05, 0.05], [0.05, 0.90, 0.05], [0.05, 0.05, 0.90], [0.85, 0.05, 0.10], [0.10, 0.10, 0.80]]), torch.tensor([0, 1, 1, 2, 2]))
]
def get_metric_classes(base_class):
class Old(base_class):
full_state_update = True
class New(base_class):
full_state_update = False
return [Old, New]
if __name__ == "__main__":
res = {True: {}, False: {}}
for base_metric_class, metric_args, args in zip(metrics, metric_args, inputs):
print(f"Testing {base_metric_class}")
name = base_metric_class.__name__
OldClass, NewClass = get_metric_classes(base_metric_class)
for metric, enabled in zip([OldClass(**metric_args), NewClass(**metric_args)], [False, True]):
res[enabled][name] = np.zeros((len(NUM_CALLS), NUM_REPS))
for i, s in tqdm(enumerate(NUM_CALLS), total=len(NUM_CALLS)):
for r in range(NUM_REPS):
start = perf_counter()
for _ in range(s):
val = metric(*args)
end = perf_counter()
metric.reset()
res[enabled][name][i, r] = end - start
fig, ax = plt.subplots(nrows=2, ncols=5)
for count, metric in enumerate(metrics):
i = count % 2
j = count % 5
name = metric.__name__
mean_old = res[False][name].mean(axis=-1)
std_old = res[False][name].std(axis=-1)
mean_new = res[True][name].mean(axis=-1)
std_new = res[True][name].std(axis=-1)
ax[i, j].plot(NUM_CALLS, mean_old, label="Old standard")
ax[i, j].fill_between(NUM_CALLS, mean_old + std_old, mean_old - std_old, alpha=0.1)
ax[i, j].plot(NUM_CALLS, mean_new, label="New standard")
ax[i, j].fill_between(NUM_CALLS, mean_new + std_new, mean_new - std_new, alpha=0.1)
if i == 1:
ax[i, j].set_xlabel("Number of forward calls", fontsize=10)
if j == 0:
ax[i, j].set_ylabel("Time (sec)", fontsize=10)
ax[i, j].set_title(name, fontsize=10, fontweight='bold')
plt.setp(ax[i, j].get_yticklabels(), fontsize=5)
ax[i, j].legend(loc="upper left", fontsize=10)
plt.show() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet!
What does this PR do?
Redo of #612
Fixes part of #344 (needs review after if we can close the issue)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃