Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .cextension import ipex_cpu, ipex_xpu
from .utils import ipex_cpu

_IS_TORCH_GTE_24 = False

Expand Down Expand Up @@ -331,7 +331,7 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")


if ipex_cpu or ipex_xpu:
if ipex_cpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
Expand Down
23 changes: 11 additions & 12 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing_extensions import deprecated

import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
Expand Down Expand Up @@ -426,7 +425,7 @@ def matmul(
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)

Expand All @@ -440,16 +439,16 @@ def matmul_4bit(
):
assert quant_state is not None

#if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
# if getattr(quant_state, "ipex", False):
# # IPEX CPU will change weight to 4D so don't need transpose
# B = B.t() if B.dim() == 2 else B
# out = F.gemv_4bit(A, B, out, state=quant_state)
# if bias is not None:
# out += bias
# return out
# else:
# return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.device.type == "cpu" and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
7 changes: 6 additions & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Sequence
import ctypes as ct
import warnings

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import ipex_cpu
from ...utils import ipex_cpu

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
Expand Down Expand Up @@ -118,3 +119,7 @@ def _(
shape,
dtype,
)
else:
warnings.warn(
"You can install intel_extension_for_pytorch to get better performance on NF4 if you are using Intel CPUs."
)
10 changes: 0 additions & 10 deletions bitsandbytes/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from packaging import version
import torch

try:
# to support Intel CPU/XPU (IPEX) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None

try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
Expand Down
50 changes: 24 additions & 26 deletions bitsandbytes/backends/xpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@

from ..._ops import register_kernel
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import ipex_xpu, triton_available
from ..utils import triton_available

# _int_mm is available in torch starting from 2.7 version,
# but currently it's don't have xpu implementation.
if ipex_xpu and torch.__version__ >= (2, 7):

@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])
# TODO: Enable _int_mm in torch
# if torch.__version__ >= (2, 9):
# @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
# def _(A: torch.Tensor, B: torch.Tensor):
# return torch._int_mm(
# A.reshape(-1, A.shape[-1]),
# B.t(),
# ).reshape(*A.shape[:-1], B.shape[0])


def _dequantize_4bit_impl(
Expand Down Expand Up @@ -92,21 +90,21 @@ def _gemv_4bit_impl(
blocksize: int,
out: torch.Tensor,
) -> None:
# torch._check_is_size(blocksize)
# torch._check(
# A.numel() == A.size(-1),
# lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
# )
# torch._check(
# A.dtype in [torch.float16, torch.bfloat16, torch.float32],
# lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
# )
# torch._check(
# B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
# lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
# )
# torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
torch._check_is_size(blocksize)
torch._check(
A.numel() == A.size(-1),
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
)
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
Expand Down
20 changes: 4 additions & 16 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,26 +289,14 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll)


try:
# to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None


try:
lib = get_native_library()
except Exception as e:
error_msg = str(e)
if not (ipex_cpu or ipex_xpu):
logger.error(
f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops",
exc_info=True,
)
logger.error(
f"bitsandbytes library load error: {error_msg}",
exc_info=True,
)

# create a mock with error messaging as fallback
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
69 changes: 43 additions & 26 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import ipex_cpu, ipex_xpu, lib
from .cextension import lib

name2qmap = {}

Expand Down Expand Up @@ -1039,6 +1039,16 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()

# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
return torch.ops.bitsandbytes.dequantize_nf4_ipex(
A,
absmax,
quant_state.blocksize,
quant_state.shape,
quant_state.dtype,
)

if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
Expand Down Expand Up @@ -1607,6 +1617,25 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset

if getattr(state, "ipex", False) and state.quant_type == "nf4":
# compute_dtype: 1 indicates fp16, 2 indicates bf16
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
out = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
compute_dtype,
1,
state.compensation,
)
return out

if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
Expand Down Expand Up @@ -2308,31 +2337,19 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state.nested = False
delattr(quant_state, "state2")

if x.device.type == "cpu" and ipex_cpu:
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
elif x.device.type == "xpu" and ipex_xpu:
new_weight = _reverse_4bit_compress_format(linear.weight.data)
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
new_scales = list(new_scales)
if not linear.training and not x.requires_grad:
new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7"
)
assert x.device.type == "cpu"
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
Expand Down
19 changes: 10 additions & 9 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
from bitsandbytes.functional import QuantState, _enable_ipex_fusion
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
_reverse_4bit_compress_format,
ipex_cpu,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -472,8 +473,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = _reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1))

self.weight.quant_state.ipex = False
self.ipex_linear_is_set = False
Expand All @@ -490,15 +489,17 @@ def set_ipex_linear(self, x: torch.Tensor):
and self.weight.data.dtype == torch.uint8
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
and x.device.type == "cpu"
and not self.training
and not x.requires_grad
):
if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
_enable_ipex_fusion(self, x)
_enable_ipex_fusion(self, x)

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
#if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu):
# self.set_ipex_linear(x)
# self.ipex_linear_is_set = True
if not self.ipex_linear_is_set and ipex_cpu:
self.set_ipex_linear(x)
self.ipex_linear_is_set = True

fix_4bit_weight_quant_state_from_module(self)

Expand Down Expand Up @@ -671,7 +672,7 @@ def to(self, *args, **kwargs):
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
elif self.data.dtype == torch.int8 and device.type == "cpu" and ipex_cpu:
self.CB = self.data

new_param = Int8Params(
Expand Down
8 changes: 8 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

import torch

try:
# to support Intel CPU backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
except BaseException:
ipex_cpu = None


def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
Expand Down