diff --git a/pyproject.toml b/pyproject.toml index c2cb87ce3b..c860a92db6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.9.3", + "unsloth_zoo>=2025.9.4", "packaging", "tyro", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1", @@ -453,7 +453,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.9.3", + "unsloth_zoo>=2025.9.4", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 597ed0244b..56b98489f6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.9.2" +__version__ = "2025.9.3" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index fc8b584caf..a57deef000 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -83,6 +83,7 @@ global FORCE_FLOAT32 FORCE_FLOAT32 = [ "gemma3", + "gemma3n", "gpt_oss", ] @@ -177,6 +178,8 @@ def from_pretrained( autoconfig_error = None peft_error = None + model_config = None + peft_config = None try: model_config = AutoConfig.from_pretrained( model_name, @@ -200,8 +203,12 @@ def from_pretrained( peft_error = str(error) is_peft = False pass - - # Both config.json and adapter_config.json should not exist! + model_types = get_transformers_model_type(model_config or peft_config) + if len(model_types) == 1: + model_type = model_types[0] + else: + # Leave as tuple if more than one arch + model_type = model_types # Old transformers versions check both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32 @@ -266,8 +273,6 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - model_type = model_config.model_type - if model_type == "llama": scaling_type = None if getattr(model_config, "rope_scaling", None) is not None: @@ -493,10 +498,11 @@ def from_pretrained( from transformers import AutoModelForVision2Seq pass +# Must be alphabetically sorted for each entry DISABLE_COMPILE_MODEL_NAMES = [ - "aya-vision", + "aya_vision", "modernbert", - "granite-vision", + "granite,llava_next", # Granite-vision 3 ] @@ -573,20 +579,61 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + # Check modelscope + if USE_MODELSCOPE and not os.path.exists(model_name): + from modelscope import snapshot_download + model_name = snapshot_download(model_name) + pass + + # First check if it's a normal model via AutoConfig + from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled + was_disabled = are_progress_bars_disabled() + disable_progress_bars() + + autoconfig_error = None + peft_error = None + model_config = None + peft_config = None + try: + model_config = AutoConfig.from_pretrained( + model_name, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, + ) + is_model = True + except Exception as error: + autoconfig_error = str(error) + is_model = False + try: + peft_config = PeftConfig.from_pretrained( + model_name, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, + ) + is_peft = True + except Exception as error: + peft_error = str(error) + is_peft = False + pass + model_types = get_transformers_model_type(model_config or peft_config) + model_types_all = ",".join(model_types) + # Check versions lowered_model_name = model_name.lower() os.environ["UNSLOTH_MODEL_NAME"] = lowered_model_name LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`' # Pixtral - if "pixtral" in lowered_model_name and transformers_version < Version("4.49.0"): + if "pixtral" in model_types_all and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) # Qwen 2.5 - elif "qwen2.5" in lowered_model_name and transformers_version < Version("4.49.0"): + elif "qwen2_5" in model_types_all and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) # Gemma 3 - elif "gemma-3" in lowered_model_name: - if "gemma-3n" in lowered_model_name: + elif "gemma3" in model_types_all: + if "gemma3n" in model_types_all: if transformers_version < Version("4.53.0"): raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" @@ -604,10 +651,10 @@ def from_pretrained( # common in both gemma-3 and gemma-3n os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" # Cohere - elif "c4ai-command-a-03-2025" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): + elif "cohere2" in model_types_all and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) # Sesame - elif "csm-1b" in lowered_model_name: + elif "csm" in model_types_all: os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" # Inference is too slow os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ @@ -615,14 +662,14 @@ def from_pretrained( "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"\ ";" # Granite 4 - elif 'granite-4' in lowered_model_name: + elif 'granitemoehybrid' in model_types_all: # Granite-4 rms norms are stored as 16 bit, but we upcast os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Olmo 2 - elif "olmo-2" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): + elif "olmo2" in model_types_all and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) - elif "falcon-h1" in lowered_model_name: + elif "falcon_h1" in model_types_all: # Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee' # since Mamba kernels error out on using lower precision os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ @@ -630,7 +677,7 @@ def from_pretrained( "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"\ ";"\ "os.environ['TRITON_F32_DEFAULT'] = 'ieee'" - elif "gpt-oss" in lowered_model_name: + elif "gpt_oss" in model_types_all: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" if not load_in_4bit: # Only upcast MoE biases for MXFP4, not BnB @@ -675,44 +722,6 @@ def from_pretrained( os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" pass - if USE_MODELSCOPE and not os.path.exists(model_name): - from modelscope import snapshot_download - model_name = snapshot_download(model_name) - pass - - # First check if it's a normal model via AutoConfig - from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled - was_disabled = are_progress_bars_disabled() - disable_progress_bars() - - autoconfig_error = None - peft_error = None - try: - model_config = AutoConfig.from_pretrained( - model_name, - token = token, - revision = revision, - trust_remote_code = trust_remote_code, - ) - is_model = True - except Exception as error: - autoconfig_error = str(error) - is_model = False - try: - peft_config = PeftConfig.from_pretrained( - model_name, - token = token, - revision = revision, - trust_remote_code = trust_remote_code, - ) - is_peft = True - except Exception as error: - peft_error = str(error) - is_peft = False - pass - - # Both config.json and adapter_config.json should not exist! - # Old transformers versions check both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32 @@ -782,15 +791,7 @@ def from_pretrained( else: redirector = contextlib.redirect_stdout(open(os.devnull, "w")) - # Get model types like Gemma3 etc - model_types = get_transformers_model_type( - model_name = model_name, - token = token, - revision = revision, - trust_remote_code = trust_remote_code, - ) model_types = ["siglip"] + model_types - # Set forced float32 env flag os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False @@ -798,8 +799,8 @@ def from_pretrained( if model_type_arch != "siglip": break global FORCE_FLOAT32 for disable_name in FORCE_FLOAT32: - if (disable_name.lower() == model_type_arch.lower().replace("-", "_") or \ - disable_name.lower() in model_name.lower()) and \ + if (disable_name.lower() == model_type_arch.lower().replace("-", "").replace("_", "") or \ + disable_name.lower() in model_types_all) and \ ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" dtype = torch.bfloat16 # Change to bfloat16 loading @@ -808,7 +809,6 @@ def from_pretrained( # Patch gradient checkpointing if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) - with redirector: patch_loss_functions(torch_compile = False) model_types, supports_sdpa = unsloth_compile_transformers( @@ -845,7 +845,7 @@ def from_pretrained( ) pass # Fix SDPA - if "gemma-3n" in lowered_model_name: + if "gemma3n" in model_types_all: supports_sdpa = False pass diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f342a4d86b..14b75f6746 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -44,6 +44,8 @@ } from trl import __version__ as trl_version +from unsloth_zoo.utils import Version +trl_version = Version(trl_version) def vLLMSamplingParams(**kwargs): from vllm import SamplingParams @@ -804,7 +806,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import " " * 12 + "if (getattr(args, 'use_vllm', False) == False):\n" + \ " " * 16 + "args.use_vllm = True\n" - if "grpo" in trainer_file and trl_version >= "0.18": + if "grpo" in trainer_file and trl_version >= Version("0.18.0"): # If model has vllm_engine, then use vllm in colocate mode. Donot wait for server vllm_setter += \ " " * 12 + "args.vllm_mode='colocate'\n" @@ -850,26 +852,27 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params # Add spaces # count the indentation of last line of sampling_params. - last_line = sampling_params.split("\n")[-1] - last_prev_line = sampling_params.split("\n")[-2] - last_prev_indentation = len(last_prev_line) - len(last_prev_line.lstrip()) - last_indentation = len(last_line) - len(last_line.lstrip()) - - - # Add extra arguments to SamplingParams - extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})" - # Backwards replace - to_replace = ",\n" + " "*last_prev_indentation + extra + ",\n" + " "*last_indentation + ")" - sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) - # Strip multiple commas - sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) - - new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\ - f"\n{' '*8}else:\n" + splitted_sampling_params = sampling_params.split("\n") + if len(splitted_sampling_params) >= 2: + last_line = splitted_sampling_params[-1] + last_prev_line = splitted_sampling_params[-2] + last_prev_indentation = len(last_prev_line) - len(last_prev_line.lstrip()) + last_indentation = len(last_line) - len(last_line.lstrip()) + + # Add extra arguments to SamplingParams + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})" + # Backwards replace + to_replace = ",\n" + " "*last_prev_indentation + extra + ",\n" + " "*last_indentation + ")" + sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) + # Strip multiple commas + sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) + + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\ + f"\n{' '*8}else:\n" pass - if trl_version >= "0.18": + if trl_version >= Version("0.18.0"): # Replace LLM init with already existing vLLM engine for colocate mode vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\(.*?\)*\)\s*?\n(?!,)" vllm_llm_replacement = "self.llm = model.vllm_engine\n" @@ -881,7 +884,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) init = init.replace(vllm_part, new_vllm_part) - pass # Search for vLLM calling in all child functions