Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# This is a signal for integrations with transformers/diffusers.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi-backend"}
features = {"multi_backend"}
supported_torch_devices = {
"cpu",
"cuda", # NVIDIA/AMD GPU
Expand Down
75 changes: 75 additions & 0 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
import ctypes as ct
from typing import Optional

Expand Down Expand Up @@ -119,6 +120,10 @@ def _(
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()

Expand All @@ -140,3 +145,73 @@ def _(
packed = packed.squeeze().view(quant_storage).unsqueeze(1)

return packed, absmax.float()


@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)

A = A.view(-1, 1)

# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)

# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)

# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]

# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])

return blocks.to(dtype)


@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# TODO: We need to determine whether `code` is NF4, FP4, or other.
# Right now we assume NF4, as this is the only one supported on CPU.

B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
B,
absmax,
blocksize,
"nf4",
shape=shapeB,
dtype=A.dtype,
)

# User called gemv with B.t(), so we need to transpose it back.
# if B.shape[0] == 1:
# B_dq = B_dq.t()

return torch.nn.functional.linear(
A,
B_dq,
bias=None,
)
39 changes: 0 additions & 39 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
_int8_linear_matmul_impl(A, B, out)


@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None

if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()

# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)

# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()

else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)

# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)

if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)

return output, subA


def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
A, B = B, A

Expand Down
78 changes: 78 additions & 0 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
from math import prod
from typing import Optional

import torch

from ..._ops import register_kernel


@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None

if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()

# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)

# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()

else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)

# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)

if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)

return output, subA


@register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _(
A: torch.Tensor,
Expand Down Expand Up @@ -41,3 +81,41 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
if out is not None:
result = out.copy_(result)
return result


@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
def _(A: torch.Tensor, threshold=0.0):
rows = prod(A.shape[:-1])
outlier_cols = None

outlier_restore = None

if threshold > 0.0:
outliers = A.abs() >= threshold

if outliers.any():
# Determine which columns contain outliers, and zero out the
# outliers ahead of quantization. We need to keep a backup of these
# outliers to restore them after quantization.
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
outlier_restore = A[outliers].clone()
A[outliers] = 0
else:
# Needed for torch.compile support.
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)

# Get absmax for each row.
row_stats = torch.max(A.abs(), dim=1).values.float()

# Quantize row-wise to int8.
out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)

# Zero out values from outlier columns across all rows.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0

# Restore outliers.
if outlier_restore is not None:
A[outliers] = outlier_restore

return out_row, row_stats, outlier_cols
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def quantize_blockwise(
state2=state2,
)
else:
quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype)

# TODO(matthewdouglas): Deprecate out kwarg
out = out.copy_(_out) if out is not None else _out
Expand Down
33 changes: 21 additions & 12 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def forward(self, x: torch.Tensor):

bias = None if self.bias is None else self.bias.to(self.compute_dtype)

return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)


class LinearFP4(Linear4bit):
Expand Down Expand Up @@ -585,19 +585,28 @@ def __new__(
obj.has_fp16_weights = has_fp16_weights
return obj

def cuda(self, device):
def _quantize(self, device):
if self.has_fp16_weights:
return super().cuda(device)
else:
# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().half().cuda(device)
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB
return super().to(device)

# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().to(device=device, dtype=torch.float16)
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB

return self

def cpu(self):
return self.to(device="cpu")

def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)

def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)

def __deepcopy__(self, memo):
# adjust this if new arguments are added to the constructor
new_instance = type(self).__new__(
Expand Down Expand Up @@ -627,8 +636,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
return self._quantize(device)
else:
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
Expand Down
14 changes: 10 additions & 4 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda" and funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")
if device != "cuda":
if funcs[1] == bnb.research.switchback_bnb:
# TODO: Deprecate/remove?
pytest.skip("switchback_bnb only works on CUDA.")

if req_grad[1]:
# This will be deprecated for CUDA in the future. We don't expect
# this to work on any other device.
pytest.skip("Deprecated feature with CUDA support only.")

dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
Expand Down Expand Up @@ -171,7 +177,7 @@ def test_matmul_4bit(
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.skip("Only nf4 is supported on CPU")
pytest.xfail("Only nf4 is supported on CPU")

dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
Expand Down
Loading