From d60750fe6e38d52c26de073c9500fe65cb2ad434 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 17 Jun 2025 13:12:50 +0000 Subject: [PATCH] remove check for better performance Signed-off-by: jiqing-feng --- bitsandbytes/backends/xpu/ops.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 0483161dc..6623d9fb6 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -57,12 +57,12 @@ def _dequantize_4bit_impl( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) + # torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + # torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + # torch._check( + # dtype in [torch.float16, torch.bfloat16, torch.float32], + # lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + # ) args = ( get_ptr(code), @@ -90,22 +90,6 @@ def _gemv_4bit_impl( blocksize: int, out: torch.Tensor, ) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - m = ct.c_int32(shapeB[0]) n = ct.c_int32(1) k = ct.c_int32(shapeB[1])