-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Make bitsandbytes optional on ROCm and add bf16 helper #4211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ecb049
e8d8974
3e606ae
c3d86a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,11 +136,15 @@ def calculate_settings( | |
|
|
||
|
|
||
| HAS_CUDA_STREAM = False | ||
| import bitsandbytes as bnb | ||
| try: | ||
| import bitsandbytes as bnb | ||
|
|
||
| # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files | ||
| HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") | ||
| get_ptr = bnb.functional.get_ptr | ||
| # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files | ||
| HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") | ||
| get_ptr = bnb.functional.get_ptr | ||
| except Exception: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| bnb = None | ||
| get_ptr = None | ||
|
|
||
| if DEVICE_TYPE == "xpu": | ||
| HAS_XPU_STREAM = True | ||
|
|
@@ -236,21 +240,32 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: | |
| WEIGHT_BUFFERS = [] | ||
| ABSMAX_BUFFERS = [] | ||
|
|
||
| # Bitsandbytes operations | ||
| # Bitsandbytes operations (optional) | ||
| ctypes_c_int = ctypes.c_int | ||
| ctypes_c_int32 = ctypes.c_int32 | ||
| cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 | ||
| cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 | ||
| cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 | ||
|
|
||
| if DEVICE_TYPE == "xpu": | ||
| # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115 | ||
| # for xpu, inference gemv using above link | ||
| cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16 | ||
| cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16 | ||
| if bnb is not None: | ||
| cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 | ||
| cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 | ||
| cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 | ||
|
|
||
| if DEVICE_TYPE == "xpu": | ||
| # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115 | ||
| cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16 | ||
| cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16 | ||
| else: | ||
| cgemm_4bit_inference_naive_fp16 = ( | ||
| bnb.functional.lib.cgemm_4bit_inference_naive_fp16 | ||
| ) | ||
| cgemm_4bit_inference_naive_bf16 = ( | ||
| bnb.functional.lib.cgemm_4bit_inference_naive_bf16 | ||
| ) | ||
| else: | ||
| cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 | ||
| cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 | ||
| cdequantize_blockwise_fp32 = None | ||
| cdequantize_blockwise_fp16_nf4 = None | ||
| cdequantize_blockwise_bf16_nf4 = None | ||
| cgemm_4bit_inference_naive_fp16 = None | ||
| cgemm_4bit_inference_naive_bf16 = None | ||
|
|
||
|
|
||
| torch_device_stream = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2540,7 +2540,14 @@ def patch_tokenizer(model, tokenizer): | |
|
|
||
|
|
||
| def patch_fast_lora(): | ||
| import peft.tuners.lora.bnb | ||
| try: | ||
| import peft.tuners.lora.bnb | ||
| except Exception as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| print( | ||
| "Unsloth: bitsandbytes/peft bnb not available - skipping 4bit LoRA patch.", | ||
| repr(e), | ||
| ) | ||
| return | ||
| from ..kernels.fast_lora import fast_lora_forward | ||
|
|
||
| peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,8 +30,18 @@ | |||||||||||||||||||||||||||||||||||||
| LlamaLinearScalingRotaryEmbedding, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| from .mistral import * | ||||||||||||||||||||||||||||||||||||||
| from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| from peft.tuners.lora import Linear4bit as Peft_Linear4bit | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||
| Bnb_Linear4bit = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| from peft.tuners.lora import Linear4bit as Peft_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||
| Peft_Linear4bit = None | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+34
to
+42
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to catch a more specific exception than the general
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| _BNB_LINEAR_TYPES = tuple(t for t in (Bnb_Linear4bit, Peft_Linear4bit) if t is not None) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| from transformers.models.granite.modeling_granite import ( | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -599,7 +609,7 @@ def post_patch(model, tokenizer, correct_dtype = None): | |||||||||||||||||||||||||||||||||||||
| correct_dtype = lm_head.weight.dtype | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for name, module in model.named_modules(): | ||||||||||||||||||||||||||||||||||||||
| if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): | ||||||||||||||||||||||||||||||||||||||
| if _BNB_LINEAR_TYPES and isinstance(module, _BNB_LINEAR_TYPES): | ||||||||||||||||||||||||||||||||||||||
| weight = module.weight | ||||||||||||||||||||||||||||||||||||||
| quant_state = weight.quant_state | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3208,6 +3208,18 @@ def get_peft_model( | |||||||||
| if not SUPPORTS_RSLORA: | ||||||||||
| del arguments["use_rslora"] | ||||||||||
|
|
||||||||||
| # PEFT API compatibility: only pass kwargs supported by the installed peft version. | ||||||||||
| try: | ||||||||||
| import inspect as _inspect | ||||||||||
|
|
||||||||||
| if ( | ||||||||||
| "ensure_weight_tying" | ||||||||||
| not in _inspect.signature(LoraConfig.__init__).parameters | ||||||||||
| ): | ||||||||||
| arguments.pop("ensure_weight_tying", None) | ||||||||||
| except Exception: | ||||||||||
| arguments.pop("ensure_weight_tying", None) | ||||||||||
|
Comment on lines
+3220
to
+3221
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a broad
Suggested change
|
||||||||||
|
|
||||||||||
| _saved_temp_tokenizer = model._saved_temp_tokenizer | ||||||||||
|
|
||||||||||
| lora_config = LoraConfig(**arguments) | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,8 +33,16 @@ | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| IS_WINDOWS = sys.platform == "win32" | ||||||||||||||||||||||||||||||||||||||
| LLAMA_CPP_DEFAULT_DIR = "llama.cpp" | ||||||||||||||||||||||||||||||||||||||
| from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| from peft.tuners.lora import Linear4bit as Peft_Linear4bit | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||
| Bnb_Linear4bit = None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| from peft.tuners.lora import Linear4bit as Peft_Linear4bit | ||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||
| Peft_Linear4bit = None | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+37
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to catch a more specific exception than the general
Suggested change
|
||||||||||||||||||||||||||||||||||||||
| from peft.tuners.lora import Linear as Peft_Linear | ||||||||||||||||||||||||||||||||||||||
| from typing import Optional, Callable, Union, List | ||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -68,6 +76,10 @@ | |||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||
| from peft import PeftModelForCausalLM, PeftModel | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| _MERGE_LORA_LINEAR_TYPES = tuple( | ||||||||||||||||||||||||||||||||||||||
| t for t in (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear) if t is not None | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| __all__ = [ | ||||||||||||||||||||||||||||||||||||||
| "print_quantization_methods", | ||||||||||||||||||||||||||||||||||||||
| "unsloth_save_model", | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -381,7 +393,7 @@ def _free_cached_model(model): | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _merge_lora(layer, name): | ||||||||||||||||||||||||||||||||||||||
| bias = getattr(layer, "bias", None) | ||||||||||||||||||||||||||||||||||||||
| if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)): | ||||||||||||||||||||||||||||||||||||||
| if _MERGE_LORA_LINEAR_TYPES and isinstance(layer, _MERGE_LORA_LINEAR_TYPES): | ||||||||||||||||||||||||||||||||||||||
| # Is LoRA so we need to merge! | ||||||||||||||||||||||||||||||||||||||
| W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer) | ||||||||||||||||||||||||||||||||||||||
| if quant_state is not None: | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad
Exceptioncan hide unexpected errors. It's better to catch a more specific exception.importlib.metadata.versionraisesimportlib.metadata.PackageNotFoundErrorwhen package metadata is not found, which is a subclass ofImportError. Usingexcept ImportError:would be more specific and safer here, and it doesn't require adding a new import.