diff --git a/pyproject.toml b/pyproject.toml index 7b1d2efda4..6e1bea6960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.11", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.9", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -511,4 +511,4 @@ cu126-ampere-torch260 = [ [project.urls] homepage = "http://www.unsloth.ai" documentation = "https://github.com/unslothai/unsloth" -repository = "https://github.com/unslothai/unsloth" +repository = "https://github.com/unslothai/unsloth" \ No newline at end of file diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 7ffddde9b0..80aa3bda67 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,10 +198,10 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.11"): + if Version(unsloth_zoo_version) < Version("2025.3.13"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ - "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" + "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" ) if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0": try: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 69cc1e6884..e2b35c5ff6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.14" +__version__ = "2025.3.15" __all__ = [ "SUPPORTS_BFLOAT16", @@ -182,6 +182,15 @@ def filter(self, x): return not (self.text in x.getMessage()) except: pass +# Gemma3 It is strongly recommended to train Gemma3 models with the `eager` +try: + from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger + gemma3_logger.addFilter(HideLoggingMessage("strongly recommended")) + del gemma3_logger +except: + pass + + # Patch get_model_param_count to record correct 4bit / 8bit from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled def get_model_param_count(model, trainable_only = False): @@ -1016,13 +1025,7 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - autocaster = contextlib.nullcontext() - else: - autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) - with autocaster: - outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) return outputs pass @@ -1126,7 +1129,9 @@ def patch_fast_lora(): def unsloth_compile_transformers( + dtype, model_name, + model_types, token = None, revision = None, trust_remote_code = False, @@ -1164,15 +1169,12 @@ def unsloth_compile_transformers( ) return pass - - model_types = get_transformers_model_type( - model_name = model_name, - token = token, - revision = revision, - trust_remote_code = trust_remote_code, - ) - model_types = ["siglip"] + model_types - + if trust_remote_code: + print( + "Unsloth: We can't trace models if `trust_remote_code = True`, "\ + "so turning off some optimizations!" + ) + return if disable: return for model_type in model_types: @@ -1204,6 +1206,9 @@ def unsloth_compile_transformers( return_logits = return_logits, ) pass + # Redo patches which override compiler + for temporary_patch in TEMPORARY_PATCHES: + temporary_patch() return model_types pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 893a09dd14..07805271f5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1548,7 +1548,7 @@ def unsloth_fast_generate( if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings: raise ValueError( - f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ + f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n'\ 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' ) pass @@ -1562,7 +1562,10 @@ def unsloth_fast_generate( # For newer HF kwargs["cache_implementation"] = "dynamic" # For num_logits_to_keep - kwargs["num_logits_to_keep"] = 1 + num_logits_to_keep = kwargs.get("num_logits_to_keep", None) + logits_to_keep = kwargs.get("logits_to_keep", None) + if num_logits_to_keep is None and logits_to_keep is None: + kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids kwargs.pop("token_type_ids", None) @@ -1822,7 +1825,7 @@ def from_pretrained( # Convert to HF format _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) - model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) + model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype, bnb_config) model.vllm_engine = llm model.fast_generate = model.vllm_engine.generate model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 44475780af..cd59e0365d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -17,6 +17,7 @@ HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING, USE_MODELSCOPE, + get_transformers_model_type, ) from .granite import FastGraniteModel from .llama import FastLlamaModel, logger @@ -66,6 +67,11 @@ unsloth_compile_transformers, ) +global FORCE_FLOAT32 +FORCE_FLOAT32 = [ + "gemma3", +] + class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( @@ -212,7 +218,13 @@ def from_pretrained( f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\ f"to obtain the latest transformers build, then restart this session."\ ) - raise RuntimeError(autoconfig_error or peft_error) + # Create a combined error message showing both failures + combined_error = ( + "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n" + f"AutoConfig error: {autoconfig_error}\n\n" + f"PeftConfig error: {peft_error}\n\n" + ) + raise RuntimeError(combined_error) pass # Get base model for PEFT: @@ -460,12 +472,17 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() - assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) + + SUPPORTS_BFLOAT16 = is_bfloat16_supported() + if dtype is None: + dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + logger.warning_once("Device does not support bfloat16. Will change to float16.") + dtype = torch.float16 + assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) patch_compiled_autograd() patch_compiling_bitsandbytes() - if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing(dtype = dtype) if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -479,11 +496,6 @@ def from_pretrained( "Also, we by default set `load_in_4bit = True`.\n"\ "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`" ) - if load_in_4bit: pass - elif load_in_8bit: pass - elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") - load_in_4bit = True pass old_model_name = model_name @@ -591,7 +603,13 @@ def from_pretrained( f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\ f"to obtain the latest transformers build, then restart this session."\ ) - raise RuntimeError(autoconfig_error or peft_error) + # Create a combined error message showing both failures + combined_error = ( + "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n" + f"AutoConfig error: {autoconfig_error}\n\n" + f"PeftConfig error: {peft_error}\n\n" + ) + raise RuntimeError(combined_error) pass # Get base model for PEFT: @@ -616,10 +634,39 @@ def from_pretrained( else: redirector = contextlib.redirect_stdout(open(os.devnull, "w")) + # Get model types like Gemma3 etc + model_types = get_transformers_model_type( + model_name = model_name, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, + ) + model_types = ["siglip"] + model_types + + # Set forced float32 env flag + os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + do_forced_float32 = False + model_type_arch = model_types[1] + global FORCE_FLOAT32 + for disable_name in FORCE_FLOAT32: + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()) and \ + ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + dtype = torch.bfloat16 # Change to bfloat16 loading + break + pass + # Patch gradient checkpointing + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) + with redirector: patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( + dtype = dtype, model_name = model_name, + model_types = model_types, + token = token, sdpa_dynamic_mask = True, sdpa_bool_masks = True, sdpa_gqa_replace = True, @@ -644,6 +691,7 @@ def from_pretrained( import_from_cache = False, disable = False, return_logits = return_logits, + trust_remote_code = trust_remote_code, ) pass diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c450ef6df5..5d2270810c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -439,6 +439,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "eval_accumulation_steps" : 2, "torch_empty_cache_steps" : 250, "logging_steps" : 1, + "max_seq_length" : None, } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4071ef835a..a3b2d1de8a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -176,8 +176,9 @@ def grpo_trainer__prepare_inputs(function_name, function): "with torch.inference_mode(), "\ "torch.amp.autocast(device_type = 'cuda', "\ - "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext():", + "dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext())"\ + "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):", ) # Disable attaching a float32 conversion hook which upcasts logits to FP32 @@ -212,7 +213,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits @@ -254,11 +255,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape - # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - attention_mask = None + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + # attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 24015f82fe..53a873d168 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -53,6 +53,7 @@ import functools from typing import Optional, Tuple, List, Union import re, inspect, sys +import contextlib import types try: from huggingface_hub.utils import get_token @@ -65,11 +66,6 @@ "FastBaseModel", ] -global FORCE_FLOAT32 -FORCE_FLOAT32 = [ - "gemma3", -] - global FORCE_EAGER_ATTENTION FORCE_EAGER_ATTENTION = [ "pixtral", # Pixtral SDPA not implemented @@ -77,12 +73,23 @@ global NUM_LOGITS_TO_KEEP NUM_LOGITS_TO_KEEP = dict() +global PROMPT_LOOPKUP +PROMPT_LOOPKUP = dict() def unsloth_base_fast_generate( self, *args, **kwargs, ): + if len(args) != 0: + x = args[0] + elif "input_ids" in kwargs: + x = kwargs["input_ids"] + else: + raise TypeError("Unsloth: You need to pass in input_ids to .generate!") + assert(type(x) is torch.Tensor) + bsz = x.shape[0] + FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -98,34 +105,35 @@ def unsloth_base_fast_generate( kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep - if not is_vlm: - global NUM_LOGITS_TO_KEEP - if arch not in NUM_LOGITS_TO_KEEP: - m = self - # Find which is needed ie - # num_logits_to_keep or logits_to_keep - while hasattr(m, "model"): - if hasattr(m, "forward"): - keys = inspect.signature(m.forward).parameters.keys() - if "num_logits_to_keep" in keys: - NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" - break - elif "logits_to_keep" in keys: - NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" - break - m = m.model - pass - if arch not in NUM_LOGITS_TO_KEEP: - NUM_LOGITS_TO_KEEP[arch] = None - pass + global NUM_LOGITS_TO_KEEP + if arch not in NUM_LOGITS_TO_KEEP: + m = self + # Find which is needed ie + # num_logits_to_keep or logits_to_keep + while hasattr(m, "model"): + if hasattr(m, "forward"): + keys = inspect.signature(m.forward).parameters.keys() + if "num_logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" + break + elif "logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" + break + m = m.model pass - key = NUM_LOGITS_TO_KEEP[arch] - if key is not None and key not in kwargs: - kwargs[key] = 1 - else: + if arch not in NUM_LOGITS_TO_KEEP: + NUM_LOGITS_TO_KEEP[arch] = None pass - # kwargs.pop("logits_to_keep", None) - # kwargs.pop("num_logits_to_keep", None) + pass + key = NUM_LOGITS_TO_KEEP[arch] + if key is not None and key not in kwargs: + kwargs[key] = 1 + global PROMPT_LOOPKUP + if arch not in PROMPT_LOOPKUP: + PROMPT_LOOPKUP[arch] = True + + if bsz == 1 and PROMPT_LOOPKUP[arch]: + kwargs["prompt_lookup_num_tokens"] = 3 # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) @@ -138,10 +146,20 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + if "use_cache" not in kwargs: kwargs["use_cache"] = True + # Mixed precision autocast - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 - with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): - output = self._old_generate(*args, **kwargs) + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + else: + autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + with torch.inference_mode(), autocaster: + try: + output = self._old_generate(*args, **kwargs) + except: + PROMPT_LOOPKUP[arch] = False + kwargs.pop("prompt_lookup_num_tokens", None) + output = self._old_generate(*args, **kwargs) pass FastBaseModel.for_training(self) @@ -209,24 +227,20 @@ def from_pretrained( if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + if dtype == torch.float16: dtype = torch.bfloat16 elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 + pass + assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) - assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) - - global FORCE_FLOAT32 - os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype - for disable_name in FORCE_FLOAT32: - if (disable_name.lower() == model_type_arch.lower() or \ - disable_name.lower() in model_name.lower()) and \ - dtype == torch.float16: - - print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") - os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" - bnb_compute_dtype = torch.float32 - break + do_forced_float32 = False + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") + bnb_compute_dtype = torch.float16 + do_forced_float32 = True pass global FORCE_EAGER_ATTENTION @@ -263,15 +277,7 @@ def from_pretrained( llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") - load_in_4bit = True - bnb_config = BitsAndBytesConfig( - load_in_4bit = True, - bnb_4bit_use_double_quant = True, - bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = bnb_compute_dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), - ) + print("Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.") pass if full_finetuning: @@ -289,10 +295,13 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config + # Check if using forced float32 - we load it in bfloat16, then cast to float16! + torch_dtype = dtype + if do_forced_float32: torch_dtype = torch.bfloat16 model = auto_model.from_pretrained( model_name, device_map = device_map, - torch_dtype = dtype, + torch_dtype = torch_dtype, # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, @@ -325,15 +334,16 @@ def from_pretrained( tokenizer.pad_token = __tokenizer.pad_token tokenizer.pad_token_id = __tokenizer.pad_token_id pass - model, tokenizer = patch_tokenizer(model, tokenizer) - model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types model, tokenizer = patch_model_and_tokenizer( model, tokenizer, downcast_rope = False, fix_embeddings = False, + do_forced_float32 = do_forced_float32, ) + model, tokenizer = patch_tokenizer(model, tokenizer) + model = post_patch_loss_function(model) # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 26669127d7..067f2596c6 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -686,12 +686,12 @@ def fix_chat_template(tokenizer): raise RuntimeError( f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\ "does not have a {% if add_generation_prompt %} for generation purposes.\n"\ - "Please file a bug report immediately - thanks!" + f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!" ) else: logger.warning_once( "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\ - "This is not a bug, but please notify the Unsloth maintainers - thanks!" + f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!" ) chat_template = new_chat_template pass