diff --git a/tests/test_mlx_runtime_cce_compile.py b/tests/test_mlx_runtime_cce_compile.py
new file mode 100644
index 000000000..9168cfe0f
--- /dev/null
+++ b/tests/test_mlx_runtime_cce_compile.py
@@ -0,0 +1,119 @@
+# Unsloth Zoo - Utilities for Unsloth
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+from __future__ import annotations
+
+import sys
+
+import pytest
+
+
+mx = pytest.importorskip("mlx.core")
+if "mlx_simulation" in str(getattr(mx, "__file__", "")):
+ pytest.skip("requires real MLX runtime", allow_module_level=True)
+
+
+def _stable_norm(values):
+ max_abs = mx.array(0.0, dtype=mx.float32)
+ for value in values:
+ value32 = value.astype(mx.float32)
+ max_abs = mx.maximum(max_abs, mx.max(mx.abs(value32)))
+
+ denom = mx.maximum(max_abs, mx.array(1e-30, dtype=mx.float32))
+ norm_sq = mx.array(0.0, dtype=mx.float32)
+ for value in values:
+ scaled = value.astype(mx.float32) / denom
+ norm_sq = norm_sq + mx.sum(scaled * scaled)
+ return denom * mx.sqrt(norm_sq)
+
+
+def _skip_torch_shim():
+ if any(name.startswith("mlx_simulation") for name in sys.modules):
+ pytest.skip("requires real MLX runtime")
+
+
+def test_compiled_runtime_cce_preserves_aux_lse_for_gradients():
+ _skip_torch_shim()
+ from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss
+
+ runtime_cce, _ = make_chunked_cross_entropy_loss(
+ ignore_index=-100,
+ chunk_size=32,
+ )
+ hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0
+ weight = (mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0) - 1.0
+ targets = (mx.arange(64, dtype=mx.int32) * 7) % 128
+ targets = mx.where(mx.arange(64) % 11 == 0, -100, targets)
+ ntoks = mx.maximum(
+ mx.sum((targets != -100).astype(mx.float32)),
+ mx.array(1.0, dtype=mx.float32),
+ )
+
+ def loss_and_grad_norm(h, w):
+ def loss_fn(hh, ww):
+ losses = runtime_cce(hh, ww, targets)
+ return losses.astype(mx.float32).sum() / ntoks
+
+ loss, grads = mx.value_and_grad(loss_fn, argnums=(0, 1))(h, w)
+ return loss, _stable_norm(grads)
+
+ eager_loss, eager_norm = loss_and_grad_norm(hidden, weight)
+ compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden, weight)
+ mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm)
+
+ assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5)
+ assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4)
+
+
+def test_compiled_quantized_runtime_cce_preserves_aux_lse_for_gradients():
+ _skip_torch_shim()
+ import mlx.nn as nn
+
+ from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss
+
+ linear = nn.Linear(32, 128, bias=False)
+ linear.weight = (
+ mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0
+ ) - 1.0
+ qlinear = nn.QuantizedLinear.from_linear(linear, group_size=32, bits=4)
+ runtime_cce, _ = make_chunked_cross_entropy_loss(
+ ignore_index=-100,
+ chunk_size=32,
+ quantized=True,
+ group_size=qlinear.group_size,
+ bits=qlinear.bits,
+ )
+ hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0
+ targets = (mx.arange(64, dtype=mx.int32) * 7) % 128
+ targets = mx.where(mx.arange(64) % 11 == 0, -100, targets)
+ ntoks = mx.maximum(
+ mx.sum((targets != -100).astype(mx.float32)),
+ mx.array(1.0, dtype=mx.float32),
+ )
+
+ def loss_and_grad_norm(h):
+ def loss_fn(hh):
+ losses = runtime_cce(
+ hh,
+ qlinear.weight,
+ qlinear.scales,
+ qlinear.biases,
+ targets,
+ )
+ return losses.astype(mx.float32).sum() / ntoks
+
+ loss, grad = mx.value_and_grad(loss_fn)(h)
+ return loss, _stable_norm((grad,))
+
+ eager_loss, eager_norm = loss_and_grad_norm(hidden)
+ compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden)
+ mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm)
+
+ assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5)
+ assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4)
diff --git a/tests/test_pr_a_components.py b/tests/test_pr_a_components.py
index 963270f1a..74e4aa750 100644
--- a/tests/test_pr_a_components.py
+++ b/tests/test_pr_a_components.py
@@ -52,7 +52,7 @@ def _ce_reference(hidden, weight, targets, ignore_index=-100, softcap=0.0):
def test_cce_forward_pure_python_matches_reference():
"""The kernel-less branch of _forward_chunked_fused_finalize."""
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
@@ -72,7 +72,7 @@ def test_cce_forward_pure_python_matches_reference():
def test_cce_with_softcap():
"""Softcap path matches reference too."""
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
@@ -93,7 +93,7 @@ def test_cce_with_softcap():
def test_cce_with_ignore_index():
"""Ignore-index zeros out loss for those positions."""
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
@@ -114,7 +114,7 @@ def test_cce_with_ignore_index():
def test_cce_chunked_matches_unchunked():
"""Chunked online LSE must equal a single-shot computation."""
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 64
diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py
index 33f8b8fb2..a8b8bc119 100644
--- a/tests/test_pr_a_deep_components.py
+++ b/tests/test_pr_a_deep_components.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
"""
-PR-A deeper component exercises: mlx_trainer, mlx_compile discovery,
+PR-A deeper component exercises: trainer, compile discovery,
cce backward, and quantization helpers — beyond just imports.
If a test fails, the failing component identifies the next gap.
@@ -40,7 +40,7 @@ def _install_shim():
# ---------------------------------------------------------------------------
def test_mlx_training_config_is_dataclass_with_all_fields():
- from unsloth_zoo.mlx_trainer import MLXTrainingConfig
+ from unsloth_zoo.mlx.trainer import MLXTrainingConfig
assert dataclasses.is_dataclass(MLXTrainingConfig)
fields = {f.name for f in dataclasses.fields(MLXTrainingConfig)}
# Required SFT-compat fields
@@ -67,37 +67,102 @@ def test_mlx_training_config_is_dataclass_with_all_fields():
@pytest.mark.parametrize("optim_name", ["adamw", "adam", "sgd", "adafactor"])
def test_mlx_training_config_each_optim(optim_name):
"""Every PR-A-supported optim string at least constructs cleanly in config."""
- from unsloth_zoo.mlx_trainer import MLXTrainingConfig
+ from unsloth_zoo.mlx.trainer import MLXTrainingConfig
cfg = MLXTrainingConfig(optim=optim_name)
assert cfg.optim == optim_name
+def test_trainer_drives_dynamic_lr_outside_optimizer_scheduler():
+ from unsloth_zoo.mlx.trainer import (
+ MLXTrainer,
+ MLXTrainingConfig,
+ )
+
+ trainer = MLXTrainer.__new__(MLXTrainer)
+ trainer.args = MLXTrainingConfig(
+ learning_rate=5e-5,
+ lr_scheduler_type="linear",
+ warmup_steps=5,
+ )
+ schedule = trainer._build_schedule(total_steps=8)
+ def value_at(step):
+ value = schedule(step)
+ return value.item() if hasattr(value, "item") else float(value)
+
+ assert value_at(0) > 0.0
+ assert value_at(4) < trainer.args.learning_rate
+ assert value_at(5) == pytest.approx(trainer.args.learning_rate)
+
+ trainer.model = object()
+ optimizer = trainer._build_optimizer(total_steps=8)
+ assert not callable(optimizer.learning_rate)
+ first_lr = float(optimizer.learning_rate)
+ trainer._set_optimizer_lr_for_step(optimizer, 1)
+ second_lr = float(optimizer.learning_rate)
+ assert second_lr > first_lr
+
+
+@pytest.mark.parametrize(
+ ("scheduler", "warmup"),
+ [
+ ("linear", 0),
+ ("linear", 5),
+ ("cosine", 0),
+ ("cosine", 5),
+ ("constant", 0),
+ ("constant", 5),
+ ],
+)
+def test_scheduler_lr_is_nonzero_for_optimizer_update_steps(scheduler, warmup):
+ from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig
+
+ total_steps = 8
+ trainer = MLXTrainer.__new__(MLXTrainer)
+ trainer.args = MLXTrainingConfig(
+ learning_rate=5e-5,
+ lr_scheduler_type=scheduler,
+ warmup_steps=warmup,
+ )
+ schedule = trainer._build_schedule(total_steps=total_steps)
+
+ if callable(schedule):
+ raw_values = [schedule(step) for step in range(total_steps)]
+ else:
+ raw_values = [schedule] * total_steps
+ values = [
+ value.item() if hasattr(value, "item") else float(value)
+ for value in raw_values
+ ]
+
+ assert all(value > 0.0 for value in values)
+
+
# ---------------------------------------------------------------------------
-# 2. mlx_compile module-level discovery functions return sensible defaults
+# 2. compile module-level discovery functions return sensible defaults
# on a host with no real MLX architectures.
# ---------------------------------------------------------------------------
-def test_mlx_compile_discovers_no_archs_under_shim():
+def test_compile_discovers_no_archs_under_shim():
"""No real mlx_vlm.models.* installed -> empty discovery, not crash."""
- import unsloth_zoo.mlx_compile as mc
+ import unsloth_zoo.mlx.compile as mc
archs = mc.discover_architectures()
assert isinstance(archs, tuple)
-def test_mlx_compile_patch_primitives_exist():
- import unsloth_zoo.mlx_compile as mc
+def test_compile_patch_primitives_exist():
+ import unsloth_zoo.mlx.compile as mc
primitives = mc.list_compile_patch_primitives()
assert len(primitives) > 0
-def test_mlx_compile_protocol_requirements_exist():
- import unsloth_zoo.mlx_compile as mc
+def test_compile_protocol_requirements_exist():
+ import unsloth_zoo.mlx.compile as mc
reqs = mc.list_protocol_requirements()
assert len(reqs) > 0
-def test_mlx_compile_summarize_qualifications_returns_dict():
- import unsloth_zoo.mlx_compile as mc
+def test_compile_summarize_qualifications_returns_dict():
+ import unsloth_zoo.mlx.compile as mc
s = mc.summarize_compile_qualifications()
assert isinstance(s, dict)
assert "architectures" in s
@@ -109,7 +174,7 @@ def test_mlx_compile_summarize_qualifications_returns_dict():
def test_cce_backward_via_torch_autograd():
"""Build a tiny CCE forward and verify torch.autograd traverses it."""
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hd, vocab = 4, 8, 32
diff --git a/tests/test_pr_a_dequantize.py b/tests/test_pr_a_dequantize.py
index a1b03bb4d..8a98a8a6d 100644
--- a/tests/test_pr_a_dequantize.py
+++ b/tests/test_pr_a_dequantize.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
"""
-PR-A integration: exercise unsloth_zoo.mlx_loader._dequantize_selected_mlx_modules.
+PR-A integration: exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules.
Builds a synthetic MLX-style model with one QuantizedLinear submodule,
runs PR-A's dequantize-and-replace helper, verifies the result is
@@ -68,7 +68,7 @@ def test_dequantize_selected_mlx_modules_swap():
"""Build a Module with a QuantizedLinear, dequant-replace, verify swap."""
import mlx.core as mx
import mlx.nn as nn
- from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
+ from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules
# 4-bit weight: 8 input dims (one group), 2 output dims, group_size=8.
bits, group_size = 4, 8
@@ -122,7 +122,7 @@ def __init__(self):
def test_dequantize_predicate_filters():
"""Predicate should let some QuantizedLinear modules through unchanged."""
import mlx.nn as nn
- from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
+ from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules
bits, group_size = 4, 8
packed, scales, biases = _make_packed_4bit([0]*8, group_size, bits)
@@ -145,7 +145,7 @@ def __init__(self):
def test_dequantize_no_match_returns_zero():
import mlx.nn as nn
- from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
+ from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules
class Empty(nn.Module):
def __init__(self):
diff --git a/tests/test_pr_a_imports.py b/tests/test_pr_a_imports.py
index 39cdf271e..7eef0ebf8 100644
--- a/tests/test_pr_a_imports.py
+++ b/tests/test_pr_a_imports.py
@@ -38,12 +38,12 @@ def _install_shim():
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("module_path", [
- "unsloth_zoo.mlx_loader",
- "unsloth_zoo.mlx_trainer",
- "unsloth_zoo.mlx_utils",
- "unsloth_zoo.mlx_compile",
- "unsloth_zoo.mlx_cce",
- "unsloth_zoo.mlx_cce.runtime_cce",
+ "unsloth_zoo.mlx.loader",
+ "unsloth_zoo.mlx.trainer",
+ "unsloth_zoo.mlx.utils",
+ "unsloth_zoo.mlx.compile",
+ "unsloth_zoo.mlx.cce",
+ "unsloth_zoo.mlx.cce.runtime_cce",
"unsloth_zoo.gated_delta_vjp",
])
def test_pr_a_module_imports(module_path):
@@ -58,48 +58,70 @@ def test_pr_a_module_imports(module_path):
# ---------------------------------------------------------------------------
def test_fast_mlx_model_class_exists():
- from unsloth_zoo.mlx_loader import FastMLXModel
+ from unsloth_zoo.mlx.loader import FastMLXModel
assert hasattr(FastMLXModel, "from_pretrained")
+def test_full_finetune_dtype_default_matches_torch_bf16():
+ import mlx.core as mx
+ from unsloth_zoo.mlx.loader import _resolve_full_finetune_dtype
+
+ assert _resolve_full_finetune_dtype(mx.bfloat16, None, mx) == (
+ mx.bfloat16,
+ False,
+ )
+ assert _resolve_full_finetune_dtype(mx.bfloat16, False, mx) == (
+ mx.bfloat16,
+ False,
+ )
+ assert _resolve_full_finetune_dtype(mx.bfloat16, True, mx) == (
+ mx.float32,
+ True,
+ )
+ assert _resolve_full_finetune_dtype(mx.float16, None, mx) == (
+ mx.float32,
+ True,
+ )
+
+
def test_fast_mlx_model_save_helpers_exist():
"""PR-B calls model.save_pretrained_merged / save_lora_adapters /
push_to_hub_merged on the FastMLXModel INSTANCE returned by
FastMLXModel.from_pretrained. The helpers are module-level in
- mlx_loader.py and attached via types.MethodType after load.
+ loader.py and attached via types.MethodType after load.
"""
- import unsloth_zoo.mlx_loader as ml
+ import unsloth_zoo.mlx.loader as ml
# The free functions must exist:
assert hasattr(ml, "_mlx_save_pretrained_merged")
assert hasattr(ml, "_mlx_save_lora_adapters")
assert hasattr(ml, "_mlx_push_to_hub_merged")
- # And the underlying mlx_utils targets:
- import unsloth_zoo.mlx_utils as mu
+ # And the underlying utils targets:
+ import unsloth_zoo.mlx.utils as mu
assert hasattr(mu, "save_pretrained_merged")
assert hasattr(mu, "save_lora_adapters")
assert hasattr(mu, "push_to_hub_merged")
-def test_mlx_trainer_classes():
- from unsloth_zoo.mlx_trainer import (
+def test_trainer_classes():
+ from unsloth_zoo.mlx.trainer import (
MLXTrainer,
MLXTrainingConfig,
)
# train_on_responses_only is the third symbol PR-B imports
- import unsloth_zoo.mlx_trainer as mt
+ import unsloth_zoo.mlx.trainer as mt
assert hasattr(mt, "train_on_responses_only") or hasattr(mt, "MLXTrainer")
# ---------------------------------------------------------------------------
-# 3. mlx_loader: dequantize-and-replace logic surface
+# 3. MLX loader: dequantize-and-replace logic surface
# ---------------------------------------------------------------------------
def test_mlx_loader_dequantize_replace_callable():
"""The dequantize-and-replace helper used by FastMLXModel.from_pretrained."""
- import unsloth_zoo.mlx_loader as ml
+ import unsloth_zoo.mlx.loader as ml
# PR-A names this `_dequantize_selected_mlx_modules`.
assert hasattr(ml, "_dequantize_selected_mlx_modules"), (
- "expected _dequantize_selected_mlx_modules in mlx_loader. "
+ "expected _dequantize_selected_mlx_modules in unsloth_zoo.mlx.loader. "
f"Got dequant-related: {[a for a in dir(ml) if 'dequant' in a.lower()]}"
)
@@ -112,7 +134,7 @@ def test_cce_fallback_path_runs():
"""Construct a tiny CCE loss and verify the no-kernel branch fires."""
import torch
import mlx.core as mx
- from unsloth_zoo.mlx_cce.runtime_cce import _build_kernel_set
+ from unsloth_zoo.mlx.cce.runtime_cce import _build_kernel_set
# With shim, is_available()=False -> kernel set returns (None, None, None)
kernels = _build_kernel_set()
@@ -124,7 +146,7 @@ def test_cce_forward_chunked_pure_python():
"""Run the pure-Python CCE forward directly and verify a finite loss."""
import torch
import mlx.core as mx
- from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
+ from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize
torch.manual_seed(0)
n, hidden, vocab = 4, 8, 32
@@ -149,11 +171,11 @@ def test_cce_forward_chunked_pure_python():
# ---------------------------------------------------------------------------
-# 5. mlx_compile: VLM dispatcher should at minimum import.
+# 5. compile: VLM dispatcher should at minimum import.
# ---------------------------------------------------------------------------
-def test_mlx_compile_import_does_not_error():
- import unsloth_zoo.mlx_compile # full module-level execution
+def test_compile_import_does_not_error():
+ import unsloth_zoo.mlx.compile # full module-level execution
# ---------------------------------------------------------------------------
@@ -171,9 +193,9 @@ def test_gated_delta_vjp_imports():
# 7. Optimizer construction: each MLXTrainingConfig optim string maps cleanly.
# ---------------------------------------------------------------------------
-def test_mlx_trainer_config_smoke():
+def test_trainer_config_smoke():
"""MLXTrainingConfig should construct with sane defaults."""
- from unsloth_zoo.mlx_trainer import MLXTrainingConfig
+ from unsloth_zoo.mlx.trainer import MLXTrainingConfig
# Try the default constructor — many MLX configs require keyword args.
import dataclasses
try:
@@ -184,3 +206,21 @@ def test_mlx_trainer_config_smoke():
# We just want the class to be inspectable.
ok = dataclasses.is_dataclass(MLXTrainingConfig) or True
assert ok
+
+
+def test_adam_optimizers_enable_bias_correction():
+ from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig
+
+ class DummyModel:
+ def trainable_parameters(self):
+ return {}
+
+ for optim_name in ("adamw", "adam"):
+ trainer = MLXTrainer(
+ model=DummyModel(),
+ tokenizer=None,
+ train_dataset=[],
+ args=MLXTrainingConfig(optim=optim_name),
+ )
+ optimizer = trainer._build_optimizer(total_steps=10)
+ assert optimizer._kw["bias_correction"] is True
diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py
index a5af788ca..bdea4b631 100644
--- a/unsloth_zoo/__init__.py
+++ b/unsloth_zoo/__init__.py
@@ -90,27 +90,26 @@ def has_429_exact_full_read(log_dir: str | Path) -> str:
from importlib.util import find_spec
-import platform as _check_platform
+from .mlx.runtime import is_mlx_available
# Detect Apple Silicon MLX mode:
# Either torch is absent (pure MLX), or unsloth already detected MLX
-_is_mlx_only = (
- os.environ.get("UNSLOTH_FORCE_GPU_PATH", "0") != "1"
- and _check_platform.system() == "Darwin"
- and _check_platform.machine() == "arm64"
- and find_spec("mlx") is not None
-)
+_is_mlx_only = is_mlx_available()
if _is_mlx_only:
# MLX mode: skip all CUDA/torch-specific initialization.
os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1"
UNSLOTH_ZOO_IS_PRESENT = True
- del _is_mlx_only, _check_platform, find_spec
+ DEVICE_TYPE = "mlx"
+ DEVICE_TYPE_TORCH = "mps"
+ DEVICE_COUNT = 1
+ ALLOW_PREQUANTIZED_MODELS = True
+ del _is_mlx_only, is_mlx_available, find_spec
# Everything below this point is GPU-only. Use a flag to gate it.
_SKIP_GPU_INIT = True
else:
_SKIP_GPU_INIT = False
- del _is_mlx_only, _check_platform
+ del _is_mlx_only, is_mlx_available
# Inject triton & bitsandbytes stubs on Apple Silicon with MLX so unsloth's
# CUDA-only imports don't error at startup. _SKIP_GPU_INIT=True is set only
@@ -123,6 +122,24 @@ def has_429_exact_full_read(log_dir: str | Path) -> str:
_inject_bnb()
del _inject_triton, _inject_bnb
+ # Temporary bridge for already-merged Unsloth code that imports the old
+ # flat MLX module names. Remove after the paired Unsloth PR lands and
+ # imports unsloth_zoo.mlx.* everywhere.
+ import importlib as _importlib
+ import sys as _sys
+
+ for _old_name, _new_name in (
+ ("unsloth_zoo.mlx_loader", "unsloth_zoo.mlx.loader"),
+ ("unsloth_zoo.mlx_trainer", "unsloth_zoo.mlx.trainer"),
+ ("unsloth_zoo.mlx_utils", "unsloth_zoo.mlx.utils"),
+ ("unsloth_zoo.mlx_compile", "unsloth_zoo.mlx.compile"),
+ ("unsloth_zoo.mlx_cce", "unsloth_zoo.mlx.cce"),
+ ("unsloth_zoo.mlx_cce.runtime_cce", "unsloth_zoo.mlx.cce.runtime_cce"),
+ ):
+ _sys.modules.setdefault(_old_name, _importlib.import_module(_new_name))
+
+ del _old_name, _new_name, _importlib, _sys
+
if not _SKIP_GPU_INIT:
if find_spec("unsloth") is None:
raise ImportError("Please install Unsloth via `pip install unsloth`!")
diff --git a/unsloth_zoo/device_type.py b/unsloth_zoo/device_type.py
index e677cb770..8a4f96730 100644
--- a/unsloth_zoo/device_type.py
+++ b/unsloth_zoo/device_type.py
@@ -23,11 +23,12 @@
"device_synchronize",
"device_empty_cache",
"device_is_bf16_supported",
+ "is_mlx_available",
]
-import torch
import functools
from .utils import Version
+from .mlx.runtime import is_mlx_available
import inspect
import os
import re
@@ -35,6 +36,11 @@
import subprocess
import urllib.request
+_IS_MLX = is_mlx_available()
+
+if not _IS_MLX:
+ import torch
+
_PYTORCH_WHL_BASE_URL = "https://download.pytorch.org/whl"
def _safe_run_command(command, timeout = 2.0):
@@ -120,6 +126,8 @@ def _nearest_rocm_index(detected_major_minor, available_indices):
@functools.cache
def _detect_rocm_major_minor():
+ if _IS_MLX:
+ return None
# Preferred sources ordered from most direct to fallback.
sources = []
hip_version = getattr(getattr(torch, "version", None), "hip", None)
@@ -200,11 +208,15 @@ def _amd_installation_hint():
@functools.cache
def is_hip():
+ if _IS_MLX:
+ return False
return bool(getattr(getattr(torch, "version", None), "hip", None))
pass
@functools.cache
def get_device_type():
+ if _IS_MLX:
+ return "mlx"
if hasattr(torch, "cuda") and torch.cuda.is_available():
if is_hip():
return "hip"
@@ -234,6 +246,7 @@ def get_device_type():
# HIP fails for autocast and other torch functions. Use CUDA instead
DEVICE_TYPE_TORCH = DEVICE_TYPE
if DEVICE_TYPE_TORCH == "hip": DEVICE_TYPE_TORCH = "cuda"
+elif DEVICE_TYPE_TORCH == "mlx": DEVICE_TYPE_TORCH = "mps"
@functools.cache
def get_device_count():
diff --git a/unsloth_zoo/mlx/__init__.py b/unsloth_zoo/mlx/__init__.py
new file mode 100644
index 000000000..149fc9991
--- /dev/null
+++ b/unsloth_zoo/mlx/__init__.py
@@ -0,0 +1,24 @@
+# Unsloth Zoo - Utilities for Unsloth
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+#
+# Keep this package initializer lightweight. The loader and trainer modules
+# import MLX libraries and should stay lazy for non-MLX import paths.
+
+from .runtime import is_mlx_available
+
+__all__ = [
+ "is_mlx_available",
+]
diff --git a/unsloth_zoo/mlx_cce/__init__.py b/unsloth_zoo/mlx/cce/__init__.py
similarity index 100%
rename from unsloth_zoo/mlx_cce/__init__.py
rename to unsloth_zoo/mlx/cce/__init__.py
diff --git a/unsloth_zoo/mlx_cce/runtime_cce.py b/unsloth_zoo/mlx/cce/runtime_cce.py
similarity index 97%
rename from unsloth_zoo/mlx_cce/runtime_cce.py
rename to unsloth_zoo/mlx/cce/runtime_cce.py
index e5a2cf43a..464051c97 100644
--- a/unsloth_zoo/mlx_cce/runtime_cce.py
+++ b/unsloth_zoo/mlx/cce/runtime_cce.py
@@ -768,7 +768,13 @@ def runtime_cce_loss(
biases: mx.array,
targets: mx.array,
) -> mx.array:
- return runtime_cce_loss_full(hidden, weight, scales, biases, targets)[0]
+ losses, lse = runtime_cce_loss_full(
+ hidden, weight, scales, biases, targets
+ )
+ # Preserve the public return value/type as losses while keeping
+ # auxiliary log-sum-exp live for the custom VJP under mx.compile.
+ # The VJP reads lse from custom-function outputs during backward.
+ return losses + lse * mx.array(0.0, dtype=mx.float32)
return runtime_cce_loss, use_metal_kernel
@@ -869,7 +875,11 @@ def runtime_cce_loss_vjp(primals, cotangents, outputs):
return grad_hidden.astype(hidden.dtype), grad_weight.astype(weight.dtype), mx.zeros_like(targets)
def runtime_cce_loss(hidden: mx.array, weight: mx.array, targets: mx.array) -> mx.array:
- return runtime_cce_loss_full(hidden, weight, targets)[0]
+ losses, lse = runtime_cce_loss_full(hidden, weight, targets)
+ # Preserve the public return value/type as losses while keeping
+ # auxiliary log-sum-exp live for the custom VJP under mx.compile.
+ # The VJP reads lse from custom-function outputs during backward.
+ return losses + lse * mx.array(0.0, dtype=mx.float32)
return runtime_cce_loss, use_metal_kernel
diff --git a/unsloth_zoo/mlx_compile.py b/unsloth_zoo/mlx/compile.py
similarity index 100%
rename from unsloth_zoo/mlx_compile.py
rename to unsloth_zoo/mlx/compile.py
diff --git a/unsloth_zoo/mlx_loader.py b/unsloth_zoo/mlx/loader.py
similarity index 97%
rename from unsloth_zoo/mlx_loader.py
rename to unsloth_zoo/mlx/loader.py
index 73962cb48..9b2151aa0 100644
--- a/unsloth_zoo/mlx_loader.py
+++ b/unsloth_zoo/mlx/loader.py
@@ -27,13 +27,14 @@
import inspect
import math
import os
+import sys
import types
import warnings
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from fnmatch import fnmatch
-from .mlx_compile import (
+from .compile import (
explain_compile_support,
get_compile_qualification,
get_compile_trait_report,
@@ -644,6 +645,17 @@ def _fp16_needs_bf16_modules(model):
return tuple(modules)
+def _resolve_full_finetune_dtype(target_dtype, float32_mixed_precision, mx):
+ if target_dtype == mx.bfloat16:
+ if type(float32_mixed_precision) is not bool:
+ # Match the Torch post-patch default: bf16 full finetuning stays
+ # bf16 unless float32_mixed_precision=True is explicitly requested.
+ float32_mixed_precision = False
+ if float32_mixed_precision is False:
+ return mx.bfloat16, False
+ return mx.float32, True
+
+
def _patch_mixed_precision_set_dtype(model):
"""Patch set_dtype so unstable fp16 vision towers keep a safer dtype."""
if getattr(model, "_unsloth_mixed_precision_set_dtype_patched", False):
@@ -1826,21 +1838,21 @@ def patched_apply_chat_template(
def _mlx_save_pretrained_merged(self, save_directory, tokenizer=None, **kwargs):
- from .mlx_utils import save_pretrained_merged
+ from .utils import save_pretrained_merged
tokenizer = tokenizer or self._tokenizer
save_pretrained_merged(self, tokenizer, save_directory, **kwargs)
def _mlx_save_pretrained_gguf(self, save_directory, tokenizer=None,
quantization_method="fast_quantized", **kwargs):
- from .mlx_utils import save_pretrained_gguf
+ from .utils import save_pretrained_gguf
tokenizer = tokenizer or self._tokenizer
save_pretrained_gguf(self, tokenizer, save_directory,
quantization_method=quantization_method)
def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None, **kwargs):
- from .mlx_utils import push_to_hub_merged
+ from .utils import push_to_hub_merged
tokenizer = tokenizer or self._tokenizer
# If save_directory wasn't given, fall back to repo_id (relative dir
# named after the repo). Callers that already saved locally should
@@ -1851,14 +1863,14 @@ def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None,
def _mlx_push_to_hub_gguf(self, repo_id, tokenizer=None,
quantization_method="fast_quantized", **kwargs):
- from .mlx_utils import push_to_hub_gguf
+ from .utils import push_to_hub_gguf
tokenizer = tokenizer or self._tokenizer
push_to_hub_gguf(self, tokenizer, repo_id, repo_id=repo_id,
quantization_method=quantization_method, **kwargs)
def _mlx_save_lora_adapters(self, path, adapter_config=None):
- from .mlx_utils import save_lora_adapters
+ from .utils import save_lora_adapters
save_lora_adapters(self, path, adapter_config=adapter_config)
@@ -2066,6 +2078,7 @@ def from_pretrained(
patch_mode="patched",
revision=None,
random_state=3407,
+ float32_mixed_precision=None,
**kwargs, # Accept and ignore GPU-only kwargs
):
"""Load a model via mlx-lm (text) or mlx-vlm (vision) on Apple Silicon.
@@ -2082,8 +2095,11 @@ def from_pretrained(
load_in_4bit: Accepted for API compat with CUDA unsloth.
full_finetuning: When True, force-disable runtime quantization
(``load_in_4bit`` etc.) so the full-precision weights are
- trainable. ``get_peft_model`` becomes a no-op for models
- loaded this way.
+ trainable. By default MLX mirrors Unsloth Torch full
+ finetuning and upcasts trainable floating weights to float32;
+ pass ``float32_mixed_precision=False`` to keep native bf16
+ weights on bf16-capable Apple Silicon. ``get_peft_model``
+ becomes a no-op for models loaded this way.
token: HuggingFace token for gated models.
text_only: Loading mode:
None — auto-detect from config (default)
@@ -2127,6 +2143,26 @@ def from_pretrained(
f"Pass dtype='float16' on M1/M2.",
stacklevel=2,
)
+ if full_finetuning:
+ original_target_dtype = target_dtype
+ target_dtype, using_float32_full_ft = _resolve_full_finetune_dtype(
+ target_dtype,
+ float32_mixed_precision,
+ mx,
+ )
+ if not using_float32_full_ft:
+ print(
+ "Unsloth: Using bfloat16 MLX full finetuning. "
+ "This reduces memory but can differ from Unsloth Torch's "
+ "float32_mixed_precision=True path."
+ )
+ else:
+ if original_target_dtype != mx.float32:
+ print(
+ "Unsloth: Using float32 MLX full finetuning to match "
+ "Unsloth Torch's explicit float32_mixed_precision=True "
+ "path."
+ )
try:
from mlx_lm import load as mlx_load
from mlx_lm.utils import _download
@@ -2338,6 +2374,7 @@ def from_pretrained(
patch_mode=patch_mode,
revision=adapter_base_revision,
random_state=random_state,
+ float32_mixed_precision=float32_mixed_precision,
**(
{"mlx_quantization_config": adapter_mlx_quant_config}
if adapter_mlx_quant_config is not None
@@ -2405,57 +2442,13 @@ def from_pretrained(
raise
print(f"Unsloth: LoRA adapter detection failed ({e}), falling back to standard load.")
- # Step 2: Check unsloth custom loader registry
model_type = config_data.get("model_type", "")
- try:
- from unsloth.models.mlx import get_unsloth_loader
- custom_loader = get_unsloth_loader(model_type)
- except (ImportError, AttributeError, NotImplementedError):
- custom_loader = None
- if custom_loader is not None:
- with _temporary_hf_token_env(token):
- model, tokenizer_or_processor = custom_loader(
- model_name, config_data, max_seq_length=max_seq_length, token=token
- )
- if text_only is False or _is_vlm(config_data):
- from .mlx_utils import normalize_vlm_processor_chat_template
-
- tokenizer_or_processor = normalize_vlm_processor_chat_template(
- tokenizer_or_processor,
- chat_template=chat_template,
- model_name=model_name,
- model_type=model_type,
- strict=False,
- )
- model._is_vlm_model = True
- model._processor = tokenizer_or_processor
- _patch_mixed_precision_set_dtype(model)
- elif chat_template is not None:
- from .mlx_utils import normalize_mlx_chat_template
-
- tokenizer_or_processor = normalize_mlx_chat_template(
- tokenizer_or_processor,
- chat_template=chat_template,
- model_name=model_name,
- model_type=model_type,
- is_vlm=False,
- strict=False,
- )
- model._config = config_data
- model._hf_repo = model_name
- model._src_path = local_path
- model._unsloth_base_revision = revision
- model._unsloth_base_commit_hash = _infer_snapshot_commit(local_path)
- model.max_seq_length = max_seq_length
- model._unsloth_patch_mode = patch_mode
- model._unsloth_full_finetuning = bool(full_finetuning)
- _patch_mlx_saving(model, tokenizer_or_processor)
- return model, tokenizer_or_processor
-
- # Step 3: Route based on text_only
+ # Step 2: Route based on text_only
is_vlm = False
- force_vlm_text_path = bool(text_only is True and _prefer_vlm_loader_for_text(config_data, model_type))
+ force_vlm_text_path = bool(
+ text_only is True and _prefer_vlm_loader_for_text(config_data, model_type)
+ )
if text_only is True and not force_vlm_text_path:
is_vlm = False
@@ -2545,7 +2538,7 @@ def from_pretrained(
import mlx.core as mx
mx.eval(model.parameters())
- from .mlx_utils import (
+ from .utils import (
normalize_mlx_chat_template,
normalize_vlm_processor_chat_template,
)
@@ -2655,7 +2648,7 @@ def from_pretrained(
elif want_runtime_quant:
import mlx.core as mx
mx.eval(model.parameters())
- from .mlx_utils import normalize_mlx_chat_template
+ from .utils import normalize_mlx_chat_template
tokenizer = normalize_mlx_chat_template(
tokenizer,
@@ -2929,7 +2922,7 @@ def get_peft_model(
_apply_gc = bool(use_gradient_checkpointing)
if _apply_gc:
- from .mlx_utils import apply_gradient_checkpointing
+ from .utils import apply_gradient_checkpointing
apply_gradient_checkpointing(model)
import mlx.utils
diff --git a/unsloth_zoo/mlx/runtime.py b/unsloth_zoo/mlx/runtime.py
new file mode 100644
index 000000000..a77024b62
--- /dev/null
+++ b/unsloth_zoo/mlx/runtime.py
@@ -0,0 +1,34 @@
+# Unsloth Zoo - Utilities for Unsloth
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+__all__ = [
+ "is_mlx_available",
+]
+
+import functools
+import importlib.util
+import os
+import platform
+
+
+@functools.cache
+def is_mlx_available() -> bool:
+ return (
+ os.environ.get("UNSLOTH_FORCE_GPU_PATH", "0") != "1"
+ and platform.system() == "Darwin"
+ and platform.machine() == "arm64"
+ and importlib.util.find_spec("mlx") is not None
+ )
diff --git a/unsloth_zoo/mlx_trainer.py b/unsloth_zoo/mlx/trainer.py
similarity index 88%
rename from unsloth_zoo/mlx_trainer.py
rename to unsloth_zoo/mlx/trainer.py
index d9600f656..de17925d2 100644
--- a/unsloth_zoo/mlx_trainer.py
+++ b/unsloth_zoo/mlx/trainer.py
@@ -19,7 +19,7 @@
Usage mirrors TRL notebooks:
- from unsloth_zoo.mlx_trainer import MLXTrainer, MLXTrainingConfig
+ from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig
trainer = MLXTrainer(
model=model,
@@ -52,7 +52,7 @@
SUPPORTED_MLX_OPTIMIZERS = ("adafactor", "adamw", "adam", "sgd", "muon", "lion")
SUPPORTED_MLX_LR_SCHEDULERS = ("linear", "cosine", "constant")
-from .mlx_utils import (
+from .utils import (
make_cce_loss_fn,
make_baseline_loss_fn,
make_vlm_cce_loss_fn,
@@ -69,7 +69,7 @@
remove_gradient_checkpointing,
_is_vlm_model,
)
-from .mlx_compile import (
+from .compile import (
build_compile_policy,
explain_compile_support,
get_compile_qualification,
@@ -120,15 +120,9 @@ class MLXTrainingConfig:
weight_decay: float = 0.001
max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead
# Elementwise clipping (PyTorch's torch.nn.utils.clip_grad_value_).
- # Clamps every grad value to [-max_grad_value, max_grad_value] leaf-by-leaf —
- # zero extra memory (no cross-leaf reduction). Mirrors mlx-vlm's defaults:
- # None → auto-pick from trainable param dtype
- # bfloat16 → 5.0 (more dynamic range, looser threshold)
- # float16 → 1.0 (less dynamic range, tighter threshold)
- # else → 0.0 (off)
- # 0.0 → disabled
- # >0 → explicit threshold
- max_grad_value: float | None = None
+ # Clamps every grad value to [-max_grad_value, max_grad_value] leaf-by-leaf
+ # with no cross-leaf reduction. Set 0.0 to disable.
+ max_grad_value: float | None = 5.0
seed: int = 3407
lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended
embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended
@@ -238,7 +232,8 @@ def __init__(
def add_step_callback(self, fn):
"""Register a callback called after each logged step.
- fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, num_tokens)
+ fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed,
+ num_tokens, grad_norm=None)
"""
self._step_callbacks.append(fn)
@@ -326,15 +321,17 @@ def _build_schedule(self, total_steps):
decay_steps = max(total_steps - warmup, 1)
if sched_type == "cosine":
- end_lr = lr * 0.1
- main_schedule = optim.cosine_decay(lr, decay_steps, end=end_lr)
+ main_schedule = optim.cosine_decay(lr, decay_steps, end=0.0)
elif sched_type == "linear":
main_schedule = optim.linear_schedule(lr, 0.0, decay_steps)
else: # constant
main_schedule = lr
if warmup > 0:
- warmup_fn = optim.linear_schedule(0.0, lr, warmup)
+ def warmup_fn(step):
+ step = mx.array(step)
+ step = mx.minimum(step + 1, mx.array(warmup))
+ return step * (lr / (warmup + 1))
if callable(main_schedule):
return optim.join_schedules(
[warmup_fn, main_schedule], [warmup]
@@ -347,6 +344,18 @@ def _build_schedule(self, total_steps):
return main_schedule
+ @staticmethod
+ def _schedule_value(schedule, step):
+ if callable(schedule):
+ return schedule(mx.array(step))
+ return schedule
+
+ def _set_optimizer_lr_for_step(self, optimizer, step):
+ schedule = getattr(self, "_lr_schedule", None)
+ if schedule is None:
+ return
+ optimizer.learning_rate = self._schedule_value(schedule, step)
+
def _build_optimizer(self, total_steps):
"""Create MLX optimizer with LR schedule from config.
@@ -355,6 +364,8 @@ def _build_optimizer(self, total_steps):
(matching HuggingFace Trainer behavior).
"""
schedule = self._build_schedule(total_steps)
+ initial_lr = self._schedule_value(schedule, 0)
+ self._lr_schedule = schedule if callable(schedule) else None
wd = self.args.weight_decay
opt_name = _normalize_mlx_optimizer_name(self.args.optim)
@@ -375,20 +386,29 @@ def _build_optimizer(self, total_steps):
if opt_name == "adafactor":
optimizer = optim.Adafactor(
- learning_rate=schedule,
+ learning_rate=initial_lr,
relative_step=False,
scale_parameter=False,
)
elif opt_name == "adamw":
- optimizer = optim.AdamW(learning_rate=schedule, weight_decay=wd)
+ # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction
+ # to False, which makes early warmup updates much larger.
+ optimizer = optim.AdamW(
+ learning_rate=initial_lr,
+ weight_decay=wd,
+ bias_correction=True,
+ )
elif opt_name == "adam":
- optimizer = optim.Adam(learning_rate=schedule)
+ optimizer = optim.Adam(
+ learning_rate=initial_lr,
+ bias_correction=True,
+ )
elif opt_name == "sgd":
- optimizer = optim.SGD(learning_rate=schedule, weight_decay=wd)
+ optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd)
elif opt_name == "muon":
- optimizer = optim.Muon(learning_rate=schedule, weight_decay=wd)
+ optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=wd)
elif opt_name == "lion":
- optimizer = optim.Lion(learning_rate=schedule, weight_decay=wd)
+ optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=wd)
self._resolved_optimizer_name = opt_name
return optimizer
@@ -608,9 +628,9 @@ def train(self):
config = getattr(model, "_config", {})
model_type = config.get("model_type", "") if isinstance(config, dict) else ""
if "qwen3_5" in model_type:
- from .mlx_loader import _fix_qwen35_attention_cache
+ from .loader import _fix_qwen35_attention_cache
_fix_qwen35_attention_cache(model)
- from .gated_delta_vjp import patch_gated_delta
+ from ..gated_delta_vjp import patch_gated_delta
patch_gated_delta()
try:
@@ -699,31 +719,21 @@ def _train_inner(self):
f"(ratio={embedding_lr_ratio:.3f} of main LR {main_lr:.2e}).")
_needs_grad_scaling = use_lora_plus or use_embedding_lr
+ _warned_skip_optimizer_state_grad_norm = False
# Build step functions following mlx-lm's pattern
- max_grad_norm = args.max_grad_norm
+ max_grad_norm = float(args.max_grad_norm or 0.0)
# Elementwise clip (clip_grad_value_): leaf-local, free memory.
- # None → auto-pick by trainable dtype (mlx-vlm defaults: bf16=5, fp16=1).
- _raw_mgv = getattr(args, "max_grad_value", None)
- if _raw_mgv is None:
- _trainable_dtype = None
- for _, _v in tree_flatten(model.trainable_parameters()):
- if _v.dtype in (mx.bfloat16, mx.float16, mx.float32):
- _trainable_dtype = _v.dtype
- break
- if _trainable_dtype == mx.bfloat16:
- max_grad_value = 5.0
- elif _trainable_dtype == mx.float16:
- max_grad_value = 1.0
- else:
- max_grad_value = 0.0
- if max_grad_value > 0:
- print(
- f"Unsloth: auto grad value clip = ±{max_grad_value:g} "
- f"(trainable dtype: {_trainable_dtype})."
- )
- else:
- max_grad_value = float(_raw_mgv or 0.0)
+ # Prefer value clipping when both clipping modes are requested; global
+ # norm clipping is exact but materially increases memory on MLX.
+ _raw_mgv = getattr(args, "max_grad_value", 5.0) # TODO: expose MLX grad-clip in Studio UI for power users
+ max_grad_value = 5.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
+ if max_grad_norm > 0 and max_grad_value > 0:
+ print(
+ "Unsloth: max_grad_norm and max_grad_value are both enabled; "
+ "ignoring max_grad_norm in favor of max_grad_value."
+ )
+ max_grad_norm = 0.0
_clip_grad_value = max_grad_value > 0
state = [model.state, optimizer.state, mx.random.state]
# The direct grad_accum==1 fast path delegates clipping to
@@ -755,43 +765,76 @@ def _grad_leaf_scale(name, safe_toks_f, clip_scale=None, dtype=None):
scale = scale.astype(dtype)
return scale
+ optimizer_v_sum = None
+
+ def _optimizer_v_total():
+ total = mx.array(0.0, dtype=mx.float32)
+ found = False
+ for name, value in tree_flatten(getattr(optimizer, "state", {})):
+ if name != "v" and not name.endswith(".v"):
+ continue
+ found = True
+ value_f = value.astype(mx.float32)
+ total = total + mx.sum(value_f)
+ return total if found else None
+
+ def _grad_norm_from_optimizer_state():
+ nonlocal optimizer_v_sum
+ betas = getattr(optimizer, "betas", None)
+ if not betas or len(betas) < 2:
+ return None
+ current_v_sum = _optimizer_v_total()
+ if current_v_sum is None:
+ return None
+ previous_v_sum = (
+ optimizer_v_sum
+ if optimizer_v_sum is not None
+ else mx.array(0.0, dtype=mx.float32)
+ )
+ beta2 = mx.array(float(betas[1]), dtype=mx.float32)
+ denom = mx.maximum(
+ mx.array(1.0, dtype=mx.float32) - beta2,
+ mx.array(1e-30, dtype=mx.float32),
+ )
+ grad_norm_sq = mx.maximum(
+ (current_v_sum - beta2 * previous_v_sum) / denom,
+ mx.array(0.0, dtype=mx.float32),
+ )
+ grad_norm = mx.sqrt(grad_norm_sq)
+ mx.eval(current_v_sum, grad_norm)
+ optimizer_v_sum = current_v_sum
+ return grad_norm
+
+ def _can_report_optimizer_state_norm():
+ # For Adam-family optimizers, recover ||g|| from the second moment
+ # after the update: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2.
+ # This avoids adding a second consumer to the lazy backward graph.
+ return getattr(optimizer, "betas", None)
+
def _apply_update(grad, toks_f):
"""Common gradient post-processing and optimizer update.
- This uses a two-pass exact clip path instead of materializing
- multiple full gradient trees. The resulting update is numerically
- equivalent to the older implementation while avoiding the extra
- tree allocation from ``optim.clip_grad_norm`` on large MLX runs.
+ Scale accumulated gradients by supervised-token count, then apply
+ the selected clipping mode. Global norm clipping uses MLX's helper
+ and reports that norm. Non-global modes report after update from
+ Adam's optimizer state so the backward graph stays single-consumer.
"""
safe_toks_f = mx.maximum(
toks_f, mx.array(1.0, dtype=mx.float32)
)
flat_grad = tree_flatten(grad)
- need_norm = max_grad_norm > 0
- grad_norm = mx.array(0.0, dtype=mx.float32)
- clip_scale = None
-
- if need_norm:
- norm_sq = mx.array(0.0, dtype=mx.float32)
- for name, value in flat_grad:
- scaled = value.astype(mx.float32) * _grad_leaf_scale(
- name, safe_toks_f
- )
- norm_sq = norm_sq + mx.sum(scaled * scaled)
- grad_norm = mx.sqrt(norm_sq)
- if max_grad_norm > 0:
- clip_scale = mx.minimum(
- mx.array(max_grad_norm, dtype=mx.float32) / (grad_norm + 1e-6),
- mx.array(1.0, dtype=mx.float32),
- )
-
- final_grad = tree_unflatten([
- (
- name,
- value * _grad_leaf_scale(name, safe_toks_f, clip_scale, value.dtype),
+ grad_norm = None
+ final_items = []
+ for name, value in flat_grad:
+ scaled = value * _grad_leaf_scale(
+ name, safe_toks_f, None, value.dtype
+ )
+ final_items.append((name, scaled))
+ final_grad = tree_unflatten(final_items)
+ if max_grad_norm > 0:
+ final_grad, grad_norm = optim.clip_grad_norm(
+ final_grad, max_norm=max_grad_norm
)
- for name, value in flat_grad
- ])
if _clip_grad_value:
# Elementwise clip after norm-scaling, before optimizer step.
final_grad = tree_map(
@@ -799,6 +842,7 @@ def _apply_update(grad, toks_f):
final_grad,
)
optimizer.update(model, final_grad)
+ return grad_norm
def _apply_update_direct(grad):
"""Fast exact path for the common ``grad_accum == 1`` case.
@@ -813,8 +857,9 @@ def _apply_update_direct(grad):
In this case the raw gradients already represent the final
per-token average, so we can clip/update the original tree directly.
"""
+ grad_norm = None
if max_grad_norm > 0:
- grad, _ = optim.clip_grad_norm(grad, max_norm=max_grad_norm)
+ grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm)
if _clip_grad_value:
# Elementwise clip per leaf — free memory (no cross-leaf reduction).
grad = tree_map(
@@ -822,6 +867,7 @@ def _apply_update_direct(grad):
grad,
)
optimizer.update(model, grad)
+ return grad_norm
# Unified step function for both VLM and text training.
# VLM: batch_data is a dict → loss_and_grad_fn(model, batch_dict)
@@ -833,10 +879,11 @@ def step_fn(batch_data, prev_state, do_update):
(lvalue, toks), grad = loss_and_grad_fn(model, batch_data[0], batch_data[1], batch_data[2])
if _direct_single_step_update:
- _apply_update_direct(grad)
- return lvalue, toks, None
+ grad_norm = _apply_update_direct(grad)
+ return lvalue, toks, None, grad_norm
toks_f = toks.astype(mx.float32)
+ grad_norm = mx.array(0.0, dtype=mx.float32)
# Scale-and-accumulate in a single tree_map per micro-batch, casting
# the scalar to each leaf's dtype so bf16/fp16 grad trees stay in
@@ -844,6 +891,12 @@ def step_fn(batch_data, prev_state, do_update):
# and forces optimizer.update to promote params/m/v to fp32).
if prev_state is not None:
prev_grad, prev_toks = prev_state
+ # Accumulated gradients are optimizer state, not something to
+ # differentiate through on the next compiled micro-batch. This
+ # keeps custom VJP losses such as CCE from retaining/corrupting
+ # the carried bf16 accumulation graph.
+ prev_grad = tree_map(mx.stop_gradient, prev_grad)
+ prev_toks = mx.stop_gradient(prev_toks)
grad = tree_map(
lambda g, p: p + g * toks_f.astype(g.dtype),
grad, prev_grad,
@@ -856,17 +909,21 @@ def step_fn(batch_data, prev_state, do_update):
)
if do_update:
- _apply_update(grad, toks_f)
- return lvalue, toks, None
+ grad_norm = _apply_update(grad, toks_f)
+ return lvalue, toks, None, grad_norm
- return lvalue, toks, (grad, toks_f)
+ grad = tree_map(mx.stop_gradient, grad)
+ toks_f = mx.stop_gradient(toks_f)
+ return lvalue, toks, (grad, toks_f), None
compile_policy = build_compile_policy(args=args)
_compile_decision = getattr(self, "_compile_decision", None)
_use_compile = compile_policy.mode != "eager"
- # why: text MLX has no compile qualification pipeline; stay eager
- # unless UNSLOTH_MLX_TEXT_COMPILE=1.
- if not is_vlm and _use_compile and os.environ.get("UNSLOTH_MLX_TEXT_COMPILE", "0") != "1":
+ if _use_compile and max_grad_norm > 0 and grad_accum > 1:
+ print(
+ "Unsloth: mx.compile disabled because MLX global norm "
+ "clipping is enabled with gradient accumulation."
+ )
_use_compile = False
if is_vlm and _use_compile:
qual = getattr(model, "_unsloth_compile_qualification", None) or get_compile_qualification(model)
@@ -1006,9 +1063,13 @@ def step_fn(batch_data, prev_state, do_update):
batch_idx += 1
do_update = (it % grad_accum == 0)
+ if do_update:
+ # Keep callable scheduler evaluation outside mx.compile. The
+ # compiled step reads the scalar LR already in optimizer state.
+ self._set_optimizer_lr_for_step(optimizer, it // grad_accum - 1)
try:
- lvalue, toks, grad_accum_state = step_fn(
+ lvalue, toks, grad_accum_state, grad_norm = step_fn(
batch_data, grad_accum_state, do_update,
)
except (ValueError, RuntimeError) as e:
@@ -1032,7 +1093,7 @@ def step_fn(batch_data, prev_state, do_update):
step_fn = _uncompiled_step_fn
_use_compile = False
state = [model.state, optimizer.state, mx.random.state]
- lvalue, toks, grad_accum_state = step_fn(
+ lvalue, toks, grad_accum_state, grad_norm = step_fn(
batch_data, grad_accum_state, do_update,
)
else:
@@ -1041,10 +1102,31 @@ def step_fn(batch_data, prev_state, do_update):
losses += lvalue * toks
n_tokens += toks
steps += 1
+ if grad_norm is not None:
+ mx.eval(grad_norm)
if grad_accum_state is not None:
mx.eval(state, losses, n_tokens, grad_accum_state[0], grad_accum_state[1])
else:
mx.eval(state, losses, n_tokens)
+ if (
+ do_update
+ and grad_norm is None
+ and max_grad_norm <= 0
+ and _can_report_optimizer_state_norm()
+ ):
+ grad_norm = _grad_norm_from_optimizer_state()
+ elif (
+ do_update
+ and grad_norm is None
+ and max_grad_norm <= 0
+ and not _can_report_optimizer_state_norm()
+ and not _warned_skip_optimizer_state_grad_norm
+ ):
+ print(
+ "Unsloth: skipping grad norm reporting for this MLX "
+ "optimizer/mode to avoid materializing the gradient graph."
+ )
+ _warned_skip_optimizer_state_grad_norm = True
if int(toks.item()) == 0:
raise ValueError(
"Unsloth MLX: a training batch produced zero supervised "
@@ -1070,6 +1152,12 @@ def step_fn(batch_data, prev_state, do_update):
peak_mem = mx.get_peak_memory() / 1e9
self._train_loss_history.append(train_loss)
+ grad_norm_val = (
+ float(grad_norm.item())
+ if grad_norm is not None else None
+ )
+ if grad_norm_val is not None:
+ self._grad_norm_history.append(grad_norm_val)
self._tokens_per_second_history.append(tokens_sec)
self._peak_memory_history.append(peak_mem)
self._step_times.append(train_time / steps if steps > 0 else 0)
@@ -1082,9 +1170,14 @@ def step_fn(batch_data, prev_state, do_update):
elapsed_total = time.perf_counter() - start_time
+ grad_text = (
+ f"Grad: {grad_norm_val:.4f} | "
+ if grad_norm_val is not None else ""
+ )
print(
f" Step {current_step}/{total_steps} | "
f"Loss: {train_loss:.4f} | "
+ f"{grad_text}"
f"LR: {lr_val:.2e} | "
f"Tok/s: {tokens_sec:.0f} | "
f"Peak: {peak_mem:.2f} GB"
@@ -1092,8 +1185,11 @@ def step_fn(batch_data, prev_state, do_update):
for cb in self._step_callbacks:
try:
- cb(current_step, total_steps, train_loss, lr_val,
- tokens_sec, peak_mem, elapsed_total, trained_tokens)
+ cb(
+ current_step, total_steps, train_loss, lr_val,
+ tokens_sec, peak_mem, elapsed_total, trained_tokens,
+ grad_norm_val,
+ )
except Exception as e:
print(f"Unsloth: step callback error: {e}")
@@ -1267,7 +1363,7 @@ def _prepare_data(self, is_vlm):
def save_model(self, output_dir=None):
"""Save LoRA adapters or full merged model (if no LoRA)."""
- from .mlx_utils import save_merged_model
+ from .utils import save_merged_model
output_dir = output_dir or self.args.output_dir
trainable = dict(tree_flatten(self.model.trainable_parameters()))
@@ -1288,7 +1384,7 @@ def save_model(self, output_dir=None):
break
- from .mlx_utils import _get_transformer_layers
+ from .utils import _get_transformer_layers
layers = _get_transformer_layers(self.model)
_num_layers = len(layers) if layers else -1
@@ -1590,7 +1686,7 @@ def train_on_responses_only(
Returns:
The trainer (for chaining), or the masking closure if return_function=True.
"""
- from .dataset_utils import (
+ from ..dataset_utils import (
train_on_responses_only as _hf_train_on_responses_only,
)
diff --git a/unsloth_zoo/mlx_utils.py b/unsloth_zoo/mlx/utils.py
similarity index 99%
rename from unsloth_zoo/mlx_utils.py
rename to unsloth_zoo/mlx/utils.py
index 0aac7dbd8..3caf75688 100644
--- a/unsloth_zoo/mlx_utils.py
+++ b/unsloth_zoo/mlx/utils.py
@@ -36,7 +36,7 @@
from pathlib import Path
-from .mlx_cce import _get_runtime_cce
+from .cce import _get_runtime_cce
def _safe_token_denominator(ntoks):
@@ -2116,7 +2116,7 @@ def _extract_vlm_images(item, messages, image_size):
if not images and isinstance(messages, list):
try:
- from .vision_utils import process_vision_info
+ from ..vision_utils import process_vision_info
extracted = process_vision_info(messages, return_video_kwargs=True)
if isinstance(extracted, tuple) and extracted:
@@ -2980,7 +2980,7 @@ def save_pretrained_gguf(
compresses it to ``quantization_method``. Pass ``"f32"`` /
``"f16"`` / ``"bf16"`` to force a specific intermediate
"""
- from .llama_cpp import (
+ from ..llama_cpp import (
convert_to_gguf,
quantize_gguf,
install_llama_cpp,
diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py
index d4d800909..26cecd385 100644
--- a/unsloth_zoo/temporary_patches/moe_utils.py
+++ b/unsloth_zoo/temporary_patches/moe_utils.py
@@ -21,6 +21,7 @@
import importlib.util
from typing import Optional, Tuple
from torch.autograd import Function
+from unsloth_zoo.mlx import is_mlx_available
# Get compile location
UNSLOTH_COMPILE_LOCATION = os.environ.get(
@@ -224,6 +225,7 @@ def persistent_alloc_fn(size: int, alignment: int, stream):
def _check_grouped_gemm_available():
"""Check if Unsloth grouped GEMM kernels are available."""
if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
+ if is_mlx_available(): return False
global _GROUPED_GEMM_AVAILABLE
if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE