diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b905a1e74..b434daf0b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -878,7 +878,22 @@ def create_new_function( overwrite = True pass if os.environ.get("UNSLOTH_COMPILE_OVERWRITE", "1") == "0": - overwrite = False + # Even with OVERWRITE disabled, force recompile on transformers version mismatch + if file_source is not None and "__UNSLOTH_VERSIONING__" in file_source: + cached_versions = file_source[:file_source.find("__UNSLOTH_VERSIONING__")] + cached_lines = [l.strip() for l in cached_versions.strip().strip('"').split("\n") if l.strip()] + # Format: [unsloth_zoo_version, unsloth_version, transformers_version, trl_version] + cached_tf_version = cached_lines[2] if len(cached_lines) > 2 else "0" + if cached_tf_version != transformers_version: + logger.warning_once( + f"Unsloth: UNSLOTH_COMPILE_OVERWRITE=0 is set, but transformers version changed " + f"({cached_tf_version} -> {transformers_version}). Forcing recompile of {name}." + ) + # Don't set overwrite = False; keep overwrite = True from version mismatch detection + else: + overwrite = False + else: + overwrite = False # Check location def write_file(function_location, write_new_source): @@ -3165,15 +3180,26 @@ def replaced_tqdm(*args, **kwargs): bad_torch_modules.add(module) pass - # Check if creating arrays in inside the function - # Error: DataDependentOutputException: aten._local_scalar_dense.default + # Check for data-dependent control flow that breaks torch.compile(fullgraph=True) + # Tier 1: Direct data escapes from tensor to Python + # .nonzero() -> data-dependent output shape (variable-length) + # .tolist() -> materializes tensor values into Python list + # .item() -> materializes tensor scalar into Python + # Tier 2: MoE expert dispatch via torch.where + index_add + # 1-arg torch.where returns data-dependent indices; combined with + # index_add this is the standard MoE routing loop pattern if ( - "torch.arange(" in source - or "torch.zeros(" in source - or "torch.ones(" in source + ".nonzero()" in source + or ".tolist()" in source + or ".item()" in source ): print( - f"Unsloth: Failed compiling function {module} since array creations are done." + f"Unsloth: Will not compile {module} since data-dependent operations are done." + ) + bad_torch_modules.add(module) + elif "torch.where(" in source and ".index_add" in source: + print( + f"Unsloth: Will not compile {module} since data-dependent routing is done." ) bad_torch_modules.add(module) pass diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 57c7f3969..6bc6721c1 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -340,21 +340,50 @@ def _train_on_responses_only(examples): # Set it to int(memory_gb_left) so 16Gb = 16 num_proc = min(num_proc, int(memory_gb_left)) + # In transformers 5.0+, VLM models skip dataset preparation in SFTTrainer.__init__ + # (skip_prepare_dataset=True when _is_vlm=True). This means the dataset may not be + # tokenized yet. We need to tokenize it before applying _train_on_responses_only. + def _maybe_tokenize_dataset(dataset): + if dataset is None: + return dataset + sample = next(iter(dataset)) + if "input_ids" in sample: + return dataset # Already tokenized + # Need to tokenize - get the processing class from trainer + _tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer + # Get the actual tokenizer (not processor) for tokenization + if hasattr(_tokenizer, "tokenizer"): + _tok = _tokenizer.tokenizer + else: + _tok = _tokenizer + max_length = getattr(trainer.args, "max_length", None) or getattr(trainer.args, "max_seq_length", 2048) + text_field = getattr(trainer.args, "dataset_text_field", "text") + def _tokenize_fn(examples): + texts = examples.get(text_field) or examples.get("text", []) + return _tok(texts, truncation=True, max_length=max_length, padding=False) + _map_kwargs = {"batched": True, "num_proc": num_proc} + if isinstance(dataset, IterableDataset): + _map_kwargs = {"batched": True} + return dataset.map(_tokenize_fn, **_map_kwargs) + pass + if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: if not hasattr(trainer.train_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + trainer.train_dataset = _maybe_tokenize_dataset(trainer.train_dataset) if isinstance(trainer.train_dataset, IterableDataset): trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batch_size = trainer.train_dataset._ex_iterable.batch_size, batched = True) else: trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass - + if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None: # Eval datasets could be a dict! if type(trainer.eval_dataset) is dict: for key, value in trainer.eval_dataset.items(): if not hasattr(value, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + value = _maybe_tokenize_dataset(value) if isinstance(value, IterableDataset): trainer.eval_dataset[key] = value.map(_train_on_responses_only, batch_size = value._ex_iterable.batch_size, batched = True) else: @@ -362,6 +391,7 @@ def _train_on_responses_only(examples): else: if not hasattr(trainer.eval_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + trainer.eval_dataset = _maybe_tokenize_dataset(trainer.eval_dataset) if isinstance(trainer.eval_dataset, IterableDataset): trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True) else: diff --git a/unsloth_zoo/temporary_patches/bitsandbytes.py b/unsloth_zoo/temporary_patches/bitsandbytes.py index 38037b707..f88f0a426 100644 --- a/unsloth_zoo/temporary_patches/bitsandbytes.py +++ b/unsloth_zoo/temporary_patches/bitsandbytes.py @@ -49,10 +49,18 @@ def patch_bitsandbytes_linear4bit_forward(): return raise_error("bitsandbytes.Linear4bit", e) def forward(self, x: torch.Tensor): - fix_4bit_weight_quant_state_from_module(self) + # In transformers 5.0+, weights may not be in packed format yet during init + if self.weight.shape[-1] == 1: + fix_4bit_weight_quant_state_from_module(self) + + # Some layers may not be quantized (no quant_state) - fall back to regular matmul + quant_state = getattr(self.weight, "quant_state", None) + if quant_state is None: + bias = None if self.bias is None else self.bias + return torch.nn.functional.linear(x, self.weight, bias) # weights are cast automatically as Int8Params, but the bias has to be cast manually - + # ** Errors out in torch.compile so remove it # if self.bias is not None and self.bias.dtype != x.dtype: # self.bias.data = self.bias.data.to(x.dtype) @@ -72,7 +80,7 @@ def forward(self, x: torch.Tensor): # Cannot do .t() on Params4bit, instead do it on torch.Tensor weight = self.weight.data.t() - return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) patch_function(bitsandbytes.nn.modules.Linear4bit, "forward", forward) try: diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index 806252cd3..c4ff97629 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -489,6 +489,43 @@ def return_attention_mask(*args, **kwargs): TEMPORARY_PATCHES.append(patch_transformers_masks) +def patch_modernbert_attention_mask(): + """Fix ModernBERT attn_bias stride alignment for SDPA backward pass. + + The attention mask created by _prepare_4d_attention_mask uses .expand() + which creates non-contiguous strides. The SDPA compiled backward kernel + requires strides to be multiples of 4. Fix: patch _update_attention_mask + on ModernBertModel to return contiguous masks BEFORE they enter + torch.compile regions, so the inductor backward graph uses aligned strides. + """ + try: + import transformers.models.modernbert.modeling_modernbert as modernbert_module + except Exception: + return # ModernBERT not available, skip + + ModernBertModel = getattr(modernbert_module, "ModernBertModel", None) + if ModernBertModel is None: + return + + original_update = getattr(ModernBertModel, "_update_attention_mask", None) + if original_update is None: + return + + def _update_attention_mask_contiguous(self, attention_mask, output_attentions=False): + global_attention_mask, sliding_window_mask = original_update(self, attention_mask, output_attentions=output_attentions) + # Make masks contiguous so SDPA backward (including compiled graphs) + # gets strides that are multiples of 4 + if global_attention_mask is not None and not global_attention_mask.is_contiguous(): + global_attention_mask = global_attention_mask.contiguous() + if sliding_window_mask is not None and not sliding_window_mask.is_contiguous(): + sliding_window_mask = sliding_window_mask.contiguous() + return global_attention_mask, sliding_window_mask + + ModernBertModel._update_attention_mask = _update_attention_mask_contiguous +pass +TEMPORARY_PATCHES.append(patch_modernbert_attention_mask) + + def patch_CsmForConditionalGeneration_merge(): try: import transformers.models.csm.modeling_csm @@ -583,6 +620,106 @@ def _merge_input_ids_with_input_values( TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_merge) +def patch_causal_conv1d_cuda_probe(): + """Probe causal_conv1d CUDA kernels and force slow path if broken. + + On GPUs whose compute capability is not supported by pre-built causal_conv1d + CUDA kernels (e.g. sm_100 on B200), `import causal_conv1d` succeeds but calling + `causal_conv1d_fn(...)` fails at runtime with "no kernel image is available". + This probe runs a tiny forward pass at startup to detect the failure, then + nullifies causal_conv1d_fn/causal_conv1d_update everywhere so all Mamba-family + models fall back to their pure-PyTorch slow paths. + """ + try: + import causal_conv1d + from causal_conv1d import causal_conv1d_fn + from causal_conv1d import causal_conv1d_update + except ImportError: + return # Package not installed, transformers already handles this + pass + + if causal_conv1d_fn is None: + return # Already nullified + pass + + if not torch.cuda.is_available(): + return + pass + + # Probe: try a tiny CUDA forward pass + try: + device = torch.device("cuda", torch.cuda.current_device()) + x = torch.randn(1, 4, 8, device=device, dtype=torch.float16) + w = torch.randn(4, 4, device=device, dtype=torch.float16) + b = torch.zeros(4, device=device, dtype=torch.float16) + _ = causal_conv1d_fn(x, w, b, activation="silu") + del x, w, b + return # CUDA kernels work fine + except Exception: + pass # Fall through to disable + pass + + print( + "Unsloth: causal_conv1d CUDA kernels not compatible with this GPU. " + "Using PyTorch slow path for Mamba models." + ) + + import sys + + # 1. Nullify the package exports themselves + for mod_name in ("causal_conv1d", "causal_conv1d.causal_conv1d_interface"): + mod = sys.modules.get(mod_name) + if mod is not None: + if hasattr(mod, "causal_conv1d_fn"): + mod.causal_conv1d_fn = None + if hasattr(mod, "causal_conv1d_update"): + mod.causal_conv1d_update = None + pass + pass + + # 2. Patch is_causal_conv1d_available to return False + try: + import transformers.utils.import_utils + transformers.utils.import_utils.is_causal_conv1d_available = lambda: False + except Exception: + pass + pass + + # 3. Dynamically scan all loaded modules and nullify broken causal_conv1d + # references. Uses identity checks (is) against the original function objects + # to avoid clobbering vllm's independent Triton-based causal_conv1d_fn/update. + _original_fn = causal_conv1d_fn + _original_update = causal_conv1d_update + + def _disabled_lazy_load(): + return (None, None) + pass + + for mod in list(sys.modules.values()): + if mod is None: + continue + # Only nullify references that point to the causal_conv1d package's functions + touched = False + if getattr(mod, "causal_conv1d_fn", None) is _original_fn: + mod.causal_conv1d_fn = None + touched = True + if getattr(mod, "causal_conv1d_update", None) is _original_update: + mod.causal_conv1d_update = None + touched = True + # is_fast_path_available = all((causal_conv1d_fn, ...)) -- must be False + # Only touch it on modules where we just nullified causal_conv1d refs + if touched and getattr(mod, "is_fast_path_available", False): + mod.is_fast_path_available = False + # Replace lazy load stubs (Pattern B: mamba, falcon_mamba) + if hasattr(mod, "_lazy_load_causal_conv1d"): + mod._lazy_load_causal_conv1d = _disabled_lazy_load + if hasattr(mod, "_causal_conv1d_cache"): + mod._causal_conv1d_cache = (None, None) + pass +pass +TEMPORARY_PATCHES.append(patch_causal_conv1d_cuda_probe) + + def patch_GraniteMoeHybridMambaLayer_cuda_kernels_forward(): try: import transformers.models.granitemoehybrid.modeling_granitemoehybrid @@ -939,3 +1076,196 @@ def forward( patch_function(transformers.models.siglip.modeling_siglip.SiglipEncoderLayer, "forward", forward) pass TEMPORARY_PATCHES.append(patch_SiglipEncoderLayer) + + +def patch_Lfm2VlMultiModalProjector(): + """Fix Lfm2VlMultiModalProjector unconditionally creating LayerNorm. + + transformers 4.57.6 ignores config.projector_use_layernorm and always + creates nn.LayerNorm + applies it in forward. The model checkpoint for + LFM2.5-VL-1.6B has projector_use_layernorm=False and ships no layer_norm + weights, so the LayerNorm gets randomly initialized and corrupts features. + Fixed in transformers 5.0.0. This patch backports the fix. + """ + try: + import transformers.models.lfm2_vl.modeling_lfm2_vl as lfm2_vl_module + except Exception: + return + + Projector = getattr(lfm2_vl_module, "Lfm2VlMultiModalProjector", None) + if Projector is None: + return + + # Already patched or already has conditional logic (transformers >= 5.0.0) + if hasattr(Projector, "_unsloth_patched") or "use_layer_norm" in (getattr(Projector.__init__, "__code__", None) and Projector.__init__.__code__.co_varnames or ()): + return + + import torch.nn as nn + original_init = Projector.__init__ + original_forward = Projector.forward + + def patched_init(self, config, *args, **kwargs): + original_init(self, config, *args, **kwargs) + self.use_layer_norm = getattr(config, "projector_use_layernorm", True) + if not self.use_layer_norm: + self.layer_norm = None + + def patched_forward(self, image_features): + image_features = self.pixel_unshuffle(image_features) + if getattr(self, "use_layer_norm", True) and self.layer_norm is not None: + image_features = self.layer_norm(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + Projector.__init__ = patched_init + Projector.forward = patched_forward + Projector._unsloth_patched = True +pass +TEMPORARY_PATCHES.append(patch_Lfm2VlMultiModalProjector) + + +def patch_peft_dispatch_bnb_4bit(): + """Fix PEFT dispatch_bnb_4bit accessing compress_statistics on non-Params4bit weights. + + In transformers 5.0+, BNB quantization loading order changed so weights may still be + nn.Parameter (not Params4bit) when PEFT tries to access .compress_statistics and .quant_type. + This wraps the original dispatch to catch AttributeError and provide defaults. + """ + try: + import peft.tuners.lora.bnb as peft_bnb + original_dispatch = peft_bnb.dispatch_bnb_4bit + except (ImportError, AttributeError): + return + + if hasattr(original_dispatch, "_unsloth_patched"): + return + + def safe_dispatch_bnb_4bit(target, adapter_name, **kwargs): + try: + return original_dispatch(target, adapter_name, **kwargs) + except AttributeError as e: + if "compress_statistics" in str(e) or "quant_type" in str(e): + # Transformers 5.0+: weight not yet quantized as Params4bit + # Retry after ensuring weight has needed attributes + w = target.weight + if not hasattr(w, "compress_statistics"): + w.compress_statistics = getattr( + target, "_bnb_compress_statistics", True + ) + if not hasattr(w, "quant_type"): + w.quant_type = getattr(target, "_bnb_quant_type", "nf4") + return original_dispatch(target, adapter_name, **kwargs) + raise + + safe_dispatch_bnb_4bit._unsloth_patched = True + peft_bnb.dispatch_bnb_4bit = safe_dispatch_bnb_4bit +pass +TEMPORARY_PATCHES.append(patch_peft_dispatch_bnb_4bit) + + +def patch_trl_push_to_hub_token(): + """Ensure to_dict() always includes push_to_hub_token for TRL compat. + + TRL 0.22.x through 0.27.1 do bare dict_args.pop("push_to_hub_token") in + SFTTrainer.__init__ and IterativeSFTTrainer.__init__. On transformers 5.0+, + TrainingArguments.to_dict() no longer includes push_to_hub_token, so the + bare pop raises KeyError. Fix: monkey-patch to_dict() to always include it. + """ + try: + from unsloth_zoo.utils import Version + import transformers + if Version(transformers.__version__) < Version("5.0.0"): + return # Not needed pre-5.0, to_dict() already includes it + from transformers import TrainingArguments + _original_to_dict = TrainingArguments.to_dict + if getattr(_original_to_dict, "_unsloth_patched", False): + return + def _patched_to_dict(self): + d = _original_to_dict(self) + if "push_to_hub_token" not in d: + d["push_to_hub_token"] = None + return d + _patched_to_dict._unsloth_patched = True + TrainingArguments.to_dict = _patched_to_dict + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_trl_push_to_hub_token) + + +def patch_trl_vision_model_mapping(): + """Fix DPO vision model detection for TRL 0.22.x + transformers 5.0+. + + TRL 0.22.x does a bare import of MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from + transformers.models.auto.modeling_auto. This name was removed in transformers + 5.0.0, replaced by MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES. The import + failure prevents DPO trainer from loading at all. + + Fix: inject the old name as an alias of the new name into the transformers + auto modeling module BEFORE TRL imports it, so the bare import succeeds. + Also patch already-loaded DPO module if it fell back to empty dict. + """ + try: + import transformers.models.auto.modeling_auto as auto_mod + except ImportError: + return + # If the old name already exists and is populated, nothing to do + existing = getattr(auto_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) + if existing is not None and len(existing) > 0: + return + # Inject the old name as alias of the new name + new_mapping = getattr(auto_mod, "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", None) + if new_mapping is not None: + auto_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new_mapping + # Also patch already-loaded DPO module if present + try: + import trl.trainer.dpo_trainer as dpo_mod + dpo_current = getattr(dpo_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) + if (dpo_current is None or len(dpo_current) == 0) and new_mapping is not None: + dpo_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new_mapping + except ImportError: + pass +pass +TEMPORARY_PATCHES.append(patch_trl_vision_model_mapping) + + +def patch_vllm_safe_apply_chat_template(): + """Fix vLLM safe_apply_chat_template for transformers 5.0+. + + transformers 5.0.0 changed apply_chat_template(tokenize=True) to default + return_dict=True, returning BatchEncoding instead of list[int]. vLLM's + safe_apply_chat_template doesn't pass return_dict=False, causing TypeError + in _validate_model_input when max(BatchEncoding) returns a string key. + + Fix: wrap the original function to inject return_dict=False when tokenize=True. + """ + try: + from unsloth_zoo.utils import Version + import transformers + if Version(transformers.__version__) < Version("5.0.0"): + return + + import vllm.renderers.hf as hf_mod + _original_safe_apply = getattr(hf_mod, "safe_apply_chat_template", None) + if _original_safe_apply is None: + return + if getattr(_original_safe_apply, "_unsloth_patched", False): + return + + def _patched_safe_apply(model_config, tokenizer, conversation, *, + tools=None, chat_template=None, tokenize=True, **kwargs): + if tokenize: + kwargs["return_dict"] = False + return _original_safe_apply( + model_config, tokenizer, conversation, + tools=tools, chat_template=chat_template, tokenize=tokenize, + **kwargs, + ) + _patched_safe_apply._unsloth_patched = True + hf_mod.safe_apply_chat_template = _patched_safe_apply + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_vllm_safe_apply_chat_template) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 9a05cc9ad..c3a0a99a9 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -482,6 +482,11 @@ def patch_tokenizer(model, tokenizer): Fixes https://github.com/unslothai/unsloth/issues/5 """ # All Unsloth Zoo code licensed under LGPLv3 + + # Guard against None tokenizer (e.g., some VLM processors without tokenizer) + if tokenizer is None: + return model, tokenizer + joiner = "\1\0=+=\0\1" number_repetitions = 3 - 1 # Number of reserved tokens needed @@ -492,7 +497,12 @@ def patch_tokenizer(model, tokenizer): if hasattr(tokenizer, "image_processor") and hasattr(tokenizer, "apply_chat_template"): patch_processor_call(tokenizer) - if hasattr(tokenizer, "tokenizer"): tokenizer = tokenizer.tokenizer + if hasattr(tokenizer, "tokenizer"): + inner = tokenizer.tokenizer + if inner is None: + # Processor exists but inner tokenizer is None - return as-is + return model, original_tokenizer + tokenizer = inner bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: @@ -592,7 +602,7 @@ def patch_tokenizer(model, tokenizer): model.generation_config.update(pad_token_id = tokenizer.pad_token_id) else: if model is not None: - if model.config.pad_token_id is None: + if getattr(model.config, "pad_token_id", None) is None: model.config.update({"pad_token_id" : tokenizer.pad_token_id}) if getattr(model, "generation_config", None) is not None: model.generation_config.update(pad_token_id = tokenizer.pad_token_id) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index f89df7c19..ea025ac7b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1329,7 +1329,11 @@ def _override_to(self, *args, **kwargs): layer.quant_method = "fbgemm_fp8" elif fp8_weight_scale.ndim == 2: # This denotes that the model if FP8 dynamic quantized. - layer = FP8Linear(in_features = 0, out_features = 0, bias = has_bias, dtype = dtype, block_size = kwargs['block_size'], device = get_target_device(), activation_scheme = kwargs['activation_scheme']) + fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme']) + # transformers 5.0+ removed device param from FP8Linear.__init__ + if Version("transformers") < Version("5.0.0"): + fp8_kwargs["device"] = get_target_device() + layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] layer.weight = torch.nn.Parameter(weight, requires_grad = False)