diff --git a/tests/test_mlx_runtime_cce_compile.py b/tests/test_mlx_runtime_cce_compile.py new file mode 100644 index 000000000..9168cfe0f --- /dev/null +++ b/tests/test_mlx_runtime_cce_compile.py @@ -0,0 +1,119 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# SPDX-License-Identifier: AGPL-3.0-or-later + +from __future__ import annotations + +import sys + +import pytest + + +mx = pytest.importorskip("mlx.core") +if "mlx_simulation" in str(getattr(mx, "__file__", "")): + pytest.skip("requires real MLX runtime", allow_module_level=True) + + +def _stable_norm(values): + max_abs = mx.array(0.0, dtype=mx.float32) + for value in values: + value32 = value.astype(mx.float32) + max_abs = mx.maximum(max_abs, mx.max(mx.abs(value32))) + + denom = mx.maximum(max_abs, mx.array(1e-30, dtype=mx.float32)) + norm_sq = mx.array(0.0, dtype=mx.float32) + for value in values: + scaled = value.astype(mx.float32) / denom + norm_sq = norm_sq + mx.sum(scaled * scaled) + return denom * mx.sqrt(norm_sq) + + +def _skip_torch_shim(): + if any(name.startswith("mlx_simulation") for name in sys.modules): + pytest.skip("requires real MLX runtime") + + +def test_compiled_runtime_cce_preserves_aux_lse_for_gradients(): + _skip_torch_shim() + from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss + + runtime_cce, _ = make_chunked_cross_entropy_loss( + ignore_index=-100, + chunk_size=32, + ) + hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0 + weight = (mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0) - 1.0 + targets = (mx.arange(64, dtype=mx.int32) * 7) % 128 + targets = mx.where(mx.arange(64) % 11 == 0, -100, targets) + ntoks = mx.maximum( + mx.sum((targets != -100).astype(mx.float32)), + mx.array(1.0, dtype=mx.float32), + ) + + def loss_and_grad_norm(h, w): + def loss_fn(hh, ww): + losses = runtime_cce(hh, ww, targets) + return losses.astype(mx.float32).sum() / ntoks + + loss, grads = mx.value_and_grad(loss_fn, argnums=(0, 1))(h, w) + return loss, _stable_norm(grads) + + eager_loss, eager_norm = loss_and_grad_norm(hidden, weight) + compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden, weight) + mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm) + + assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5) + assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4) + + +def test_compiled_quantized_runtime_cce_preserves_aux_lse_for_gradients(): + _skip_torch_shim() + import mlx.nn as nn + + from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss + + linear = nn.Linear(32, 128, bias=False) + linear.weight = ( + mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0 + ) - 1.0 + qlinear = nn.QuantizedLinear.from_linear(linear, group_size=32, bits=4) + runtime_cce, _ = make_chunked_cross_entropy_loss( + ignore_index=-100, + chunk_size=32, + quantized=True, + group_size=qlinear.group_size, + bits=qlinear.bits, + ) + hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0 + targets = (mx.arange(64, dtype=mx.int32) * 7) % 128 + targets = mx.where(mx.arange(64) % 11 == 0, -100, targets) + ntoks = mx.maximum( + mx.sum((targets != -100).astype(mx.float32)), + mx.array(1.0, dtype=mx.float32), + ) + + def loss_and_grad_norm(h): + def loss_fn(hh): + losses = runtime_cce( + hh, + qlinear.weight, + qlinear.scales, + qlinear.biases, + targets, + ) + return losses.astype(mx.float32).sum() / ntoks + + loss, grad = mx.value_and_grad(loss_fn)(h) + return loss, _stable_norm((grad,)) + + eager_loss, eager_norm = loss_and_grad_norm(hidden) + compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden) + mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm) + + assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5) + assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4) diff --git a/tests/test_pr_a_components.py b/tests/test_pr_a_components.py index 963270f1a..74e4aa750 100644 --- a/tests/test_pr_a_components.py +++ b/tests/test_pr_a_components.py @@ -52,7 +52,7 @@ def _ce_reference(hidden, weight, targets, ignore_index=-100, softcap=0.0): def test_cce_forward_pure_python_matches_reference(): """The kernel-less branch of _forward_chunked_fused_finalize.""" - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hidden_dim, vocab = 4, 8, 32 @@ -72,7 +72,7 @@ def test_cce_forward_pure_python_matches_reference(): def test_cce_with_softcap(): """Softcap path matches reference too.""" - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hidden_dim, vocab = 4, 8, 32 @@ -93,7 +93,7 @@ def test_cce_with_softcap(): def test_cce_with_ignore_index(): """Ignore-index zeros out loss for those positions.""" - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hidden_dim, vocab = 4, 8, 32 @@ -114,7 +114,7 @@ def test_cce_with_ignore_index(): def test_cce_chunked_matches_unchunked(): """Chunked online LSE must equal a single-shot computation.""" - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hidden_dim, vocab = 4, 8, 64 diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 33f8b8fb2..a8b8bc119 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -15,7 +15,7 @@ # along with this program. If not, see . """ -PR-A deeper component exercises: mlx_trainer, mlx_compile discovery, +PR-A deeper component exercises: trainer, compile discovery, cce backward, and quantization helpers — beyond just imports. If a test fails, the failing component identifies the next gap. @@ -40,7 +40,7 @@ def _install_shim(): # --------------------------------------------------------------------------- def test_mlx_training_config_is_dataclass_with_all_fields(): - from unsloth_zoo.mlx_trainer import MLXTrainingConfig + from unsloth_zoo.mlx.trainer import MLXTrainingConfig assert dataclasses.is_dataclass(MLXTrainingConfig) fields = {f.name for f in dataclasses.fields(MLXTrainingConfig)} # Required SFT-compat fields @@ -67,37 +67,102 @@ def test_mlx_training_config_is_dataclass_with_all_fields(): @pytest.mark.parametrize("optim_name", ["adamw", "adam", "sgd", "adafactor"]) def test_mlx_training_config_each_optim(optim_name): """Every PR-A-supported optim string at least constructs cleanly in config.""" - from unsloth_zoo.mlx_trainer import MLXTrainingConfig + from unsloth_zoo.mlx.trainer import MLXTrainingConfig cfg = MLXTrainingConfig(optim=optim_name) assert cfg.optim == optim_name +def test_trainer_drives_dynamic_lr_outside_optimizer_scheduler(): + from unsloth_zoo.mlx.trainer import ( + MLXTrainer, + MLXTrainingConfig, + ) + + trainer = MLXTrainer.__new__(MLXTrainer) + trainer.args = MLXTrainingConfig( + learning_rate=5e-5, + lr_scheduler_type="linear", + warmup_steps=5, + ) + schedule = trainer._build_schedule(total_steps=8) + def value_at(step): + value = schedule(step) + return value.item() if hasattr(value, "item") else float(value) + + assert value_at(0) > 0.0 + assert value_at(4) < trainer.args.learning_rate + assert value_at(5) == pytest.approx(trainer.args.learning_rate) + + trainer.model = object() + optimizer = trainer._build_optimizer(total_steps=8) + assert not callable(optimizer.learning_rate) + first_lr = float(optimizer.learning_rate) + trainer._set_optimizer_lr_for_step(optimizer, 1) + second_lr = float(optimizer.learning_rate) + assert second_lr > first_lr + + +@pytest.mark.parametrize( + ("scheduler", "warmup"), + [ + ("linear", 0), + ("linear", 5), + ("cosine", 0), + ("cosine", 5), + ("constant", 0), + ("constant", 5), + ], +) +def test_scheduler_lr_is_nonzero_for_optimizer_update_steps(scheduler, warmup): + from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig + + total_steps = 8 + trainer = MLXTrainer.__new__(MLXTrainer) + trainer.args = MLXTrainingConfig( + learning_rate=5e-5, + lr_scheduler_type=scheduler, + warmup_steps=warmup, + ) + schedule = trainer._build_schedule(total_steps=total_steps) + + if callable(schedule): + raw_values = [schedule(step) for step in range(total_steps)] + else: + raw_values = [schedule] * total_steps + values = [ + value.item() if hasattr(value, "item") else float(value) + for value in raw_values + ] + + assert all(value > 0.0 for value in values) + + # --------------------------------------------------------------------------- -# 2. mlx_compile module-level discovery functions return sensible defaults +# 2. compile module-level discovery functions return sensible defaults # on a host with no real MLX architectures. # --------------------------------------------------------------------------- -def test_mlx_compile_discovers_no_archs_under_shim(): +def test_compile_discovers_no_archs_under_shim(): """No real mlx_vlm.models.* installed -> empty discovery, not crash.""" - import unsloth_zoo.mlx_compile as mc + import unsloth_zoo.mlx.compile as mc archs = mc.discover_architectures() assert isinstance(archs, tuple) -def test_mlx_compile_patch_primitives_exist(): - import unsloth_zoo.mlx_compile as mc +def test_compile_patch_primitives_exist(): + import unsloth_zoo.mlx.compile as mc primitives = mc.list_compile_patch_primitives() assert len(primitives) > 0 -def test_mlx_compile_protocol_requirements_exist(): - import unsloth_zoo.mlx_compile as mc +def test_compile_protocol_requirements_exist(): + import unsloth_zoo.mlx.compile as mc reqs = mc.list_protocol_requirements() assert len(reqs) > 0 -def test_mlx_compile_summarize_qualifications_returns_dict(): - import unsloth_zoo.mlx_compile as mc +def test_compile_summarize_qualifications_returns_dict(): + import unsloth_zoo.mlx.compile as mc s = mc.summarize_compile_qualifications() assert isinstance(s, dict) assert "architectures" in s @@ -109,7 +174,7 @@ def test_mlx_compile_summarize_qualifications_returns_dict(): def test_cce_backward_via_torch_autograd(): """Build a tiny CCE forward and verify torch.autograd traverses it.""" - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hd, vocab = 4, 8, 32 diff --git a/tests/test_pr_a_dequantize.py b/tests/test_pr_a_dequantize.py index a1b03bb4d..8a98a8a6d 100644 --- a/tests/test_pr_a_dequantize.py +++ b/tests/test_pr_a_dequantize.py @@ -15,7 +15,7 @@ # along with this program. If not, see . """ -PR-A integration: exercise unsloth_zoo.mlx_loader._dequantize_selected_mlx_modules. +PR-A integration: exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules. Builds a synthetic MLX-style model with one QuantizedLinear submodule, runs PR-A's dequantize-and-replace helper, verifies the result is @@ -68,7 +68,7 @@ def test_dequantize_selected_mlx_modules_swap(): """Build a Module with a QuantizedLinear, dequant-replace, verify swap.""" import mlx.core as mx import mlx.nn as nn - from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules + from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules # 4-bit weight: 8 input dims (one group), 2 output dims, group_size=8. bits, group_size = 4, 8 @@ -122,7 +122,7 @@ def __init__(self): def test_dequantize_predicate_filters(): """Predicate should let some QuantizedLinear modules through unchanged.""" import mlx.nn as nn - from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules + from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules bits, group_size = 4, 8 packed, scales, biases = _make_packed_4bit([0]*8, group_size, bits) @@ -145,7 +145,7 @@ def __init__(self): def test_dequantize_no_match_returns_zero(): import mlx.nn as nn - from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules + from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules class Empty(nn.Module): def __init__(self): diff --git a/tests/test_pr_a_imports.py b/tests/test_pr_a_imports.py index 39cdf271e..7eef0ebf8 100644 --- a/tests/test_pr_a_imports.py +++ b/tests/test_pr_a_imports.py @@ -38,12 +38,12 @@ def _install_shim(): # --------------------------------------------------------------------------- @pytest.mark.parametrize("module_path", [ - "unsloth_zoo.mlx_loader", - "unsloth_zoo.mlx_trainer", - "unsloth_zoo.mlx_utils", - "unsloth_zoo.mlx_compile", - "unsloth_zoo.mlx_cce", - "unsloth_zoo.mlx_cce.runtime_cce", + "unsloth_zoo.mlx.loader", + "unsloth_zoo.mlx.trainer", + "unsloth_zoo.mlx.utils", + "unsloth_zoo.mlx.compile", + "unsloth_zoo.mlx.cce", + "unsloth_zoo.mlx.cce.runtime_cce", "unsloth_zoo.gated_delta_vjp", ]) def test_pr_a_module_imports(module_path): @@ -58,48 +58,70 @@ def test_pr_a_module_imports(module_path): # --------------------------------------------------------------------------- def test_fast_mlx_model_class_exists(): - from unsloth_zoo.mlx_loader import FastMLXModel + from unsloth_zoo.mlx.loader import FastMLXModel assert hasattr(FastMLXModel, "from_pretrained") +def test_full_finetune_dtype_default_matches_torch_bf16(): + import mlx.core as mx + from unsloth_zoo.mlx.loader import _resolve_full_finetune_dtype + + assert _resolve_full_finetune_dtype(mx.bfloat16, None, mx) == ( + mx.bfloat16, + False, + ) + assert _resolve_full_finetune_dtype(mx.bfloat16, False, mx) == ( + mx.bfloat16, + False, + ) + assert _resolve_full_finetune_dtype(mx.bfloat16, True, mx) == ( + mx.float32, + True, + ) + assert _resolve_full_finetune_dtype(mx.float16, None, mx) == ( + mx.float32, + True, + ) + + def test_fast_mlx_model_save_helpers_exist(): """PR-B calls model.save_pretrained_merged / save_lora_adapters / push_to_hub_merged on the FastMLXModel INSTANCE returned by FastMLXModel.from_pretrained. The helpers are module-level in - mlx_loader.py and attached via types.MethodType after load. + loader.py and attached via types.MethodType after load. """ - import unsloth_zoo.mlx_loader as ml + import unsloth_zoo.mlx.loader as ml # The free functions must exist: assert hasattr(ml, "_mlx_save_pretrained_merged") assert hasattr(ml, "_mlx_save_lora_adapters") assert hasattr(ml, "_mlx_push_to_hub_merged") - # And the underlying mlx_utils targets: - import unsloth_zoo.mlx_utils as mu + # And the underlying utils targets: + import unsloth_zoo.mlx.utils as mu assert hasattr(mu, "save_pretrained_merged") assert hasattr(mu, "save_lora_adapters") assert hasattr(mu, "push_to_hub_merged") -def test_mlx_trainer_classes(): - from unsloth_zoo.mlx_trainer import ( +def test_trainer_classes(): + from unsloth_zoo.mlx.trainer import ( MLXTrainer, MLXTrainingConfig, ) # train_on_responses_only is the third symbol PR-B imports - import unsloth_zoo.mlx_trainer as mt + import unsloth_zoo.mlx.trainer as mt assert hasattr(mt, "train_on_responses_only") or hasattr(mt, "MLXTrainer") # --------------------------------------------------------------------------- -# 3. mlx_loader: dequantize-and-replace logic surface +# 3. MLX loader: dequantize-and-replace logic surface # --------------------------------------------------------------------------- def test_mlx_loader_dequantize_replace_callable(): """The dequantize-and-replace helper used by FastMLXModel.from_pretrained.""" - import unsloth_zoo.mlx_loader as ml + import unsloth_zoo.mlx.loader as ml # PR-A names this `_dequantize_selected_mlx_modules`. assert hasattr(ml, "_dequantize_selected_mlx_modules"), ( - "expected _dequantize_selected_mlx_modules in mlx_loader. " + "expected _dequantize_selected_mlx_modules in unsloth_zoo.mlx.loader. " f"Got dequant-related: {[a for a in dir(ml) if 'dequant' in a.lower()]}" ) @@ -112,7 +134,7 @@ def test_cce_fallback_path_runs(): """Construct a tiny CCE loss and verify the no-kernel branch fires.""" import torch import mlx.core as mx - from unsloth_zoo.mlx_cce.runtime_cce import _build_kernel_set + from unsloth_zoo.mlx.cce.runtime_cce import _build_kernel_set # With shim, is_available()=False -> kernel set returns (None, None, None) kernels = _build_kernel_set() @@ -124,7 +146,7 @@ def test_cce_forward_chunked_pure_python(): """Run the pure-Python CCE forward directly and verify a finite loss.""" import torch import mlx.core as mx - from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize + from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize torch.manual_seed(0) n, hidden, vocab = 4, 8, 32 @@ -149,11 +171,11 @@ def test_cce_forward_chunked_pure_python(): # --------------------------------------------------------------------------- -# 5. mlx_compile: VLM dispatcher should at minimum import. +# 5. compile: VLM dispatcher should at minimum import. # --------------------------------------------------------------------------- -def test_mlx_compile_import_does_not_error(): - import unsloth_zoo.mlx_compile # full module-level execution +def test_compile_import_does_not_error(): + import unsloth_zoo.mlx.compile # full module-level execution # --------------------------------------------------------------------------- @@ -171,9 +193,9 @@ def test_gated_delta_vjp_imports(): # 7. Optimizer construction: each MLXTrainingConfig optim string maps cleanly. # --------------------------------------------------------------------------- -def test_mlx_trainer_config_smoke(): +def test_trainer_config_smoke(): """MLXTrainingConfig should construct with sane defaults.""" - from unsloth_zoo.mlx_trainer import MLXTrainingConfig + from unsloth_zoo.mlx.trainer import MLXTrainingConfig # Try the default constructor — many MLX configs require keyword args. import dataclasses try: @@ -184,3 +206,21 @@ def test_mlx_trainer_config_smoke(): # We just want the class to be inspectable. ok = dataclasses.is_dataclass(MLXTrainingConfig) or True assert ok + + +def test_adam_optimizers_enable_bias_correction(): + from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig + + class DummyModel: + def trainable_parameters(self): + return {} + + for optim_name in ("adamw", "adam"): + trainer = MLXTrainer( + model=DummyModel(), + tokenizer=None, + train_dataset=[], + args=MLXTrainingConfig(optim=optim_name), + ) + optimizer = trainer._build_optimizer(total_steps=10) + assert optimizer._kw["bias_correction"] is True diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index a5af788ca..bdea4b631 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -90,27 +90,26 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: from importlib.util import find_spec -import platform as _check_platform +from .mlx.runtime import is_mlx_available # Detect Apple Silicon MLX mode: # Either torch is absent (pure MLX), or unsloth already detected MLX -_is_mlx_only = ( - os.environ.get("UNSLOTH_FORCE_GPU_PATH", "0") != "1" - and _check_platform.system() == "Darwin" - and _check_platform.machine() == "arm64" - and find_spec("mlx") is not None -) +_is_mlx_only = is_mlx_available() if _is_mlx_only: # MLX mode: skip all CUDA/torch-specific initialization. os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1" UNSLOTH_ZOO_IS_PRESENT = True - del _is_mlx_only, _check_platform, find_spec + DEVICE_TYPE = "mlx" + DEVICE_TYPE_TORCH = "mps" + DEVICE_COUNT = 1 + ALLOW_PREQUANTIZED_MODELS = True + del _is_mlx_only, is_mlx_available, find_spec # Everything below this point is GPU-only. Use a flag to gate it. _SKIP_GPU_INIT = True else: _SKIP_GPU_INIT = False - del _is_mlx_only, _check_platform + del _is_mlx_only, is_mlx_available # Inject triton & bitsandbytes stubs on Apple Silicon with MLX so unsloth's # CUDA-only imports don't error at startup. _SKIP_GPU_INIT=True is set only @@ -123,6 +122,24 @@ def has_429_exact_full_read(log_dir: str | Path) -> str: _inject_bnb() del _inject_triton, _inject_bnb + # Temporary bridge for already-merged Unsloth code that imports the old + # flat MLX module names. Remove after the paired Unsloth PR lands and + # imports unsloth_zoo.mlx.* everywhere. + import importlib as _importlib + import sys as _sys + + for _old_name, _new_name in ( + ("unsloth_zoo.mlx_loader", "unsloth_zoo.mlx.loader"), + ("unsloth_zoo.mlx_trainer", "unsloth_zoo.mlx.trainer"), + ("unsloth_zoo.mlx_utils", "unsloth_zoo.mlx.utils"), + ("unsloth_zoo.mlx_compile", "unsloth_zoo.mlx.compile"), + ("unsloth_zoo.mlx_cce", "unsloth_zoo.mlx.cce"), + ("unsloth_zoo.mlx_cce.runtime_cce", "unsloth_zoo.mlx.cce.runtime_cce"), + ): + _sys.modules.setdefault(_old_name, _importlib.import_module(_new_name)) + + del _old_name, _new_name, _importlib, _sys + if not _SKIP_GPU_INIT: if find_spec("unsloth") is None: raise ImportError("Please install Unsloth via `pip install unsloth`!") diff --git a/unsloth_zoo/device_type.py b/unsloth_zoo/device_type.py index e677cb770..8a4f96730 100644 --- a/unsloth_zoo/device_type.py +++ b/unsloth_zoo/device_type.py @@ -23,11 +23,12 @@ "device_synchronize", "device_empty_cache", "device_is_bf16_supported", + "is_mlx_available", ] -import torch import functools from .utils import Version +from .mlx.runtime import is_mlx_available import inspect import os import re @@ -35,6 +36,11 @@ import subprocess import urllib.request +_IS_MLX = is_mlx_available() + +if not _IS_MLX: + import torch + _PYTORCH_WHL_BASE_URL = "https://download.pytorch.org/whl" def _safe_run_command(command, timeout = 2.0): @@ -120,6 +126,8 @@ def _nearest_rocm_index(detected_major_minor, available_indices): @functools.cache def _detect_rocm_major_minor(): + if _IS_MLX: + return None # Preferred sources ordered from most direct to fallback. sources = [] hip_version = getattr(getattr(torch, "version", None), "hip", None) @@ -200,11 +208,15 @@ def _amd_installation_hint(): @functools.cache def is_hip(): + if _IS_MLX: + return False return bool(getattr(getattr(torch, "version", None), "hip", None)) pass @functools.cache def get_device_type(): + if _IS_MLX: + return "mlx" if hasattr(torch, "cuda") and torch.cuda.is_available(): if is_hip(): return "hip" @@ -234,6 +246,7 @@ def get_device_type(): # HIP fails for autocast and other torch functions. Use CUDA instead DEVICE_TYPE_TORCH = DEVICE_TYPE if DEVICE_TYPE_TORCH == "hip": DEVICE_TYPE_TORCH = "cuda" +elif DEVICE_TYPE_TORCH == "mlx": DEVICE_TYPE_TORCH = "mps" @functools.cache def get_device_count(): diff --git a/unsloth_zoo/mlx/__init__.py b/unsloth_zoo/mlx/__init__.py new file mode 100644 index 000000000..149fc9991 --- /dev/null +++ b/unsloth_zoo/mlx/__init__.py @@ -0,0 +1,24 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# Keep this package initializer lightweight. The loader and trainer modules +# import MLX libraries and should stay lazy for non-MLX import paths. + +from .runtime import is_mlx_available + +__all__ = [ + "is_mlx_available", +] diff --git a/unsloth_zoo/mlx_cce/__init__.py b/unsloth_zoo/mlx/cce/__init__.py similarity index 100% rename from unsloth_zoo/mlx_cce/__init__.py rename to unsloth_zoo/mlx/cce/__init__.py diff --git a/unsloth_zoo/mlx_cce/runtime_cce.py b/unsloth_zoo/mlx/cce/runtime_cce.py similarity index 97% rename from unsloth_zoo/mlx_cce/runtime_cce.py rename to unsloth_zoo/mlx/cce/runtime_cce.py index e5a2cf43a..464051c97 100644 --- a/unsloth_zoo/mlx_cce/runtime_cce.py +++ b/unsloth_zoo/mlx/cce/runtime_cce.py @@ -768,7 +768,13 @@ def runtime_cce_loss( biases: mx.array, targets: mx.array, ) -> mx.array: - return runtime_cce_loss_full(hidden, weight, scales, biases, targets)[0] + losses, lse = runtime_cce_loss_full( + hidden, weight, scales, biases, targets + ) + # Preserve the public return value/type as losses while keeping + # auxiliary log-sum-exp live for the custom VJP under mx.compile. + # The VJP reads lse from custom-function outputs during backward. + return losses + lse * mx.array(0.0, dtype=mx.float32) return runtime_cce_loss, use_metal_kernel @@ -869,7 +875,11 @@ def runtime_cce_loss_vjp(primals, cotangents, outputs): return grad_hidden.astype(hidden.dtype), grad_weight.astype(weight.dtype), mx.zeros_like(targets) def runtime_cce_loss(hidden: mx.array, weight: mx.array, targets: mx.array) -> mx.array: - return runtime_cce_loss_full(hidden, weight, targets)[0] + losses, lse = runtime_cce_loss_full(hidden, weight, targets) + # Preserve the public return value/type as losses while keeping + # auxiliary log-sum-exp live for the custom VJP under mx.compile. + # The VJP reads lse from custom-function outputs during backward. + return losses + lse * mx.array(0.0, dtype=mx.float32) return runtime_cce_loss, use_metal_kernel diff --git a/unsloth_zoo/mlx_compile.py b/unsloth_zoo/mlx/compile.py similarity index 100% rename from unsloth_zoo/mlx_compile.py rename to unsloth_zoo/mlx/compile.py diff --git a/unsloth_zoo/mlx_loader.py b/unsloth_zoo/mlx/loader.py similarity index 97% rename from unsloth_zoo/mlx_loader.py rename to unsloth_zoo/mlx/loader.py index 73962cb48..9b2151aa0 100644 --- a/unsloth_zoo/mlx_loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -27,13 +27,14 @@ import inspect import math import os +import sys import types import warnings from contextlib import contextmanager from dataclasses import asdict, dataclass from fnmatch import fnmatch -from .mlx_compile import ( +from .compile import ( explain_compile_support, get_compile_qualification, get_compile_trait_report, @@ -644,6 +645,17 @@ def _fp16_needs_bf16_modules(model): return tuple(modules) +def _resolve_full_finetune_dtype(target_dtype, float32_mixed_precision, mx): + if target_dtype == mx.bfloat16: + if type(float32_mixed_precision) is not bool: + # Match the Torch post-patch default: bf16 full finetuning stays + # bf16 unless float32_mixed_precision=True is explicitly requested. + float32_mixed_precision = False + if float32_mixed_precision is False: + return mx.bfloat16, False + return mx.float32, True + + def _patch_mixed_precision_set_dtype(model): """Patch set_dtype so unstable fp16 vision towers keep a safer dtype.""" if getattr(model, "_unsloth_mixed_precision_set_dtype_patched", False): @@ -1826,21 +1838,21 @@ def patched_apply_chat_template( def _mlx_save_pretrained_merged(self, save_directory, tokenizer=None, **kwargs): - from .mlx_utils import save_pretrained_merged + from .utils import save_pretrained_merged tokenizer = tokenizer or self._tokenizer save_pretrained_merged(self, tokenizer, save_directory, **kwargs) def _mlx_save_pretrained_gguf(self, save_directory, tokenizer=None, quantization_method="fast_quantized", **kwargs): - from .mlx_utils import save_pretrained_gguf + from .utils import save_pretrained_gguf tokenizer = tokenizer or self._tokenizer save_pretrained_gguf(self, tokenizer, save_directory, quantization_method=quantization_method) def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None, **kwargs): - from .mlx_utils import push_to_hub_merged + from .utils import push_to_hub_merged tokenizer = tokenizer or self._tokenizer # If save_directory wasn't given, fall back to repo_id (relative dir # named after the repo). Callers that already saved locally should @@ -1851,14 +1863,14 @@ def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None, def _mlx_push_to_hub_gguf(self, repo_id, tokenizer=None, quantization_method="fast_quantized", **kwargs): - from .mlx_utils import push_to_hub_gguf + from .utils import push_to_hub_gguf tokenizer = tokenizer or self._tokenizer push_to_hub_gguf(self, tokenizer, repo_id, repo_id=repo_id, quantization_method=quantization_method, **kwargs) def _mlx_save_lora_adapters(self, path, adapter_config=None): - from .mlx_utils import save_lora_adapters + from .utils import save_lora_adapters save_lora_adapters(self, path, adapter_config=adapter_config) @@ -2066,6 +2078,7 @@ def from_pretrained( patch_mode="patched", revision=None, random_state=3407, + float32_mixed_precision=None, **kwargs, # Accept and ignore GPU-only kwargs ): """Load a model via mlx-lm (text) or mlx-vlm (vision) on Apple Silicon. @@ -2082,8 +2095,11 @@ def from_pretrained( load_in_4bit: Accepted for API compat with CUDA unsloth. full_finetuning: When True, force-disable runtime quantization (``load_in_4bit`` etc.) so the full-precision weights are - trainable. ``get_peft_model`` becomes a no-op for models - loaded this way. + trainable. By default MLX mirrors Unsloth Torch full + finetuning and upcasts trainable floating weights to float32; + pass ``float32_mixed_precision=False`` to keep native bf16 + weights on bf16-capable Apple Silicon. ``get_peft_model`` + becomes a no-op for models loaded this way. token: HuggingFace token for gated models. text_only: Loading mode: None — auto-detect from config (default) @@ -2127,6 +2143,26 @@ def from_pretrained( f"Pass dtype='float16' on M1/M2.", stacklevel=2, ) + if full_finetuning: + original_target_dtype = target_dtype + target_dtype, using_float32_full_ft = _resolve_full_finetune_dtype( + target_dtype, + float32_mixed_precision, + mx, + ) + if not using_float32_full_ft: + print( + "Unsloth: Using bfloat16 MLX full finetuning. " + "This reduces memory but can differ from Unsloth Torch's " + "float32_mixed_precision=True path." + ) + else: + if original_target_dtype != mx.float32: + print( + "Unsloth: Using float32 MLX full finetuning to match " + "Unsloth Torch's explicit float32_mixed_precision=True " + "path." + ) try: from mlx_lm import load as mlx_load from mlx_lm.utils import _download @@ -2338,6 +2374,7 @@ def from_pretrained( patch_mode=patch_mode, revision=adapter_base_revision, random_state=random_state, + float32_mixed_precision=float32_mixed_precision, **( {"mlx_quantization_config": adapter_mlx_quant_config} if adapter_mlx_quant_config is not None @@ -2405,57 +2442,13 @@ def from_pretrained( raise print(f"Unsloth: LoRA adapter detection failed ({e}), falling back to standard load.") - # Step 2: Check unsloth custom loader registry model_type = config_data.get("model_type", "") - try: - from unsloth.models.mlx import get_unsloth_loader - custom_loader = get_unsloth_loader(model_type) - except (ImportError, AttributeError, NotImplementedError): - custom_loader = None - if custom_loader is not None: - with _temporary_hf_token_env(token): - model, tokenizer_or_processor = custom_loader( - model_name, config_data, max_seq_length=max_seq_length, token=token - ) - if text_only is False or _is_vlm(config_data): - from .mlx_utils import normalize_vlm_processor_chat_template - - tokenizer_or_processor = normalize_vlm_processor_chat_template( - tokenizer_or_processor, - chat_template=chat_template, - model_name=model_name, - model_type=model_type, - strict=False, - ) - model._is_vlm_model = True - model._processor = tokenizer_or_processor - _patch_mixed_precision_set_dtype(model) - elif chat_template is not None: - from .mlx_utils import normalize_mlx_chat_template - - tokenizer_or_processor = normalize_mlx_chat_template( - tokenizer_or_processor, - chat_template=chat_template, - model_name=model_name, - model_type=model_type, - is_vlm=False, - strict=False, - ) - model._config = config_data - model._hf_repo = model_name - model._src_path = local_path - model._unsloth_base_revision = revision - model._unsloth_base_commit_hash = _infer_snapshot_commit(local_path) - model.max_seq_length = max_seq_length - model._unsloth_patch_mode = patch_mode - model._unsloth_full_finetuning = bool(full_finetuning) - _patch_mlx_saving(model, tokenizer_or_processor) - return model, tokenizer_or_processor - - # Step 3: Route based on text_only + # Step 2: Route based on text_only is_vlm = False - force_vlm_text_path = bool(text_only is True and _prefer_vlm_loader_for_text(config_data, model_type)) + force_vlm_text_path = bool( + text_only is True and _prefer_vlm_loader_for_text(config_data, model_type) + ) if text_only is True and not force_vlm_text_path: is_vlm = False @@ -2545,7 +2538,7 @@ def from_pretrained( import mlx.core as mx mx.eval(model.parameters()) - from .mlx_utils import ( + from .utils import ( normalize_mlx_chat_template, normalize_vlm_processor_chat_template, ) @@ -2655,7 +2648,7 @@ def from_pretrained( elif want_runtime_quant: import mlx.core as mx mx.eval(model.parameters()) - from .mlx_utils import normalize_mlx_chat_template + from .utils import normalize_mlx_chat_template tokenizer = normalize_mlx_chat_template( tokenizer, @@ -2929,7 +2922,7 @@ def get_peft_model( _apply_gc = bool(use_gradient_checkpointing) if _apply_gc: - from .mlx_utils import apply_gradient_checkpointing + from .utils import apply_gradient_checkpointing apply_gradient_checkpointing(model) import mlx.utils diff --git a/unsloth_zoo/mlx/runtime.py b/unsloth_zoo/mlx/runtime.py new file mode 100644 index 000000000..a77024b62 --- /dev/null +++ b/unsloth_zoo/mlx/runtime.py @@ -0,0 +1,34 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +__all__ = [ + "is_mlx_available", +] + +import functools +import importlib.util +import os +import platform + + +@functools.cache +def is_mlx_available() -> bool: + return ( + os.environ.get("UNSLOTH_FORCE_GPU_PATH", "0") != "1" + and platform.system() == "Darwin" + and platform.machine() == "arm64" + and importlib.util.find_spec("mlx") is not None + ) diff --git a/unsloth_zoo/mlx_trainer.py b/unsloth_zoo/mlx/trainer.py similarity index 88% rename from unsloth_zoo/mlx_trainer.py rename to unsloth_zoo/mlx/trainer.py index d9600f656..de17925d2 100644 --- a/unsloth_zoo/mlx_trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -19,7 +19,7 @@ Usage mirrors TRL notebooks: - from unsloth_zoo.mlx_trainer import MLXTrainer, MLXTrainingConfig + from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig trainer = MLXTrainer( model=model, @@ -52,7 +52,7 @@ SUPPORTED_MLX_OPTIMIZERS = ("adafactor", "adamw", "adam", "sgd", "muon", "lion") SUPPORTED_MLX_LR_SCHEDULERS = ("linear", "cosine", "constant") -from .mlx_utils import ( +from .utils import ( make_cce_loss_fn, make_baseline_loss_fn, make_vlm_cce_loss_fn, @@ -69,7 +69,7 @@ remove_gradient_checkpointing, _is_vlm_model, ) -from .mlx_compile import ( +from .compile import ( build_compile_policy, explain_compile_support, get_compile_qualification, @@ -120,15 +120,9 @@ class MLXTrainingConfig: weight_decay: float = 0.001 max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead # Elementwise clipping (PyTorch's torch.nn.utils.clip_grad_value_). - # Clamps every grad value to [-max_grad_value, max_grad_value] leaf-by-leaf — - # zero extra memory (no cross-leaf reduction). Mirrors mlx-vlm's defaults: - # None → auto-pick from trainable param dtype - # bfloat16 → 5.0 (more dynamic range, looser threshold) - # float16 → 1.0 (less dynamic range, tighter threshold) - # else → 0.0 (off) - # 0.0 → disabled - # >0 → explicit threshold - max_grad_value: float | None = None + # Clamps every grad value to [-max_grad_value, max_grad_value] leaf-by-leaf + # with no cross-leaf reduction. Set 0.0 to disable. + max_grad_value: float | None = 5.0 seed: int = 3407 lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended @@ -238,7 +232,8 @@ def __init__( def add_step_callback(self, fn): """Register a callback called after each logged step. - fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, num_tokens) + fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, + num_tokens, grad_norm=None) """ self._step_callbacks.append(fn) @@ -326,15 +321,17 @@ def _build_schedule(self, total_steps): decay_steps = max(total_steps - warmup, 1) if sched_type == "cosine": - end_lr = lr * 0.1 - main_schedule = optim.cosine_decay(lr, decay_steps, end=end_lr) + main_schedule = optim.cosine_decay(lr, decay_steps, end=0.0) elif sched_type == "linear": main_schedule = optim.linear_schedule(lr, 0.0, decay_steps) else: # constant main_schedule = lr if warmup > 0: - warmup_fn = optim.linear_schedule(0.0, lr, warmup) + def warmup_fn(step): + step = mx.array(step) + step = mx.minimum(step + 1, mx.array(warmup)) + return step * (lr / (warmup + 1)) if callable(main_schedule): return optim.join_schedules( [warmup_fn, main_schedule], [warmup] @@ -347,6 +344,18 @@ def _build_schedule(self, total_steps): return main_schedule + @staticmethod + def _schedule_value(schedule, step): + if callable(schedule): + return schedule(mx.array(step)) + return schedule + + def _set_optimizer_lr_for_step(self, optimizer, step): + schedule = getattr(self, "_lr_schedule", None) + if schedule is None: + return + optimizer.learning_rate = self._schedule_value(schedule, step) + def _build_optimizer(self, total_steps): """Create MLX optimizer with LR schedule from config. @@ -355,6 +364,8 @@ def _build_optimizer(self, total_steps): (matching HuggingFace Trainer behavior). """ schedule = self._build_schedule(total_steps) + initial_lr = self._schedule_value(schedule, 0) + self._lr_schedule = schedule if callable(schedule) else None wd = self.args.weight_decay opt_name = _normalize_mlx_optimizer_name(self.args.optim) @@ -375,20 +386,29 @@ def _build_optimizer(self, total_steps): if opt_name == "adafactor": optimizer = optim.Adafactor( - learning_rate=schedule, + learning_rate=initial_lr, relative_step=False, scale_parameter=False, ) elif opt_name == "adamw": - optimizer = optim.AdamW(learning_rate=schedule, weight_decay=wd) + # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction + # to False, which makes early warmup updates much larger. + optimizer = optim.AdamW( + learning_rate=initial_lr, + weight_decay=wd, + bias_correction=True, + ) elif opt_name == "adam": - optimizer = optim.Adam(learning_rate=schedule) + optimizer = optim.Adam( + learning_rate=initial_lr, + bias_correction=True, + ) elif opt_name == "sgd": - optimizer = optim.SGD(learning_rate=schedule, weight_decay=wd) + optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) elif opt_name == "muon": - optimizer = optim.Muon(learning_rate=schedule, weight_decay=wd) + optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=wd) elif opt_name == "lion": - optimizer = optim.Lion(learning_rate=schedule, weight_decay=wd) + optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=wd) self._resolved_optimizer_name = opt_name return optimizer @@ -608,9 +628,9 @@ def train(self): config = getattr(model, "_config", {}) model_type = config.get("model_type", "") if isinstance(config, dict) else "" if "qwen3_5" in model_type: - from .mlx_loader import _fix_qwen35_attention_cache + from .loader import _fix_qwen35_attention_cache _fix_qwen35_attention_cache(model) - from .gated_delta_vjp import patch_gated_delta + from ..gated_delta_vjp import patch_gated_delta patch_gated_delta() try: @@ -699,31 +719,21 @@ def _train_inner(self): f"(ratio={embedding_lr_ratio:.3f} of main LR {main_lr:.2e}).") _needs_grad_scaling = use_lora_plus or use_embedding_lr + _warned_skip_optimizer_state_grad_norm = False # Build step functions following mlx-lm's pattern - max_grad_norm = args.max_grad_norm + max_grad_norm = float(args.max_grad_norm or 0.0) # Elementwise clip (clip_grad_value_): leaf-local, free memory. - # None → auto-pick by trainable dtype (mlx-vlm defaults: bf16=5, fp16=1). - _raw_mgv = getattr(args, "max_grad_value", None) - if _raw_mgv is None: - _trainable_dtype = None - for _, _v in tree_flatten(model.trainable_parameters()): - if _v.dtype in (mx.bfloat16, mx.float16, mx.float32): - _trainable_dtype = _v.dtype - break - if _trainable_dtype == mx.bfloat16: - max_grad_value = 5.0 - elif _trainable_dtype == mx.float16: - max_grad_value = 1.0 - else: - max_grad_value = 0.0 - if max_grad_value > 0: - print( - f"Unsloth: auto grad value clip = ±{max_grad_value:g} " - f"(trainable dtype: {_trainable_dtype})." - ) - else: - max_grad_value = float(_raw_mgv or 0.0) + # Prefer value clipping when both clipping modes are requested; global + # norm clipping is exact but materially increases memory on MLX. + _raw_mgv = getattr(args, "max_grad_value", 5.0) # TODO: expose MLX grad-clip in Studio UI for power users + max_grad_value = 5.0 if _raw_mgv is None else float(_raw_mgv or 0.0) + if max_grad_norm > 0 and max_grad_value > 0: + print( + "Unsloth: max_grad_norm and max_grad_value are both enabled; " + "ignoring max_grad_norm in favor of max_grad_value." + ) + max_grad_norm = 0.0 _clip_grad_value = max_grad_value > 0 state = [model.state, optimizer.state, mx.random.state] # The direct grad_accum==1 fast path delegates clipping to @@ -755,43 +765,76 @@ def _grad_leaf_scale(name, safe_toks_f, clip_scale=None, dtype=None): scale = scale.astype(dtype) return scale + optimizer_v_sum = None + + def _optimizer_v_total(): + total = mx.array(0.0, dtype=mx.float32) + found = False + for name, value in tree_flatten(getattr(optimizer, "state", {})): + if name != "v" and not name.endswith(".v"): + continue + found = True + value_f = value.astype(mx.float32) + total = total + mx.sum(value_f) + return total if found else None + + def _grad_norm_from_optimizer_state(): + nonlocal optimizer_v_sum + betas = getattr(optimizer, "betas", None) + if not betas or len(betas) < 2: + return None + current_v_sum = _optimizer_v_total() + if current_v_sum is None: + return None + previous_v_sum = ( + optimizer_v_sum + if optimizer_v_sum is not None + else mx.array(0.0, dtype=mx.float32) + ) + beta2 = mx.array(float(betas[1]), dtype=mx.float32) + denom = mx.maximum( + mx.array(1.0, dtype=mx.float32) - beta2, + mx.array(1e-30, dtype=mx.float32), + ) + grad_norm_sq = mx.maximum( + (current_v_sum - beta2 * previous_v_sum) / denom, + mx.array(0.0, dtype=mx.float32), + ) + grad_norm = mx.sqrt(grad_norm_sq) + mx.eval(current_v_sum, grad_norm) + optimizer_v_sum = current_v_sum + return grad_norm + + def _can_report_optimizer_state_norm(): + # For Adam-family optimizers, recover ||g|| from the second moment + # after the update: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2. + # This avoids adding a second consumer to the lazy backward graph. + return getattr(optimizer, "betas", None) + def _apply_update(grad, toks_f): """Common gradient post-processing and optimizer update. - This uses a two-pass exact clip path instead of materializing - multiple full gradient trees. The resulting update is numerically - equivalent to the older implementation while avoiding the extra - tree allocation from ``optim.clip_grad_norm`` on large MLX runs. + Scale accumulated gradients by supervised-token count, then apply + the selected clipping mode. Global norm clipping uses MLX's helper + and reports that norm. Non-global modes report after update from + Adam's optimizer state so the backward graph stays single-consumer. """ safe_toks_f = mx.maximum( toks_f, mx.array(1.0, dtype=mx.float32) ) flat_grad = tree_flatten(grad) - need_norm = max_grad_norm > 0 - grad_norm = mx.array(0.0, dtype=mx.float32) - clip_scale = None - - if need_norm: - norm_sq = mx.array(0.0, dtype=mx.float32) - for name, value in flat_grad: - scaled = value.astype(mx.float32) * _grad_leaf_scale( - name, safe_toks_f - ) - norm_sq = norm_sq + mx.sum(scaled * scaled) - grad_norm = mx.sqrt(norm_sq) - if max_grad_norm > 0: - clip_scale = mx.minimum( - mx.array(max_grad_norm, dtype=mx.float32) / (grad_norm + 1e-6), - mx.array(1.0, dtype=mx.float32), - ) - - final_grad = tree_unflatten([ - ( - name, - value * _grad_leaf_scale(name, safe_toks_f, clip_scale, value.dtype), + grad_norm = None + final_items = [] + for name, value in flat_grad: + scaled = value * _grad_leaf_scale( + name, safe_toks_f, None, value.dtype + ) + final_items.append((name, scaled)) + final_grad = tree_unflatten(final_items) + if max_grad_norm > 0: + final_grad, grad_norm = optim.clip_grad_norm( + final_grad, max_norm=max_grad_norm ) - for name, value in flat_grad - ]) if _clip_grad_value: # Elementwise clip after norm-scaling, before optimizer step. final_grad = tree_map( @@ -799,6 +842,7 @@ def _apply_update(grad, toks_f): final_grad, ) optimizer.update(model, final_grad) + return grad_norm def _apply_update_direct(grad): """Fast exact path for the common ``grad_accum == 1`` case. @@ -813,8 +857,9 @@ def _apply_update_direct(grad): In this case the raw gradients already represent the final per-token average, so we can clip/update the original tree directly. """ + grad_norm = None if max_grad_norm > 0: - grad, _ = optim.clip_grad_norm(grad, max_norm=max_grad_norm) + grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) if _clip_grad_value: # Elementwise clip per leaf — free memory (no cross-leaf reduction). grad = tree_map( @@ -822,6 +867,7 @@ def _apply_update_direct(grad): grad, ) optimizer.update(model, grad) + return grad_norm # Unified step function for both VLM and text training. # VLM: batch_data is a dict → loss_and_grad_fn(model, batch_dict) @@ -833,10 +879,11 @@ def step_fn(batch_data, prev_state, do_update): (lvalue, toks), grad = loss_and_grad_fn(model, batch_data[0], batch_data[1], batch_data[2]) if _direct_single_step_update: - _apply_update_direct(grad) - return lvalue, toks, None + grad_norm = _apply_update_direct(grad) + return lvalue, toks, None, grad_norm toks_f = toks.astype(mx.float32) + grad_norm = mx.array(0.0, dtype=mx.float32) # Scale-and-accumulate in a single tree_map per micro-batch, casting # the scalar to each leaf's dtype so bf16/fp16 grad trees stay in @@ -844,6 +891,12 @@ def step_fn(batch_data, prev_state, do_update): # and forces optimizer.update to promote params/m/v to fp32). if prev_state is not None: prev_grad, prev_toks = prev_state + # Accumulated gradients are optimizer state, not something to + # differentiate through on the next compiled micro-batch. This + # keeps custom VJP losses such as CCE from retaining/corrupting + # the carried bf16 accumulation graph. + prev_grad = tree_map(mx.stop_gradient, prev_grad) + prev_toks = mx.stop_gradient(prev_toks) grad = tree_map( lambda g, p: p + g * toks_f.astype(g.dtype), grad, prev_grad, @@ -856,17 +909,21 @@ def step_fn(batch_data, prev_state, do_update): ) if do_update: - _apply_update(grad, toks_f) - return lvalue, toks, None + grad_norm = _apply_update(grad, toks_f) + return lvalue, toks, None, grad_norm - return lvalue, toks, (grad, toks_f) + grad = tree_map(mx.stop_gradient, grad) + toks_f = mx.stop_gradient(toks_f) + return lvalue, toks, (grad, toks_f), None compile_policy = build_compile_policy(args=args) _compile_decision = getattr(self, "_compile_decision", None) _use_compile = compile_policy.mode != "eager" - # why: text MLX has no compile qualification pipeline; stay eager - # unless UNSLOTH_MLX_TEXT_COMPILE=1. - if not is_vlm and _use_compile and os.environ.get("UNSLOTH_MLX_TEXT_COMPILE", "0") != "1": + if _use_compile and max_grad_norm > 0 and grad_accum > 1: + print( + "Unsloth: mx.compile disabled because MLX global norm " + "clipping is enabled with gradient accumulation." + ) _use_compile = False if is_vlm and _use_compile: qual = getattr(model, "_unsloth_compile_qualification", None) or get_compile_qualification(model) @@ -1006,9 +1063,13 @@ def step_fn(batch_data, prev_state, do_update): batch_idx += 1 do_update = (it % grad_accum == 0) + if do_update: + # Keep callable scheduler evaluation outside mx.compile. The + # compiled step reads the scalar LR already in optimizer state. + self._set_optimizer_lr_for_step(optimizer, it // grad_accum - 1) try: - lvalue, toks, grad_accum_state = step_fn( + lvalue, toks, grad_accum_state, grad_norm = step_fn( batch_data, grad_accum_state, do_update, ) except (ValueError, RuntimeError) as e: @@ -1032,7 +1093,7 @@ def step_fn(batch_data, prev_state, do_update): step_fn = _uncompiled_step_fn _use_compile = False state = [model.state, optimizer.state, mx.random.state] - lvalue, toks, grad_accum_state = step_fn( + lvalue, toks, grad_accum_state, grad_norm = step_fn( batch_data, grad_accum_state, do_update, ) else: @@ -1041,10 +1102,31 @@ def step_fn(batch_data, prev_state, do_update): losses += lvalue * toks n_tokens += toks steps += 1 + if grad_norm is not None: + mx.eval(grad_norm) if grad_accum_state is not None: mx.eval(state, losses, n_tokens, grad_accum_state[0], grad_accum_state[1]) else: mx.eval(state, losses, n_tokens) + if ( + do_update + and grad_norm is None + and max_grad_norm <= 0 + and _can_report_optimizer_state_norm() + ): + grad_norm = _grad_norm_from_optimizer_state() + elif ( + do_update + and grad_norm is None + and max_grad_norm <= 0 + and not _can_report_optimizer_state_norm() + and not _warned_skip_optimizer_state_grad_norm + ): + print( + "Unsloth: skipping grad norm reporting for this MLX " + "optimizer/mode to avoid materializing the gradient graph." + ) + _warned_skip_optimizer_state_grad_norm = True if int(toks.item()) == 0: raise ValueError( "Unsloth MLX: a training batch produced zero supervised " @@ -1070,6 +1152,12 @@ def step_fn(batch_data, prev_state, do_update): peak_mem = mx.get_peak_memory() / 1e9 self._train_loss_history.append(train_loss) + grad_norm_val = ( + float(grad_norm.item()) + if grad_norm is not None else None + ) + if grad_norm_val is not None: + self._grad_norm_history.append(grad_norm_val) self._tokens_per_second_history.append(tokens_sec) self._peak_memory_history.append(peak_mem) self._step_times.append(train_time / steps if steps > 0 else 0) @@ -1082,9 +1170,14 @@ def step_fn(batch_data, prev_state, do_update): elapsed_total = time.perf_counter() - start_time + grad_text = ( + f"Grad: {grad_norm_val:.4f} | " + if grad_norm_val is not None else "" + ) print( f" Step {current_step}/{total_steps} | " f"Loss: {train_loss:.4f} | " + f"{grad_text}" f"LR: {lr_val:.2e} | " f"Tok/s: {tokens_sec:.0f} | " f"Peak: {peak_mem:.2f} GB" @@ -1092,8 +1185,11 @@ def step_fn(batch_data, prev_state, do_update): for cb in self._step_callbacks: try: - cb(current_step, total_steps, train_loss, lr_val, - tokens_sec, peak_mem, elapsed_total, trained_tokens) + cb( + current_step, total_steps, train_loss, lr_val, + tokens_sec, peak_mem, elapsed_total, trained_tokens, + grad_norm_val, + ) except Exception as e: print(f"Unsloth: step callback error: {e}") @@ -1267,7 +1363,7 @@ def _prepare_data(self, is_vlm): def save_model(self, output_dir=None): """Save LoRA adapters or full merged model (if no LoRA).""" - from .mlx_utils import save_merged_model + from .utils import save_merged_model output_dir = output_dir or self.args.output_dir trainable = dict(tree_flatten(self.model.trainable_parameters())) @@ -1288,7 +1384,7 @@ def save_model(self, output_dir=None): break - from .mlx_utils import _get_transformer_layers + from .utils import _get_transformer_layers layers = _get_transformer_layers(self.model) _num_layers = len(layers) if layers else -1 @@ -1590,7 +1686,7 @@ def train_on_responses_only( Returns: The trainer (for chaining), or the masking closure if return_function=True. """ - from .dataset_utils import ( + from ..dataset_utils import ( train_on_responses_only as _hf_train_on_responses_only, ) diff --git a/unsloth_zoo/mlx_utils.py b/unsloth_zoo/mlx/utils.py similarity index 99% rename from unsloth_zoo/mlx_utils.py rename to unsloth_zoo/mlx/utils.py index 0aac7dbd8..3caf75688 100644 --- a/unsloth_zoo/mlx_utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -36,7 +36,7 @@ from pathlib import Path -from .mlx_cce import _get_runtime_cce +from .cce import _get_runtime_cce def _safe_token_denominator(ntoks): @@ -2116,7 +2116,7 @@ def _extract_vlm_images(item, messages, image_size): if not images and isinstance(messages, list): try: - from .vision_utils import process_vision_info + from ..vision_utils import process_vision_info extracted = process_vision_info(messages, return_video_kwargs=True) if isinstance(extracted, tuple) and extracted: @@ -2980,7 +2980,7 @@ def save_pretrained_gguf( compresses it to ``quantization_method``. Pass ``"f32"`` / ``"f16"`` / ``"bf16"`` to force a specific intermediate """ - from .llama_cpp import ( + from ..llama_cpp import ( convert_to_gguf, quantize_gguf, install_llama_cpp, diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index d4d800909..26cecd385 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -21,6 +21,7 @@ import importlib.util from typing import Optional, Tuple from torch.autograd import Function +from unsloth_zoo.mlx import is_mlx_available # Get compile location UNSLOTH_COMPILE_LOCATION = os.environ.get( @@ -224,6 +225,7 @@ def persistent_alloc_fn(size: int, alignment: int, stream): def _check_grouped_gemm_available(): """Check if Unsloth grouped GEMM kernels are available.""" if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False + if is_mlx_available(): return False global _GROUPED_GEMM_AVAILABLE if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE