Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions unsloth/_gpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ def is_bf16_supported():
# set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()

# Backwards compatibility: some notebooks import `unsloth.is_bf16_supported`.
# Ensure it exists on all backends (HIP / XPU) and has a stable signature.
if "is_bf16_supported" not in globals():

def is_bf16_supported(including_emulation = False):
return SUPPORTS_BFLOAT16


# For Gradio HF Spaces?
# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
import triton
Expand Down
5 changes: 4 additions & 1 deletion unsloth/import_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,10 @@ def fix_vllm_aimv2_issue():
spec = importlib.util.find_spec("vllm")
if spec is None:
return
vllm_version = importlib_version("vllm")
try:
vllm_version = importlib_version("vllm")
except Exception:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected errors. It's better to catch a more specific exception. importlib.metadata.version raises importlib.metadata.PackageNotFoundError when package metadata is not found, which is a subclass of ImportError. Using except ImportError: would be more specific and safer here, and it doesn't require adding a new import.

Suggested change
except Exception:
except ImportError:

return
if Version(vllm_version) < Version("0.10.1"):
vllm_location = spec.origin
if vllm_location is None:
Expand Down
45 changes: 30 additions & 15 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ def calculate_settings(


HAS_CUDA_STREAM = False
import bitsandbytes as bnb
try:
import bitsandbytes as bnb

# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
get_ptr = bnb.functional.get_ptr
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
get_ptr = bnb.functional.get_ptr
except Exception:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to catch a more specific exception than the general Exception. For a failed import, ImportError is the correct exception to catch. This prevents masking other potential issues during the setup of bnb.

Suggested change
except Exception:
except ImportError:

bnb = None
get_ptr = None

if DEVICE_TYPE == "xpu":
HAS_XPU_STREAM = True
Expand Down Expand Up @@ -236,21 +240,32 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
WEIGHT_BUFFERS = []
ABSMAX_BUFFERS = []

# Bitsandbytes operations
# Bitsandbytes operations (optional)
ctypes_c_int = ctypes.c_int
ctypes_c_int32 = ctypes.c_int32
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4

if DEVICE_TYPE == "xpu":
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115
# for xpu, inference gemv using above link
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16
if bnb is not None:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4

if DEVICE_TYPE == "xpu":
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16
else:
cgemm_4bit_inference_naive_fp16 = (
bnb.functional.lib.cgemm_4bit_inference_naive_fp16
)
cgemm_4bit_inference_naive_bf16 = (
bnb.functional.lib.cgemm_4bit_inference_naive_bf16
)
else:
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
cdequantize_blockwise_fp32 = None
cdequantize_blockwise_fp16_nf4 = None
cdequantize_blockwise_bf16_nf4 = None
cgemm_4bit_inference_naive_fp16 = None
cgemm_4bit_inference_naive_bf16 = None


torch_device_stream = (
Expand Down
9 changes: 8 additions & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,7 +2540,14 @@ def patch_tokenizer(model, tokenizer):


def patch_fast_lora():
import peft.tuners.lora.bnb
try:
import peft.tuners.lora.bnb
except Exception as e:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected errors. Since this block is intended to handle a failed import of peft.tuners.lora.bnb, it's better to catch the more specific ImportError.

Suggested change
except Exception as e:
except ImportError as e:

print(
"Unsloth: bitsandbytes/peft bnb not available - skipping 4bit LoRA patch.",
repr(e),
)
return
from ..kernels.fast_lora import fast_lora_forward

peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward
Expand Down
16 changes: 13 additions & 3 deletions unsloth/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@
LlamaLinearScalingRotaryEmbedding,
)
from .mistral import *
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit

try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except Exception:
Bnb_Linear4bit = None

try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except Exception:
Peft_Linear4bit = None
Comment on lines +34 to +42

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to catch a more specific exception than the general Exception. For failed imports, ImportError is the correct exception to catch. This applies to both try-except blocks here.

Suggested change
try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except Exception:
Bnb_Linear4bit = None
try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except Exception:
Peft_Linear4bit = None
try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except ImportError:
Bnb_Linear4bit = None
try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except ImportError:
Peft_Linear4bit = None


_BNB_LINEAR_TYPES = tuple(t for t in (Bnb_Linear4bit, Peft_Linear4bit) if t is not None)

try:
from transformers.models.granite.modeling_granite import (
Expand Down Expand Up @@ -599,7 +609,7 @@ def post_patch(model, tokenizer, correct_dtype = None):
correct_dtype = lm_head.weight.dtype

for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
if _BNB_LINEAR_TYPES and isinstance(module, _BNB_LINEAR_TYPES):
weight = module.weight
quant_state = weight.quant_state

Expand Down
12 changes: 12 additions & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3208,6 +3208,18 @@ def get_peft_model(
if not SUPPORTS_RSLORA:
del arguments["use_rslora"]

# PEFT API compatibility: only pass kwargs supported by the installed peft version.
try:
import inspect as _inspect

if (
"ensure_weight_tying"
not in _inspect.signature(LoraConfig.__init__).parameters
):
arguments.pop("ensure_weight_tying", None)
except Exception:
arguments.pop("ensure_weight_tying", None)
Comment on lines +3220 to +3221

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected errors. It's better to catch a more specific set of exceptions that you might expect from the introspection logic, such as ImportError, AttributeError, or TypeError. This makes the code more robust against unforeseen issues.

Suggested change
except Exception:
arguments.pop("ensure_weight_tying", None)
except (ImportError, AttributeError, TypeError, ValueError):
arguments.pop("ensure_weight_tying", None)


_saved_temp_tokenizer = model._saved_temp_tokenizer

lora_config = LoraConfig(**arguments)
Expand Down
7 changes: 7 additions & 0 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,13 @@ def masked_batch_mean(x):
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
return x.mean()
else:
# Align mask/coef lengths when left-padding adds extra tokens.
if x.shape[1] != completion_mask.shape[1]:
min_len = min(x.shape[1], completion_mask.shape[1])
x = x[:, -min_len:]
cm = completion_mask[:, -min_len:]
denom = cm.sum().clamp(min = 1.0)
return (x * cm).sum() / denom
return (x * completion_mask).sum() / completion_token_count

if advantages.dim() == 1:
Expand Down
18 changes: 15 additions & 3 deletions unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,16 @@

IS_WINDOWS = sys.platform == "win32"
LLAMA_CPP_DEFAULT_DIR = "llama.cpp"
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit

try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except Exception:
Bnb_Linear4bit = None

try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except Exception:
Peft_Linear4bit = None
Comment on lines +37 to +45

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to catch a more specific exception than the general Exception. For failed imports, ImportError is the correct exception to catch. This applies to both try-except blocks here.

Suggested change
try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except Exception:
Bnb_Linear4bit = None
try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except Exception:
Peft_Linear4bit = None
try:
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
except ImportError:
Bnb_Linear4bit = None
try:
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
except ImportError:
Peft_Linear4bit = None

from peft.tuners.lora import Linear as Peft_Linear
from typing import Optional, Callable, Union, List
import sys
Expand Down Expand Up @@ -68,6 +76,10 @@
from pathlib import Path
from peft import PeftModelForCausalLM, PeftModel

_MERGE_LORA_LINEAR_TYPES = tuple(
t for t in (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear) if t is not None
)

__all__ = [
"print_quantization_methods",
"unsloth_save_model",
Expand Down Expand Up @@ -381,7 +393,7 @@ def _free_cached_model(model):

def _merge_lora(layer, name):
bias = getattr(layer, "bias", None)
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
if _MERGE_LORA_LINEAR_TYPES and isinstance(layer, _MERGE_LORA_LINEAR_TYPES):
# Is LoRA so we need to merge!
W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
if quant_state is not None:
Expand Down
Loading