diff --git a/pyproject.toml b/pyproject.toml index 02bcf4bb60..7f24aabbf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.16", + "unsloth_zoo>=2025.3.17", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.16", + "unsloth_zoo>=2025.3.17", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 708eeaf9e4..d401b7205f 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.16"): + if Version(unsloth_zoo_version) < Version("2025.3.17"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0044c7e761..840c15c003 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.18" +__version__ = "2025.3.19" __all__ = [ "SUPPORTS_BFLOAT16", @@ -1176,9 +1176,10 @@ def unsloth_compile_transformers( "so turning off some optimizations!" ) return - if disable: return - model_types = list(dict().fromkeys(model_types).keys()) + if disable: return model_types, False + + supports_sdpa = [True] for model_type in model_types: _unsloth_compile_transformers( model_type, @@ -1206,12 +1207,13 @@ def unsloth_compile_transformers( import_from_cache = import_from_cache, disable = disable, return_logits = return_logits, + supports_sdpa = supports_sdpa, ) pass # Redo patches which override compiler for temporary_patch in TEMPORARY_PATCHES: temporary_patch() - return model_types + return model_types, supports_sdpa[0] pass # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b3b49a0436..722b50d27a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2024,6 +2024,14 @@ def get_peft_model( **kwargs, ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + # Check for other PEFT args in kwargs + for (peft_arg, flag) in ( + ("finetune_vision_layers", False), + ("finetune_language_layers", True), + ("finetune_attention_modules", True), + ("finetune_mlp_modules", True), + ): + if peft_arg not in kwargs: kwargs[peft_arg] = flag return FastBaseModel.get_peft_model( model = model, r = r, @@ -2031,10 +2039,6 @@ def get_peft_model( lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = bias, - finetune_vision_layers = False, - finetune_language_layers = True, - finetune_attention_modules = True, - finetune_mlp_modules = True, layers_to_transform = layers_to_transform, layers_pattern = layers_pattern, use_gradient_checkpointing = use_gradient_checkpointing, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 670e082580..cac5acd838 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -663,7 +663,7 @@ def from_pretrained( with redirector: patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( + model_types, supports_sdpa = unsloth_compile_transformers( dtype = dtype, model_name = model_name, model_types = model_types, @@ -726,6 +726,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, auto_model = auto_model, use_gradient_checkpointing = use_gradient_checkpointing, + supports_sdpa = supports_sdpa, *args, **kwargs, ) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index cf250dd498..91ed262502 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -728,6 +728,16 @@ "mistralai/Mistral-Small-3.1-24B-Base-2503", "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", ), + "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : ( + "unsloth/orpheus-3b-0.1-pretrained", + "canopylabs/orpheus-3b-0.1-pretrained", + "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit", + ), + "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : ( + "unsloth/orpheus-3b-0.1-ft", + "canopylabs/orpheus-3b-0.1-ft", + "unsloth/orpheus-3b-0.1-ft-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a3b2d1de8a..376d1e9a28 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -79,7 +79,7 @@ def sft_trainer_prepare_dataset(function_name, function): function_name != "_prepare_dataset": return function fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None) - if fast_sft_prepare_dataset is not None and "pack_examples" in function: + if fast_sft_prepare_dataset is not None: params = inspect.signature(fast_sft_prepare_dataset).parameters.keys() params = ".*?".join(params) matched = re.match( diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ef32ab1847..f05cc95d60 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -66,11 +66,6 @@ "FastBaseModel", ] -global FORCE_EAGER_ATTENTION -FORCE_EAGER_ATTENTION = [ - "pixtral", # Pixtral SDPA not implemented -] - global NUM_LOGITS_TO_KEEP NUM_LOGITS_TO_KEEP = dict() global PROMPT_LOOPKUP @@ -145,8 +140,11 @@ def unsloth_base_fast_generate( kwargs[key] = 1 global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: - PROMPT_LOOPKUP[arch] = True - + # Only works for VLMs and not LLMs! + if is_vlm: + PROMPT_LOOPKUP[arch] = False + else: + PROMPT_LOOPKUP[arch] = True if bsz == 1 and PROMPT_LOOPKUP[arch]: kwargs["prompt_lookup_num_tokens"] = 3 @@ -237,8 +235,14 @@ def from_pretrained( tokenizer_name = None, auto_model = AutoModelForVision2Seq, use_gradient_checkpointing = "unsloth", + supports_sdpa = True, **kwargs, ): + if model_types is None: + raise RuntimeError( + "Unsloth: Please use FastModel or FastVisionModel and not use FastBaseModel directly!" + ) + os.environ["UNSLOTH_USE_NEW_MODEL"] = "1" if trust_remote_code: print( @@ -299,16 +303,11 @@ def from_pretrained( bnb_compute_dtype = torch.float16 do_forced_float32 = True pass - - global FORCE_EAGER_ATTENTION - attn_implementation = "sdpa" - for disable_name in FORCE_EAGER_ATTENTION: - if (disable_name.lower() == model_type_arch.lower() or \ - disable_name.lower() in model_name.lower()): - - print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") - attn_implementation = "eager" - break + # Stop SDPA for some archs like Pixtral / Mistral3 + kwargs["attn_implementation"] = "sdpa" + if not supports_sdpa: + print(f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to eager!") + del kwargs["attn_implementation"] pass bnb_config = None @@ -347,8 +346,6 @@ def from_pretrained( os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "0" pass - kwargs.pop("attn_implementation", None); # No need since we auto call it - # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config @@ -362,7 +359,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = attn_implementation, + # attn_implementation = attn_implementation, **kwargs, ) # Return old flag @@ -431,11 +428,12 @@ def from_pretrained( m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) - + if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": + if model.generate.__name__ != "unsloth_base_fast_generate": + model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_base_fast_generate, model) + pass # Post patches model = FastBaseModel.post_patch_model( model,