Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions bitsandbytes/backends/xpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def _dequantize_4bit_impl(
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
)
# torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
# torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(
# dtype in [torch.float16, torch.bfloat16, torch.float32],
# lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
# )

args = (
get_ptr(code),
Expand Down Expand Up @@ -90,22 +90,6 @@ def _gemv_4bit_impl(
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(
A.numel() == A.size(-1),
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
)
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
k = ct.c_int32(shapeB[1])
Expand Down