diff --git a/unsloth/device_type.py b/unsloth/device_type.py index adc09b05df..68038de679 100644 --- a/unsloth/device_type.py +++ b/unsloth/device_type.py @@ -25,7 +25,6 @@ import torch import functools from unsloth_zoo.utils import Version -import inspect @functools.cache @@ -78,21 +77,50 @@ def get_device_count(): DEVICE_COUNT: int = get_device_count() -# Check blocksize for 4bit -> 64 for CUDA, 128 for AMD -# If AMD, we cannot load pre-quantized models for now :( +# 4-bit quantization requires a block size of 64 +# this is not supported on AMD Instinct GPUs currently +# | Device Type | Warp Size | Block Size | +# |-----------------|-----------|------------| +# | CUDA | 32 | 64 | +# | Radeon (Navi) | 32 | 64 | +# | Instinct (MI) | 64 | 128 | +# +# Since bitsandbytes 0.49.0, pre-quantized models with 64 blockwise now works +# on Radeon GPUs, but not Instinct MI300x for eg [WIP] +# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748 + ALLOW_PREQUANTIZED_MODELS: bool = True # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES: bool = True if DEVICE_TYPE == "hip": try: - from bitsandbytes.nn.modules import Params4bit - - if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource( - Params4bit - ): - ALLOW_PREQUANTIZED_MODELS = False import bitsandbytes - - ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") except: - pass + print( + "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works." + ) + ALLOW_PREQUANTIZED_MODELS = False + ALLOW_BITSANDBYTES = False + if ALLOW_BITSANDBYTES: + ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0") + if Version(bitsandbytes.__version__) > Version("0.49.0"): + try: + # Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU + from bitsandbytes.cextension import ROCM_WARP_SIZE_64 + + ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64 + except Exception as e: + print( + "Unsloth: Checking `from bitsandbytes.cextension import ROCM_WARP_SIZE_64` had error = \n" + f"{str(e)}\n" + "4bit QLoRA disabled for now, but 16bit and full finetuning works." + ) + ALLOW_PREQUANTIZED_MODELS = False + ALLOW_BITSANDBYTES = False + elif ALLOW_BITSANDBYTES: + from bitsandbytes.nn.modules import Params4bit + + if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource( + Params4bit + ): + ALLOW_PREQUANTIZED_MODELS = False diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6148b1783e..a1e5756060 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -249,7 +249,7 @@ def from_pretrained( model_name = new_model_name # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -383,7 +383,7 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -790,7 +790,7 @@ def from_pretrained( model_name = new_model_name # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ): @@ -1056,7 +1056,7 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) # Check if pre-quantized models are allowed - # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64 + # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( ("-unsloth-bnb-4bit", "-bnb-4bit") ):