diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c92b25e2c..85ce52cc5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -639,8 +639,12 @@ def to(self, *args, **kwargs): if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) - elif device.type == "cpu" and self.data.dtype != torch.int8: - return self.cpu() + elif device.type == "cpu": + if self.data.dtype == torch.int8: + self.CB = self.data + return self + else: + return self.cpu() else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking),