From d605fdf15a51ebfa47df55b64dba6c904064f47f Mon Sep 17 00:00:00 2001 From: sstamenk Date: Thu, 18 Dec 2025 17:24:34 +0100 Subject: [PATCH 1/9] Enable 4-bit quant on Radeon --- unsloth/device_type.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index adc09b05df..c9b0f125d8 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -25,7 +25,6 @@ import torch import functools from unsloth_zoo.utils import Version -import inspect @functools.cache @@ -78,19 +77,21 @@ def get_device_count(): DEVICE_COUNT: int = get_device_count() -# Check blocksize for 4bit -> 64 for CUDA, 128 for AMD -# If AMD, we cannot load pre-quantized models for now :( +# 4-bit quantization requires a block size of 64 +# | Device Type | Warp Size | Block Size | +# |-----------------|-----------|------------| +# | CUDA | 32 | 64 | +# | Radeon (Navi) | 32 | 64 | +# | Instinct (MI) | 64 | 128 | ALLOW_PREQUANTIZED_MODELS: bool = True # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True if DEVICE_TYPE == "hip": try: - from bitsandbytes.nn.modules import Params4bit + from bitsandbytes.cextension import ROCM_WARP_SIZE_64 + + ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64 - if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource( - Params4bit - ): - ALLOW_PREQUANTIZED_MODELS = False import bitsandbytes ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") From ca27f37962d3d5a72eac57c63c5f8a664e0161e7 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Thu, 18 Dec 2025 17:40:43 +0100 Subject: [PATCH 2/9] Fix table centering --- unsloth/device_type.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index c9b0f125d8..d38f7a0315 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -80,9 +80,9 @@ def get_device_count(): # 4-bit quantization requires a block size of 64 # | Device Type | Warp Size | Block Size | # |-----------------|-----------|------------| -# | CUDA | 32 | 64 | -# | Radeon (Navi) | 32 | 64 | -# | Instinct (MI) | 64 | 128 | +# | CUDA | 32 | 64 | +# | Radeon (Navi) | 32 | 64 | +# | Instinct (MI) | 64 | 128 | ALLOW_PREQUANTIZED_MODELS: bool = True # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True From 0cfa49e66c9eacf205e3ff7104121880b636d8b1 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Thu, 18 Dec 2025 17:57:54 +0100 Subject: [PATCH 3/9] Update comments for clarity --- unsloth/models/loader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6148b1783e..a1e5756060 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -249,7 +249,7 @@ def from_pretrained( model_name = new_model_name # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -383,7 +383,7 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -790,7 +790,7 @@ def from_pretrained( model_name = new_model_name # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -1056,7 +1056,7 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): From fbe409b0575317985882b32ec267e6ef0c34404f Mon Sep 17 00:00:00 2001 From: sstamenk Date: Thu, 18 Dec 2025 18:10:19 +0100 Subject: [PATCH 4/9] Handle failure to import Bitsandbytes --- unsloth/device_type.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index d38f7a0315..b0edba6e82 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -95,5 +95,9 @@ def get_device_count(): import bitsandbytes ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") - except: - pass + except Exception: + print( + "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!" + ) + ALLOW_PREQUANTIZED_MODELS = False + ALLOW_BITSANDBYTES = False \ No newline at end of file From b88f43e2c02b1cf54ed2f0d010705e625b900e18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:11:31 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/device_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index b0edba6e82..41da5d9d6a 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -100,4 +100,4 @@ def get_device_count(): "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!" ) ALLOW_PREQUANTIZED_MODELS = False - ALLOW_BITSANDBYTES = False \ No newline at end of file + ALLOW_BITSANDBYTES = False From a711bd8cda6759de899cd942a169370bd6cb2dcd Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Thu, 18 Dec 2025 18:17:05 +0100 Subject: [PATCH 6/9] Update device_type.py --- unsloth/device_type.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index 41da5d9d6a..ee2c6beb05 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -78,6 +78,7 @@ def get_device_count(): DEVICE_COUNT: int = get_device_count() # 4-bit quantization requires a block size of 64 +# this is not supported on AMD Instinct GPUs currently # | Device Type | Warp Size | Block Size | # |-----------------|-----------|------------| # | CUDA | 32 | 64 | From 79cc6667f4bf3424b9ae2a4c9cd36a561907660a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 19 Dec 2025 19:12:41 -0800 Subject: [PATCH 7/9] Apply suggestion from @danielhanchen --- unsloth/device_type.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index ee2c6beb05..d434a2f02e 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -84,6 +84,7 @@ def get_device_count(): # | CUDA | 32 | 64 | # | Radeon (Navi) | 32 | 64 | # | Instinct (MI) | 64 | 128 | +# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748 ALLOW_PREQUANTIZED_MODELS: bool = True # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True From 4fd13600820bb583dda5ed3d753b107fbda93792 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 19 Dec 2025 19:22:00 -0800 Subject: [PATCH 8/9] Update device_type.py --- unsloth/device_type.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index d434a2f02e..095dfb6657 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -84,22 +84,39 @@ def get_device_count(): # | CUDA | 32 | 64 | # | Radeon (Navi) | 32 | 64 | # | Instinct (MI) | 64 | 128 | +# +# Since bitsandbytes 0.49.0, pre-quantized models with 64 blockwise now works +# on Radeon GPUs, but not Instinct MI300x for eg [WIP] # See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748 + ALLOW_PREQUANTIZED_MODELS: bool = True # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True if DEVICE_TYPE == "hip": try: - from bitsandbytes.cextension import ROCM_WARP_SIZE_64 - - ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64 - import bitsandbytes - - ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") - except Exception: + except: print( - "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!" + "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works." ) ALLOW_PREQUANTIZED_MODELS = False ALLOW_BITSANDBYTES = False + if ALLOW_BITSANDBYTES: + ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") + if Version(bitsandbytes.__version__) > Version("0.49.0"): + try: + # Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU + from bitsandbytes.cextension import ROCM_WARP_SIZE_64 + ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64 + except Exception as e: + print( + "Unsloth: Checking `from bitsandbytes.cextension import ROCM_WARP_SIZE_64` had error = \n" + f"{str(e)}\n" + "4bit QLoRA disabled for now, but 16bit and full finetuning works." + ) + ALLOW_PREQUANTIZED_MODELS = False + ALLOW_BITSANDBYTES = False + elif ALLOW_BITSANDBYTES: + from bitsandbytes.nn.modules import Params4bit + if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit): + ALLOW_PREQUANTIZED_MODELS = False From a6fe3525dc42169caa9ace26577cf8440d57655e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 Dec 2025 03:22:08 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/device_type.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/device_type.py b/unsloth/device_type.py index 095dfb6657..68038de679 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -107,6 +107,7 @@ def get_device_count(): try: # Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU from bitsandbytes.cextension import ROCM_WARP_SIZE_64 + ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64 except Exception as e: print( @@ -118,5 +119,8 @@ def get_device_count(): ALLOW_BITSANDBYTES = False elif ALLOW_BITSANDBYTES: from bitsandbytes.nn.modules import Params4bit - if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit): + + if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource( + Params4bit + ): ALLOW_PREQUANTIZED_MODELS = False