diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 3ce601b22..17593b25e 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.5.5" +__version__ = "2025.5.6" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 7c8ecef83..35c9caca9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -493,7 +493,7 @@ def create_standalone_class( source = f"{compile}\n{source}\n" - left = re.match("[\s\n]{4,}", leftover).span()[1] + left = re.match(r"[\s\n]{4,}", leftover).span()[1] new_forward = definition + leftover[:left] + \ f"return {module}_forward({parameters})\n" full_class = full_class.replace(old_source, new_forward) @@ -505,6 +505,9 @@ def create_standalone_class( # Combine all into file source = source + full_class + # Remove @auto_docstring + source = source.replace("@auto_docstring", "") + # Fix Gemma 3 ignore_index being not set! source = source.replace("self.config.ignore_index", "-100") return source @@ -1470,18 +1473,45 @@ def unsloth_compile_transformers( if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return # Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False` - exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals()) + exec("model_logger.addFilter(HideLoggingMessage('`use_cache=True`'))", globals(), locals()) + + # Instead of Inductor Compilation: + try: + import torch._inductor.async_compile + from torch.hub import tqdm + def replaced_tqdm(*args, **kwargs): + kwargs["desc"] = "Unsloth: Compiling kernels" + return tqdm(*args, **kwargs) + torch._inductor.async_compile.tqdm = replaced_tqdm + except: + print("Unsloth: Failed editing tqdm to replace Inductor Compilation:") + pass # torch_compile_options UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" + UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" torch_compile_options = { - "epilogue_fusion" : epilogue_fusion, - "max_autotune" : max_autotune, - "shape_padding" : shape_padding, - "trace.enabled" : UNSLOTH_COMPILE_DEBUG or debug, - "triton.cudagraphs" : cudagraphs, + "epilogue_fusion" : epilogue_fusion, + "max_autotune" : max_autotune, + "shape_padding" : shape_padding, + "trace.enabled" : UNSLOTH_COMPILE_DEBUG or debug, + "triton.cudagraphs" : cudagraphs, + "debug" : UNSLOTH_COMPILE_DEBUG or debug, + "dce" : True, + "memory_planning" : True, + "coordinate_descent_tuning" : UNSLOTH_COMPILE_MAXIMUM, + "trace.graph_diagram" : UNSLOTH_COMPILE_DEBUG or debug, + "compile_threads" : 24, + "combo_kernels" : False, # Causes incompatible gradient sizes on 2.6 + "group_fusion" : True, + "disable_progress" : not UNSLOTH_ENABLE_LOGGING, + "verbose_progress" : UNSLOTH_ENABLE_LOGGING, + "triton.multi_kernel" : False, # Sometimes fails + "triton.use_block_ptr" : True, + "triton.enable_persistent_tma_matmul" : True, + "triton.autotune_at_compile_time" : True, } # Return logits @@ -1705,6 +1735,18 @@ def unsloth_compile_transformers( bad_torch_modules.add(module) pass + # Remove decoder layers + if "for layer in self." in source: + print(f"Unsloth: Failed compiling function {module} since it looks like a decoder!") + bad_torch_modules.add(module) + pass + + # Remove padding + if "nn.functional.pad" in source or "padding" in source: + print(f"Unsloth: Failed compiling function {module} since there is padding done.") + bad_torch_modules.add(module) + pass + # Check for residual streams optimizations if fast_residual_stream and "residual" in source: new_source = patch_residual_stream(source) @@ -1790,7 +1832,7 @@ def unsloth_compile_transformers( # Remove causal masks do_not_remove = False for module in remove_causal_masks: - if module.endswith(("ForConditionalGeneration")): + if module.endswith(("ForConditionalGeneration", "Gemma3Model")): do_not_remove = True print(f"Unsloth: Will not remove causal mask for {model_location} since it's a VLM!") break diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 127484dd7..4a092faf1 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -189,6 +189,7 @@ def patch_model_and_tokenizer( downcast_rope = True, fix_embeddings = True, do_forced_float32 = False, + correct_dtype = None, ): # All Unsloth Zoo code licensed under LGPLv3 assert(type(downcast_rope) is bool) @@ -223,10 +224,12 @@ def patch_model_and_tokenizer( pass # Get most likely the correct data-type of the model - try: - correct_dtype = _get_dtype(model.config.torch_dtype) - except: - correct_dtype = model.get_input_embeddings().weight.dtype + if correct_dtype is None: + try: + correct_dtype = _get_dtype(model.config.torch_dtype) + except: + correct_dtype = model.get_input_embeddings().weight.dtype + pass # If we force float32, we first use bfloat16, then downcast to float16 if do_forced_float32: correct_dtype = torch.float16 @@ -242,29 +245,31 @@ def patch_model_and_tokenizer( assert(module.weight.dtype == torch.float32) torch.cuda.empty_cache() pass + pass - # Correct torch_dtype - def __fix_dtype(config): - if not hasattr(config, "to_dict"): return - dicts = config.to_dict() - for key, value in dicts.items(): - if key == "torch_dtype": - setattr(config, "torch_dtype", torch.float16) - else: - __fix_dtype(getattr(config, key)) - m = model - while hasattr(m, "model"): - if hasattr(m, "dtype"): - try: setattr(m, "dtype", torch.float16) - except: pass - if hasattr(m, "config"): __fix_dtype(m.config) - m = m.model - pass - if hasattr(m, "config"): __fix_dtype(m.config) + # Correct torch_dtype + def __fix_dtype(config): + if not hasattr(config, "to_dict"): return + dicts = config.to_dict() + for key, value in dicts.items(): + if key == "torch_dtype": + setattr(config, "torch_dtype", correct_dtype) + else: + __fix_dtype(getattr(config, key)) + m = model + while hasattr(m, "model"): if hasattr(m, "dtype"): - try: setattr(m, "dtype", torch.float16) + try: setattr(m, "dtype", correct_dtype) except: pass + if hasattr(m, "config"): __fix_dtype(m.config) + m = m.model pass + if hasattr(m, "config"): __fix_dtype(m.config) + if hasattr(m, "dtype"): + try: setattr(m, "dtype", correct_dtype) + except: pass + pass + # Check all params and patch! for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 0cfdd2bb4..688165afa 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -18,6 +18,7 @@ from typing import Union, List, Any, Tuple, Dict, Callable, Optional import inspect import torch +import torch.nn import os import logging @@ -272,6 +273,7 @@ def forward( **lm_kwargs, ) labels = None + # We NEVER ENTER if labels is not None: since we already accounted for it logits = outputs.logits @@ -307,13 +309,109 @@ def forward( image_hidden_states=image_features if pixel_values is not None else None, ) pass + + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + pass + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward + return + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None and attention_mask is not None: + attention_mask = attention_mask.to(device = labels.device) + labels[attention_mask == 0] = -100 + pass + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **lm_kwargs, + ) + labels = None + # We NEVER ENTER if labels is not None: since we already accounted for it + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + loss = outputs.loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward - return pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) @@ -395,12 +493,21 @@ def _update_causal_mask( return causal_mask pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask).parameters - new_keys = inspect.signature(_update_causal_mask).parameters - if old_keys != new_keys: - print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") + + if hasattr(transformers.models.gemma3.modeling_gemma3, "Gemma3Model"): + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Model._update_causal_mask).parameters + new_keys = inspect.signature(_update_causal_mask).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3Model.") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3Model._update_causal_mask = _update_causal_mask else: - transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask).parameters + new_keys = inspect.signature(_update_causal_mask).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3ForConditionalGeneration._update_causal_mask.") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask return pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_causal_mask) @@ -583,7 +690,6 @@ def patch_SmolVLMForConditionalGeneration_forward(): from typing import List, Optional, Tuple, Union from transformers.models.smolvlm.modeling_smolvlm import ( - CrossEntropyLoss, SmolVLMCausalLMOutputWithPast, SmolVLMForConditionalGeneration, ) @@ -675,7 +781,7 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(