Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
8ae9bb2
add memory efficient backward
dbaranchuk Aug 23, 2022
1753aa0
refactoring
dbaranchuk Aug 23, 2022
656de8e
minor fixes
dbaranchuk Aug 23, 2022
876387d
minor fixes
dbaranchuk Aug 23, 2022
ef2936a
delete CxB from state
dbaranchuk Aug 23, 2022
4d6174b
memory efficient fp16 backward
dbaranchuk Aug 25, 2022
b3fee1e
add dtype <-> fp16 cast
dbaranchuk Aug 26, 2022
8d34d36
req_gradA for casted & more efficient and accurate fp16 backward
dbaranchuk Aug 28, 2022
843ad06
Merge pull request #1 from TimDettmers/main
dbaranchuk Sep 11, 2022
42b5fc9
add memory effcient backward option
dbaranchuk Sep 11, 2022
ee325f0
clarified an exception message
dbaranchuk Sep 11, 2022
d358999
refactoring
dbaranchuk Sep 11, 2022
4dd475c
refactoring
dbaranchuk Sep 11, 2022
e2a7576
bug fix
dbaranchuk Sep 11, 2022
3634fc7
Merge branch 'TimDettmers:main' into memory-efficient-backward
justheuristic Sep 17, 2022
cc4858c
some kind of warning or something when this is first executed to make…
justheuristic Sep 17, 2022
469d5a6
test_bf16
justheuristic Sep 17, 2022
a9c7953
cast to half before double_quant
justheuristic Sep 17, 2022
140cdbe
check dtypes first
justheuristic Sep 17, 2022
9379df8
check dtypes first
justheuristic Sep 17, 2022
e29c5f5
clearer assertions
justheuristic Sep 17, 2022
fc4a135
clearer assertions
justheuristic Sep 17, 2022
a9fe0ff
recast to fp16
justheuristic Sep 17, 2022
eac9aca
cast bias too
justheuristic Sep 17, 2022
7facedd
copypaste tolerances
justheuristic Sep 17, 2022
d9ca0ed
un-fuse bias
justheuristic Sep 17, 2022
56a074f
un-fuse bias
justheuristic Sep 17, 2022
e9b8711
un-fuse bias
justheuristic Sep 17, 2022
0de1a44
change order
justheuristic Sep 17, 2022
647c976
change order
justheuristic Sep 17, 2022
210b9ed
debug assert
justheuristic Sep 17, 2022
85bf529
debug assert
justheuristic Sep 17, 2022
e2b523d
change typecast behavior
justheuristic Sep 17, 2022
d6e25b5
change typecast behavior
justheuristic Sep 17, 2022
1145589
change typecast behavior
justheuristic Sep 17, 2022
1da4880
change typecast behavior
justheuristic Sep 17, 2022
5b169f1
change typecast behavior
justheuristic Sep 17, 2022
14048a3
safer cast
justheuristic Sep 17, 2022
a214824
matmul -1- addmm
justheuristic Sep 17, 2022
702cc72
debug asset
justheuristic Sep 17, 2022
45dc198
cast properly
justheuristic Sep 17, 2022
577275b
cast properly
justheuristic Sep 17, 2022
e35e2c6
cast properly
justheuristic Sep 17, 2022
cbfdf0b
cast edge case
justheuristic Sep 17, 2022
ab9dee0
cast edge case
justheuristic Sep 17, 2022
fa8e07c
more lenient threshold
justheuristic Sep 17, 2022
f667032
bump threshold to 0.21
justheuristic Sep 17, 2022
18f142e
addmm_
justheuristic Sep 17, 2022
76ece2c
rollback
justheuristic Sep 17, 2022
579b8c7
reduce diff
justheuristic Sep 17, 2022
591f603
add memory efficient backward
justheuristic Sep 17, 2022
2cd047e
run backward
justheuristic Sep 17, 2022
7906dc4
debugpritn
justheuristic Sep 17, 2022
4b4a9ef
debugprint
justheuristic Sep 17, 2022
4da2227
debug
justheuristic Sep 17, 2022
5d65817
debug
justheuristic Sep 17, 2022
d9b8789
debug
justheuristic Sep 17, 2022
6a826c4
pre-cast
justheuristic Sep 17, 2022
37f805b
debug
justheuristic Sep 17, 2022
95dafc6
cast before allclose
justheuristic Sep 17, 2022
28a9313
cast before allclose
justheuristic Sep 17, 2022
725cc72
cast device
justheuristic Sep 17, 2022
e4086a2
cast device
justheuristic Sep 17, 2022
01b4c6a
cast device
justheuristic Sep 17, 2022
32a9a88
cast device
justheuristic Sep 17, 2022
cff3a71
cast device
justheuristic Sep 17, 2022
9b7d307
review
TimDettmers Sep 20, 2022
a07825a
review
justheuristic Sep 20, 2022
292a478
set threshold
TimDettmers Sep 20, 2022
76ce9aa
try fp32
justheuristic Sep 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import operator
import warnings

import torch
import bitsandbytes.functional as F

Expand Down Expand Up @@ -184,6 +186,7 @@ class MatmulLtState:
idx = None
is_training = True
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()

Expand All @@ -209,31 +212,29 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)

# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
requires_gradA = A.requires_grad
requires_gradB = B.requires_grad
requires_gradBias = bias is not None and bias.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"

# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")

# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A, threshold=state.threshold
A.to(torch.float16), threshold=state.threshold
)

if state.threshold > 0.0 and coo_tensorA is not None:
Expand Down Expand Up @@ -269,7 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B)
) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
Expand All @@ -290,7 +291,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.half()
.to(A.dtype)
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
Expand All @@ -307,7 +308,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)

if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)

# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
Copy link
Contributor

@justheuristic justheuristic Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curiously, when i tried to replace the next line from
output += torch.matmul(subA, state.subB)

to
output.addmm_(subA, state.subB)

the precision would drop and the tests would fail.
I have no idea why - the dtypes of output, subA and subB are always equal (tested).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot remember if I stumbled upon the same thing. I remember trying to make this matrix multiplication more efficient but failed. What is the increase that you see in errors?

It does not make much sense to me since in cuBLAS you perform (A @ B) + D = C and the results of A @ B is in fp32 so the entire operation should be more precise. The same goes for fused multiply-add in general, which is more precise than multiplication followed by addition. It might be some weird tensor core issue, but it makes no sense to me.

If the error is only smaller some of the time and it has more variance, it would still be okay to have this. I believe it would be a good chunk faster.

Expand All @@ -318,42 +325,43 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):

ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if requires_gradA or requires_gradB:
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)


clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#clone_func = torch.clone
return clone_func(output.view(output_shape))

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, req_gradBias = ctx.req_grads
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert (
state.has_fp16_weights
), "Backprop only supported for fp16 weights."
grad_A = grad_B = grad_bias = None

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.view(
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()

grad_A = grad_B = grad_bias = None

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
Expand All @@ -363,16 +371,20 @@ def backward(ctx, grad_output):
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

if req_gradBias:
grad_bias = grad_output.sum(0)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')

return grad_A, grad_B, None, grad_bias, None

Expand Down
21 changes: 15 additions & 6 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
Expand All @@ -232,10 +233,13 @@ def __init__(

self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True

self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)

def init_8bit_state(self):
self.state.CB = self.weight.CB
Expand All @@ -255,11 +259,16 @@ def forward(self, x):

out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)

if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB

return out

Expand Down
9 changes: 6 additions & 3 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16]
dtype = [torch.float16, torch.bfloat16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_matmullt(
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2)
) = bnb.functional.double_quant(B2.to(torch.float16))
B2 = state.CB

if not transpose[0] and transpose[1]:
Expand All @@ -367,11 +367,14 @@ def test_matmullt(
if has_bias:
out_torch += bias

assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"

n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')

idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() <= n * 0.0175
assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() <= n * 0.001

Expand Down
46 changes: 35 additions & 11 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def __init__(self, initial_data):


class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)

def forward(self, x):
Expand Down Expand Up @@ -451,9 +453,12 @@ def test_linear8bitlt_accumulated_gradient():


@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
Expand Down Expand Up @@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8

mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
Expand All @@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"

mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.to(torch.float16)
.to("cuda")
)
mlp = MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization

for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
Expand All @@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None

assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"

if memory_efficient_backward:
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
grad_proj = torch.randn_like(o1)

mlp.zero_grad()
(o1 * grad_proj).sum().backward()
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean()

torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005


def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically
Expand Down