diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index be975f655..702b3c141 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -279,12 +279,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -342,12 +336,9 @@ def backward(ctx, grad_output): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -365,10 +356,24 @@ def backward(ctx, grad_output): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True + if state.has_fp16_weights: + CBt = state.CBt + else: + # Restore CBt from CB + assert state.CBt is None, "CBt should not be stored in state" + CB = state.CB.half() + SCB = state.SCB.unsqueeze(1).half() + SCBt = state.SCBt.unsqueeze(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + # intentionally, do not store CxBt in state + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + else: + CxBt = state.CxBt + gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) if req_gradBias: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54bb..03ffd3b2a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,10 +148,12 @@ def __new__( has_fp16_weights=False, CB=None, SCB=None, + SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None + cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -165,10 +167,10 @@ def cuda(self, device): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt - del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) + setattr(self, "SCBt", SCBt) return self @@ -210,6 +212,7 @@ def to(self, *args, **kwargs): ) new_param.CB = self.CB new_param.SCB = self.SCB + new_param.SCBt = self.SCBt return new_param @@ -240,8 +243,10 @@ def __init__( def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB + self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None + self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training @@ -255,11 +260,10 @@ def forward(self, x): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if not self.state.has_fp16_weights and self.state.CxB is not None: + # In this version, we convert 8-bit row major to turing/ampere format at each inference pass + # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. + del self.state.CxB return out