-
-
Notifications
You must be signed in to change notification settings - Fork 812
Memory efficient backward #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 52 commits
8ae9bb2
1753aa0
656de8e
876387d
ef2936a
4d6174b
b3fee1e
8d34d36
843ad06
42b5fc9
ee325f0
d358999
4dd475c
e2a7576
3634fc7
cc4858c
469d5a6
a9c7953
140cdbe
9379df8
e29c5f5
fc4a135
a9fe0ff
eac9aca
7facedd
d9ca0ed
56a074f
e9b8711
0de1a44
647c976
210b9ed
85bf529
e2b523d
d6e25b5
1145589
1da4880
5b169f1
14048a3
a214824
702cc72
45dc198
577275b
e35e2c6
cbfdf0b
ab9dee0
fa8e07c
f667032
18f142e
76ece2c
579b8c7
591f603
2cd047e
7906dc4
4b4a9ef
4da2227
5d65817
d9b8789
6a826c4
37f805b
95dafc6
28a9313
725cc72
e4086a2
01b4c6a
32a9a88
cff3a71
9b7d307
a07825a
292a478
76ce9aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import operator | ||
| import warnings | ||
|
|
||
| import torch | ||
| import bitsandbytes.functional as F | ||
|
|
||
|
|
@@ -210,32 +212,29 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): | |
| ctx.B = B | ||
| ctx.bias = bias | ||
| if A.shape[-1] == B.shape[0]: | ||
| return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) | ||
| return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) | ||
| else: | ||
| return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) | ||
| return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) | ||
|
|
||
| # 1. Quantize A | ||
| # 2. Quantize B | ||
| # 3. Matmul | ||
| # 4. Mixed-precision decomposition matmul | ||
| # 5. Save state | ||
| requires_gradA = A.requires_grad | ||
| requires_gradB = B.requires_grad | ||
| requires_gradBias = bias is not None and bias.requires_grad | ||
| formatB = state.formatB | ||
| input_shape = A.shape | ||
| if state.outlier_pool is None: | ||
| state.outlier_pool = GlobalOutlierPooler.get_instance() | ||
|
|
||
| # Cast A to fp16 | ||
| A_dtype = A.dtype | ||
| A = A.to(torch.float16) | ||
| if A.dtype != torch.float16: | ||
| warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") | ||
|
|
||
| # 1. Quantize A | ||
| if len(A.shape) == 3: | ||
| A = A.view(-1, A.shape[-1]).contiguous() | ||
| CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( | ||
| A, threshold=state.threshold | ||
| A.to(torch.float16), threshold=state.threshold | ||
| ) | ||
|
|
||
| if state.threshold > 0.0 and coo_tensorA is not None: | ||
|
|
@@ -271,7 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): | |
| state.SCB, | ||
| state.SCBt, | ||
| coo_tensorB, | ||
| ) = F.double_quant(B) | ||
| ) = F.double_quant(B.to(torch.float16)) | ||
| state.CxB, state.SB = F.transform(CB, to_order=formatB) | ||
| else: | ||
| has_grad = False | ||
|
|
@@ -292,7 +291,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): | |
| (outliers * state.SCB.view(-1, 1) / 127.0) | ||
| .t() | ||
| .contiguous() | ||
| .half() | ||
| .to(A.dtype) | ||
| ) | ||
| CA[:, state.idx.long()] = 0 | ||
| CAt[:, state.idx.long()] = 0 | ||
|
|
@@ -309,7 +308,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): | |
| C32A, SA = F.transform(CA, "col32") | ||
| out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) | ||
| # we apply the fused bias here | ||
| output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) | ||
|
|
||
| if bias is None or bias.dtype == torch.float16: | ||
| output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) | ||
| output = output.to(A.dtype) | ||
| else: # apply bias separately | ||
| output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) | ||
| output = output.to(A.dtype).add_(bias) | ||
|
|
||
| # 4. Mixed-precision decomposition matmul | ||
| if coo_tensorA is not None and subA is not None: | ||
|
|
@@ -320,18 +325,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): | |
|
|
||
| ctx.formatB = formatB | ||
| ctx.grad_shape = input_shape | ||
| ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] | ||
| ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype | ||
|
|
||
| if requires_gradA or requires_gradB: | ||
| if any(ctx.needs_input_grad[:2]): | ||
| ctx.tensors = (CAt, subA) | ||
| ctx.tensor_states = (SCAt, state.idx) | ||
| else: | ||
| ctx.tensors = [None, None] | ||
| ctx.tensor_states = (None, None) | ||
| ctx.save_for_backward(None, None) | ||
|
|
||
| # Cast fp16 output back to A.dtype | ||
| output = output.to(A_dtype) | ||
|
|
||
| clone_func = torch.clone if len(output_shape) == 3 else lambda x : x | ||
| return clone_func(output.view(output_shape)) | ||
|
|
@@ -341,24 +344,24 @@ def backward(ctx, grad_output): | |
| if ctx.is_empty: | ||
| bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) | ||
| return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None | ||
| req_gradA, req_gradB, req_gradBias = ctx.req_grads | ||
| req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad | ||
| CAt, subA = ctx.tensors | ||
| SCAt, idx = ctx.tensor_states | ||
| formatB = ctx.formatB | ||
| state = ctx.state | ||
| grad_A = grad_B = grad_bias = None | ||
|
|
||
| # Cast grad_output to fp16 | ||
| grad_output_dtype = grad_output.dtype | ||
| grad_output = grad_output.to(torch.float16) | ||
| if req_gradBias: | ||
| # compute grad_bias first before changing grad_output dtype | ||
| grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) | ||
|
|
||
| # Cast grad_output to fp16 | ||
| if len(grad_output.shape) == 3: | ||
| grad_output = grad_output.reshape( | ||
| -1, grad_output.shape[-1] | ||
| ).contiguous() | ||
|
|
||
| grad_A = grad_B = grad_bias = None | ||
|
|
||
| Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) | ||
| Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) | ||
| if req_gradB: | ||
| CxAt, SAt = F.transform(CAt, formatB, transpose=True) | ||
| C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) | ||
|
|
@@ -375,21 +378,14 @@ def backward(ctx, grad_output): | |
| state.CBt, to_order=formatB, transpose=True | ||
| ) | ||
| gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) | ||
| grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) | ||
| grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) | ||
|
|
||
| elif state.CB is not None: | ||
| CB = state.CB.half() | ||
| SCB = (state.SCB.unsqueeze(1) / 127.0).half() | ||
| CB *= SCB | ||
| grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) | ||
| CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0)) | ||
|
||
| grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) | ||
| else: | ||
| raise Exception('State must contain either CBt or CB matrix for backward') | ||
|
|
||
| if req_gradBias: | ||
| grad_bias = grad_output.sum(0) | ||
|
|
||
| # Cast grad_A back to grad_output_dtype | ||
| grad_output = grad_output.to(grad_output_dtype) | ||
|
|
||
| return grad_A, grad_B, None, grad_bias, None | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,15 @@ def __init__(self, initial_data): | |
|
|
||
|
|
||
| class MLP8bit(torch.nn.Module): | ||
| def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): | ||
| def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): | ||
| super(MLP8bit, self).__init__() | ||
| self.fc1 = bnb.nn.Linear8bitLt( | ||
| dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold | ||
| dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, | ||
| threshold=threshold | ||
| ) | ||
| self.fc2 = bnb.nn.Linear8bitLt( | ||
| dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold | ||
| dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, | ||
| threshold=threshold | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
|
|
@@ -451,9 +453,12 @@ def test_linear8bitlt_accumulated_gradient(): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("threshold", values, ids=names) | ||
| def test_linear8bitlt_no_fp16_weights(threshold): | ||
| @pytest.mark.parametrize("memory_efficient_backward", [True, False]) | ||
| def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): | ||
| l1 = ( | ||
| bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) | ||
| bnb.nn.Linear8bitLt( | ||
| 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward | ||
| ) | ||
| .cuda() | ||
| .half() | ||
| ) | ||
|
|
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): | |
| assert mlp.fc2.weight.dtype == torch.int8 | ||
|
|
||
| mlp = ( | ||
| MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) | ||
| MLP8bit( | ||
| 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward | ||
| ) | ||
| .half() | ||
| .to("cuda") | ||
| ) | ||
|
|
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): | |
| assert mlp.fc1.weight.device.type == "cuda" | ||
| assert mlp.fc2.weight.device.type == "cuda" | ||
|
|
||
| mlp = ( | ||
| MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) | ||
| .to(torch.float16) | ||
| .to("cuda") | ||
| ) | ||
| mlp = MLP8bit( | ||
| 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward | ||
| ) | ||
| w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, | ||
| mlp = mlp.cuda().half() # and this line triggers quantization | ||
|
|
||
| for i in range(100): | ||
| b1 = torch.randn(16, 8, 32, device="cuda").half() | ||
|
|
@@ -545,11 +552,30 @@ def test_linear8bitlt_no_fp16_weights(threshold): | |
| assert mlp.fc1.state.idx is not None | ||
| if threshold > 0: | ||
| assert mlp.fc2.state.idx is not None | ||
|
|
||
| assert mlp.fc1.weight.dtype == torch.int8 | ||
| assert mlp.fc2.weight.dtype == torch.int8 | ||
| assert mlp.fc1.weight.device.type == "cuda" | ||
| assert mlp.fc2.weight.device.type == "cuda" | ||
|
|
||
| if memory_efficient_backward: | ||
| b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) | ||
| o1 = mlp(b1) | ||
| assert o1.dtype == torch.float16 | ||
| assert o1.requires_grad | ||
| grad_proj = torch.randn_like(o1) | ||
|
|
||
| mlp.zero_grad() | ||
| (o1 * grad_proj).sum().backward() | ||
| grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() | ||
| scale = grad_ref.abs().mean() | ||
| assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) | ||
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| def test_linear8bitlt_fp32_bias(): | ||
| # casts model to fp16 -> int8 automatically | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curiously, when i tried to replace the next line from
output += torch.matmul(subA, state.subB)to
output.addmm_(subA, state.subB)the precision would drop and the tests would fail.
I have no idea why - the dtypes of output, subA and subB are always equal (tested).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?
It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.
If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.