Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistake in parameters' grad norm tracking #2012

Merged
merged 9 commits into from
Jun 2, 2020

Conversation

ivannz
Copy link
Contributor

@ivannz ivannz commented May 30, 2020

There is a mistake in grad_norms in core.grads.GradInformation. The mistake affects the reported gradient norm for every individual parameters, but not the total norm. This PR fixes the erroneous computation.

According to torch docs torch.norm(tensor, p) computes the vector norm: either 2-norm, or sum(abs(x)**p)**(1./p) if p != 2, and flattens the tensor if required. Now p.grad.data.norm in grads.py#L17 is the same as torch.norm(p.grad.data, ...), and thus computes the norm. However on grads.py#L19 grad tracker takes the p-th root again, essentially making params_norm equal to sum(abs(x)**p)**(1./(p**2)) = norm**(1./p) which is not correct.

The following snippet, which borrows code from grads.py#L17-L19

import torch
import numpy as np

x = torch.randn(3, 4, 32, 32) * 0.1
for norm_type in [1., 1.5, 2., 4.]:
    # numpy reference value
    np_val = np.linalg.norm(x.numpy().ravel(), norm_type)

    # according to docs (and definition of p-norm)
    tr_val_naive = (abs(x) ** norm_type).sum().pow(1./norm_type)

    # almost verbatim from grads.py#L17-L19
    param_norm = x.norm(norm_type)
    grad_norm = param_norm ** (1./norm_type)

    print(f'numpy.norm = {float(np_val):4g}, '
          f'torch.naive = {float(tr_val_naive):4g}, '
          f'grad_norm = {float(grad_norm):4g}')

produces

numpy.norm = 972.305, torch.naive = 972.305, grad_norm = 972.306
numpy.norm = 47.8891, torch.naive = 47.8891, grad_norm = 13.1874
numpy.norm = 11.0455, torch.naive = 11.0455, grad_norm = 3.32348
numpy.norm = 1.38376, torch.naive = 1.38376, grad_norm = 1.08459

which is clearly incorrect.

@mergify mergify bot requested a review from a team May 30, 2020 10:28
@ivannz
Copy link
Contributor Author

ivannz commented May 30, 2020

If necessary there is a patch to this fix that would enable support of infinity-norm computation:

  1. replace total_norm with a list and accumulate individual norms in it
  2. use total_norm = torch.tensor(total_norm).norm(norm_type)

@Borda Borda added the bug Something isn't working label May 30, 2020
@Borda Borda added this to the 0.8.0 milestone May 30, 2020
@williamFalcon
Copy link
Contributor

@ivannz excellent! mind adding a test to make sure we're matching what numpy expects?

@ivannz ivannz force-pushed the grad-track-fix branch 2 times, most recently from c0332c6 to 944bbd6 Compare May 31, 2020 10:41
@ivannz
Copy link
Contributor Author

ivannz commented May 31, 2020

I can add a unit test, but I have a question. Since the latest code relies completely on torch.norm, it is it necessary to test if torch.norm coincides with np.linalg.norm? It would seem that pytorch's ATen should already test its implementation against BLAS.

@codecov
Copy link

codecov bot commented May 31, 2020

Codecov Report

Merging #2012 into master will increase coverage by 0%.
The diff coverage is 89%.

@@          Coverage Diff           @@
##           master   #2012   +/-   ##
======================================
  Coverage      88%     88%           
======================================
  Files          74      74           
  Lines        4666    4664    -2     
======================================
- Hits         4084    4083    -1     
+ Misses        582     581    -1     

@williamFalcon
Copy link
Contributor

yes, but we need to make sure the overall result of this gradient clipping and the norm calculation are correct... we've addressed this issue multiple times already and at this point we just need more rigorous testing on this so that we don't have to revisit...

@ivannz
Copy link
Contributor Author

ivannz commented May 31, 2020

I have added a test, @williamFalcon .

Due to explicit rounding to 3 places in grad_norm the test uses 5e-3 relative tolerance for checking the norms.

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, good work :)
The norm computation is correct, thanks for the fix.
I think the test can be simplified, see my comments.

@mergify mergify bot requested a review from a team June 1, 2020 05:39
@ivannz
Copy link
Contributor Author

ivannz commented Jun 1, 2020

To reflect similarity in spirit with clip_grad_norm I added support for an explicit 'inf' string in trainer's track_grad_norm argument: it can be an int, float or an 'inf' for an infinity-norm (or negative to disable norm tracking). The value is type-cheked in Trainer.__init__ (with MisconfigurationException) and cast to float on every check in Trainer.run_training_batch().

@Borda Borda requested a review from awaelchli June 1, 2020 12:08
@mergify mergify bot requested a review from a team June 1, 2020 12:41
@ivannz
Copy link
Contributor Author

ivannz commented Jun 1, 2020

@awaelchli and @Borda thank you for the code review!

@Borda Thank you for your edits! As for your doc string style fix, I apologise for not having read CONTRBUTING.md carefully.

@Borda
Copy link
Member

Borda commented Jun 1, 2020

@Borda Thank you for your edits! As for your doc string style fix, I apologise for not having read CONTRBUTING.md carefully.

that is fine, we thank you for your contribution :]

@williamFalcon williamFalcon merged commit e85a646 into Lightning-AI:master Jun 2, 2020
justusschock pushed a commit that referenced this pull request Jun 29, 2020
* fix grad norm formula

* grad-norm tracker test

* fixed seed and explicit rtol in grad norm tracking test

* a docstring for grad-norms and forced cast to float of norm_type

* support for inf-norm

* renamed the grad norm test

* docs

* fixed language in docstring

* Apply suggestions from code review

Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants