diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 4ec50a8f4..c1a3a73a7 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -52,7 +52,7 @@ def forward(self, x): class Int8Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False): + def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None @@ -66,7 +66,7 @@ def cuda(self, device): else: # we store the 8-bit rows-major weight # we convert this weight to the turning/ampere weight during the first inference pass - B = self.data.contiguous().half().cuda() + B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt del SCBt