-
-
Notifications
You must be signed in to change notification settings - Fork 808
Memory efficient backward #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory efficient backward #33
Conversation
Update main branch
TimDettmers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me overall. Please include also a test that tests for minimal difference between fp16 gradients and int8->fp16 gradients. You can add these tests either to the existing tests or create a new one.
bitsandbytes/autograd/_functions.py
Outdated
| # Cast A to fp16 | ||
| A_dtype = A.dtype | ||
| A = A.to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly what we talked about, this should take care of bfloat16! One thing we need to think about is some kind of warning or something when this is first executed to make people aware that a cast happens and the operation quantization is performed in fp16.
Additionally, it would be great to have bf16 tests that verify everything works correctly with those inputs. I think you just need to change one line (and check if everything else is correctly executed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- added a warnings.warn the first time A is cast to a different dtype
- added FP16 tests; fine-tuned typecasts to fit into almost all thresholds
- increased atol 0.0175 -> 0.02 in one specific case for BF16 (this does not affect FP16)
- e.g. grad_bias is now computed on natural grad_output before it is cast (and before it loses precision)
- memory_efficient_backward will set Linear8bitLt.weight.requires_grad = False by default
- added a test for backward pass with memory_efficient_backward
… people aware that a cast happens and the operation quantization is performed in fp16.
| output = output.to(A.dtype).add_(bias) | ||
|
|
||
| # 4. Mixed-precision decomposition matmul | ||
| if coo_tensorA is not None and subA is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curiously, when i tried to replace the next line from
output += torch.matmul(subA, state.subB)
to
output.addmm_(subA, state.subB)
the precision would drop and the tests would fail.
I have no idea why - the dtypes of output, subA and subB are always equal (tested).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?
It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.
If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.
TimDettmers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks all good to me. Let's discuss this briefly tomorrow. I am curious if we can get the .addmm_ to work. Otherwise, just a couple of questions on the test performance. Overall great work! Thank you so much, Yozh!
bitsandbytes/autograd/_functions.py
Outdated
| SCB = (state.SCB.unsqueeze(1) / 127.0).half() | ||
| CB *= SCB | ||
| grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) | ||
| CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how PyTorch implements div, but multiplication is about 30x faster than division. Since we apply it over a matrix this might make a tiny but significant difference. So .mul(1/127.0) might be better here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied, thanks
tests/test_modules.py
Outdated
| (o1 * grad_proj).sum().backward() | ||
| grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() | ||
| scale = grad_ref.abs().mean() | ||
| assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember I had some tests that were using relative difference normalized by the standard deviation, which is similar to this. What is the range of errors that you see? It might also be good to test for a maximum of k elements that exceed a threshold. This helps to differentiate worse-case vs general performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, i've added a separate assert with outliers:
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005
| output = output.to(A.dtype).add_(bias) | ||
|
|
||
| # 4. Mixed-precision decomposition matmul | ||
| if coo_tensorA is not None and subA is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?
It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.
If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.
|
All good! Thank you both! |


TODO #1: double check that both memory_efficient_backward and has_fp16_weights options work properly.
TODO #2: make a clearer PR description.
This PR provides two features: