Skip to content

fix mlx: Adds the MLX training path used by Studio on Apple Silicon#634

Merged
danielhanchen merged 16 commits into
unslothai:mainfrom
mmathew23:fix/mlx
May 14, 2026
Merged

fix mlx: Adds the MLX training path used by Studio on Apple Silicon#634
danielhanchen merged 16 commits into
unslothai:mainfrom
mmathew23:fix/mlx

Conversation

@mmathew23

@mmathew23 mmathew23 commented May 8, 2026

Copy link
Copy Markdown
Collaborator

Adds the MLX training path used by Studio on Apple Silicon, including text/VLM loading, CCE training, compile support, LoRA/QLoRA/full-finetune
behavior, and grad norm reporting.

Key changes:

  • Refactors MLX code into unsloth_zoo/mlx/:
    • loader.py
    • trainer.py
    • utils.py
    • compile.py
    • runtime.py
    • cce/
  • Adds/updates MLX runtime detection and package exports.
  • Adds compiled MLX CCE support, including preserving auxiliary CCE outputs needed by custom VJP under mx.compile.
  • Enables compiled text training by default.
  • Supports MLX LoRA and QLoRA training through Studio.
  • Supports full finetuning on MLX.
  • Defaults MLX full finetuning to fp32 to match Unsloth Torch full-FT behavior.
  • Adds float32_mixed_precision=False opt-out for bf16 full finetuning.
  • Adds value clipping default support via max_grad_value=3.0.
  • Reports grad norm from Adam/AdamW optimizer state without retaining the backward graph.
  • Stabilizes compiled gradient accumulation with mx.stop_gradient on carried accumulation state.
  • Keeps global max_grad_norm guarded because it is expensive and disables compile with gradient accumulation.

Validation:

  • pytest tests/test_mlx_runtime_cce_compile.py tests/test_mlx_torch_shim_smoke.py
  • Result: 58 passed
  • Verified full finetuning, LoRA, and QLoRA with compiled CCE.
  • Verified CCE memory savings vs standard CE for fp32 full finetuning.

mmathew23 added 8 commits May 7, 2026 12:43
  - Move MLX text/VLM modules under unsloth_zoo.mlx and update imports.
  - Keep MLX LR schedules in the trainer and write scalar learning rates into the optimizer before each real update, avoiding callable scheduler
  evaluation inside mx.compile.
  - Fix warmup indexing so the first optimizer update uses a nonzero LR without reaching peak LR early.
  - Make cosine decay target 0.0 for HF-style scheduler behavior.
  - Add scheduler coverage for linear/cosine/constant with and without warmup.

  The scheduler change keeps mx.compile enabled for dynamic LR training while avoiding the compiled callable-scheduler path that caused flat loss /
  unchanged weights in Studio-like Qwen text training.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLX integration by reorganizing modules into a dedicated mlx sub-package and standardizing device detection via a new is_mlx_available utility. Significant enhancements are made to the MLX trainer, including the implementation of gradient norm reporting for Adam-family optimizers, updated learning rate scheduler logic, and the addition of a float32_mixed_precision option for model loading. Review feedback identified a potential discontinuity in the learning rate warmup transition and an issue with gradient norm tracking when resuming training from checkpoints, both of which were addressed with specific implementation improvements.

Comment on lines +331 to +334
def warmup_fn(step):
step = mx.array(step)
step = mx.minimum(step + 1, mx.array(warmup))
return step * (lr / (warmup + 1))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current warmup calculation step * (lr / (warmup + 1)) with step = mx.minimum(step + 1, mx.array(warmup)) introduces a jump in learning rate at the boundary between the warmup and the main schedule. For example, if warmup=5, at the last warmup step (step=4), the LR is 5/6 * lr, but the next step (start of main schedule) jumps to lr. Using (step + 1) * (lr / warmup) ensures a smooth transition where the warmup reaches exactly lr at the final warmup step.

Suggested change
def warmup_fn(step):
step = mx.array(step)
step = mx.minimum(step + 1, mx.array(warmup))
return step * (lr / (warmup + 1))
def warmup_fn(step):
step = mx.array(step)
# Ensure smooth transition to main schedule at step == warmup
return (step + 1) * (lr / warmup)

Comment on lines +759 to +770
optimizer_v_sum = None

def _optimizer_v_total():
total = mx.array(0.0, dtype=mx.float32)
found = False
for name, value in tree_flatten(getattr(optimizer, "state", {})):
if name != "v" and not name.endswith(".v"):
continue
found = True
value_f = value.astype(mx.float32)
total = total + mx.sum(value_f)
return total if found else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initializing optimizer_v_sum to None causes the first gradient norm reported after resuming training from a checkpoint to be incorrect, as it will include the magnitude of the existing optimizer state. Initializing it with the current total second moment from the optimizer state ensures that the reported norm correctly reflects only the current update's magnitude.

Suggested change
optimizer_v_sum = None
def _optimizer_v_total():
total = mx.array(0.0, dtype=mx.float32)
found = False
for name, value in tree_flatten(getattr(optimizer, "state", {})):
if name != "v" and not name.endswith(".v"):
continue
found = True
value_f = value.astype(mx.float32)
total = total + mx.sum(value_f)
return total if found else None
def _optimizer_v_total():
total = mx.array(0.0, dtype=mx.float32)
found = False
for name, value in tree_flatten(getattr(optimizer, "state", {})):
if name != "v" and not name.endswith(".v"):
continue
found = True
value_f = value.astype(mx.float32)
total = total + mx.sum(value_f)
return total if found else None
# Initialize with current state to correctly handle resumed training
optimizer_v_sum = _optimizer_v_total()

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 055b834d54

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/loader.py
Comment on lines 2420 to 2423
model_type = config_data.get("model_type", "")
try:
from unsloth.models.mlx import get_unsloth_loader
custom_loader = get_unsloth_loader(model_type)
except (ImportError, AttributeError, NotImplementedError):
custom_loader = None

if custom_loader is not None:
with _temporary_hf_token_env(token):
model, tokenizer_or_processor = custom_loader(
model_name, config_data, max_seq_length=max_seq_length, token=token
)
if text_only is False or _is_vlm(config_data):
from .mlx_utils import normalize_vlm_processor_chat_template

tokenizer_or_processor = normalize_vlm_processor_chat_template(
tokenizer_or_processor,
chat_template=chat_template,
model_name=model_name,
model_type=model_type,
strict=False,
)
model._is_vlm_model = True
model._processor = tokenizer_or_processor
_patch_mixed_precision_set_dtype(model)
elif chat_template is not None:
from .mlx_utils import normalize_mlx_chat_template

tokenizer_or_processor = normalize_mlx_chat_template(
tokenizer_or_processor,
chat_template=chat_template,
model_name=model_name,
model_type=model_type,
is_vlm=False,
strict=False,
)
model._config = config_data
model._hf_repo = model_name
model._src_path = local_path
model._unsloth_base_revision = revision
model._unsloth_base_commit_hash = _infer_snapshot_commit(local_path)
model.max_seq_length = max_seq_length
model._unsloth_patch_mode = patch_mode
model._unsloth_full_finetuning = bool(full_finetuning)
_patch_mlx_saving(model, tokenizer_or_processor)
return model, tokenizer_or_processor

# Step 3: Route based on text_only
# Step 2: Route based on text_only
is_vlm = False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore custom MLX loader registry dispatch

This branch now goes straight from adapter handling into generic text/VLM routing and no longer consults unsloth.models.mlx.get_unsloth_loader(model_type) before deciding the load path. In environments where Unsloth provides a model-type-specific MLX loader, those architectures will now bypass their required loader and can fail to load or lose model-specific setup. Re-introducing the registry check before this routing step keeps previously supported custom model types working.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +1179 to +1181
cb(current_step, total_steps, train_loss, lr_val,
tokens_sec, peak_mem, elapsed_total, trained_tokens)
tokens_sec, peak_mem, elapsed_total, trained_tokens,
grad_norm_val)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve old step-callback call signature

add_step_callback previously worked with callbacks taking 8 positional arguments, but this call site now always passes a 9th positional value (grad_norm_val). Existing callbacks that follow the prior signature will raise TypeError on every log step and be skipped by the exception handler, silently disabling downstream progress/metrics hooks. Keep backward compatibility by detecting callback arity (or otherwise falling back to the old 8-arg call).

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 00ef8eae1b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +20 to +24
from .runtime import is_mlx_available

__all__ = [
"is_mlx_available",
]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Add compatibility shims for legacy MLX import paths

This refactor moves MLX modules under unsloth_zoo.mlx.* but does not provide aliases for the previously public modules (unsloth_zoo.mlx_loader, unsloth_zoo.mlx_trainer, unsloth_zoo.mlx_utils, unsloth_zoo.mlx_compile, unsloth_zoo.mlx_cce). Existing user code and downstream integrations that import those paths will now fail with ModuleNotFoundError immediately after upgrade, even though functionality still exists under new paths. Please add thin compatibility modules or import-time aliases so existing scripts continue to run.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d11c19f9f0

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/loader.py Outdated
Comment on lines +670 to +677
spec = importlib.util.spec_from_file_location(
"_unsloth_optional_mlx_loader",
candidate,
)
if spec is None or spec.loader is None:
continue
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Import optional MLX loader via package module name

Load-time probing now executes unsloth/models/mlx.py as _unsloth_optional_mlx_loader instead of importing it as unsloth.models.mlx. That strips package context (__package__), so any relative import inside that module (for example from .foo import ...) raises during probe and silently disables custom loader dispatch for those model types. This is a regression for valid package-structured custom loaders and causes fallback routing to run for architectures that require the custom loader.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e2c9405e86

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +730 to +736
max_grad_value = 1.0 if _raw_mgv is None else float(_raw_mgv or 0.0)
if max_grad_norm > 0 and max_grad_value > 0:
print(
"Unsloth: max_grad_norm and max_grad_value are both enabled; "
"ignoring max_grad_norm in favor of max_grad_value."
)
max_grad_norm = 0.0

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor explicit max_grad_norm when max_grad_value is defaulted

This now silently disables global norm clipping whenever max_grad_value > 0, but max_grad_value was also changed to default to 1.0, so a user who sets only max_grad_norm (a common training knob) will never actually get norm clipping unless they also discover and set max_grad_value=0.0. That is a behavior regression from prior defaults and changes optimization dynamics unexpectedly across runs.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/__init__.py
Comment on lines 122 to +139
_inject_bnb()
del _inject_triton, _inject_bnb

# Temporary bridge for already-merged Unsloth code that imports the old
# flat MLX module names. Remove after the paired Unsloth PR lands and
# imports unsloth_zoo.mlx.* everywhere.
import importlib as _importlib
import sys as _sys

for _old_name, _new_name in (
("unsloth_zoo.mlx_loader", "unsloth_zoo.mlx.loader"),
("unsloth_zoo.mlx_trainer", "unsloth_zoo.mlx.trainer"),
("unsloth_zoo.mlx_utils", "unsloth_zoo.mlx.utils"),
("unsloth_zoo.mlx_compile", "unsloth_zoo.mlx.compile"),
("unsloth_zoo.mlx_cce", "unsloth_zoo.mlx.cce"),
("unsloth_zoo.mlx_cce.runtime_cce", "unsloth_zoo.mlx.cce.runtime_cce"),
):
_sys.modules.setdefault(_old_name, _importlib.import_module(_new_name))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply legacy MLX import shims outside MLX-only init path

The compatibility aliases for unsloth_zoo.mlx_loader / mlx_trainer / etc. are installed only inside the _SKIP_GPU_INIT branch, so they are absent on non-MLX environments even though the old module files were removed; importing those legacy paths now still raises ModuleNotFoundError there. Fresh evidence in this commit is that the bridge exists but is gated under if _SKIP_GPU_INIT: instead of being unconditional.

Useful? React with 👍 / 👎.

mmathew23 and others added 2 commits May 13, 2026 12:00
The MLX path uses elementwise value clipping by default to keep the compiled
low-memory path available. Bump the default from 1.0 to 5.0 to match the
range we are standardising on across Studio. Leaves explicit user-supplied
values untouched.
Companion to the previous commit. Without this the dataclass default
shadows the getattr fallback, so non-Studio callers instantiating
MLXTrainingConfig without an explicit max_grad_value still picked up 1.0.
@danielhanchen danielhanchen merged commit e6d8f7f into unslothai:main May 14, 2026
8 of 11 checks passed
danielhanchen added a commit that referenced this pull request May 15, 2026
* fix CI fallout from MLX subpackage refactor (#634)

Three small fixes triggered by tests on PRs #645 / unsloth#5426:

1) tests/test_upstream_pinned_symbols_accelerator.py read the legacy
   flat unsloth_zoo/mlx_trainer.py and mlx_loader.py paths directly.
   #634 moved those to unsloth_zoo/mlx/trainer.py and mlx/loader.py;
   point the tests at the new paths so they exercise the post-refactor
   sources.

2) unsloth_zoo/vision_utils.py:68 did an unconditional top-level
   import torchvision while torchvision is not a declared dependency.
   On a CPU-only install that lacks torchvision (e.g. zoo's Core CI
   matrix), every import of unsloth_zoo.vision_utils raises
   ModuleNotFoundError. Guard the import so the module loads; existing
   _read_video_torchvision call sites already only fire when video
   reading is requested.

3) tests/test_zoo_source_upstream_refs.py tries to resolve transformers
   qwen2_vl / qwen2_5_vl image-processing modules, both of which
   transitively `import torchvision`. Skip via pytest.importorskip
   when torchvision is not installed instead of failing the run.

Each fix is independent and minimal; no behavior change for installs
that have torchvision (the eager import path is identical), and no
behavior change for the MLX trainer / loader (the tests still pin
the exact same source-level invariants on the new file paths).

* tests/conftest: also stub unsloth.device_type on CPU-only runners

The post-#634 conftest stubs ``unsloth_zoo.device_type`` so importing
the package on a CPU-only CI runner doesn't trip ``get_device_type()``.
``test_unsloth_trainer_exec_marker`` (and any future test that does
``import unsloth.*``) goes one level up: it imports ``unsloth.trainer``,
which traverses ``unsloth/__init__.py`` → ``_gpu_init`` → ``unsloth/
device_type.py:get_device_type()``, raising the same NotImplementedError
zoo's stub was meant to suppress.

Parameterise ``_preload_real_device_type(package, prereqs)`` so the
same harness works for both packages, factor the fallback into
``_install_device_type_stub(name)``, and call both. unsloth's
``get_device_type()`` consumes ``unsloth_zoo`` helpers but has no
intra-unsloth prereqs, hence ``prereqs=()`` there.

* tests/conftest: don't stub unsloth.device_type after all

The previous commit (196a78b) stubbed unsloth.device_type so
test_unsloth_trainer_exec_marker could complete its
``import unsloth.trainer`` smoke. That makes ``import unsloth`` succeed
on CPU-only CI, which then runs
``unsloth/_gpu_init.py:_patch_trl_trainer()`` and rebinds
``trl.trainer.sft_trainer.SFTTrainer`` /
``transformers.models.ministral.MinistralAttention`` to Unsloth's
compiled wrappers.

``inspect.getsource(...)`` on those classes follows the wrapper's
``__init__.__code__.co_filename`` and returns the rewritten source --
which doesn't contain the literals the drift detectors anchor on
(``self._signature_columns``, ``self._prepare_dataset(``,
``class SFTTrainer``, named ``hidden_states``/``attention_mask``).
The HF=4.57.6 job went from -1 pre-existing failure
(``test_unsloth_trainer_exec_marker``, also failing on main) to -4 new
failures across ``test_MinistralAttention_forward_signature`` and the
three ``test_unsloth_rl_trainer_*`` checks. That's a regression.

Keep parameterised ``_preload_real_device_type`` and
``_install_device_type_stub`` -- they're harmless infrastructure and
useful if a future test needs the unsloth.device_type stub
specifically. Just don't fire the second preload at session setup.

``test_unsloth_trainer_exec_marker`` will continue to fail on
CPU-only CI as it does on main; that's a separate pre-existing issue
with ``unsloth/device_type.py`` raising on accelerator-less hosts and
should be fixed upstream in unsloth itself, not papered over in zoo's
conftest.

* device_type: mirror UNSLOTH_ALLOW_CPU=1 gate from unsloth #5429

unsloth/_gpu_init.py:125 imports DEVICE_TYPE from unsloth_zoo, not
local unsloth/device_type.py, so the env-var gate has to live here
too for `import unsloth.trainer` to succeed on CPU-only CI.

The companion unsloth PR landed as cb15a7a (#5429). Pairs with the
conftest change in this branch that sets UNSLOTH_ALLOW_CPU=1 before
triggering `import unsloth` from _apply_upstream_import_fixes_for_tests.

@functools.cache on get_device_type plus the module-level
DEVICE_TYPE = get_device_type() assignment mean the os.environ.get
runs exactly once per process. Production hosts with a real accelerator
hit the `cuda`/`xpu`/`hip` branches above and never reach the env-var
short-circuit.

* prod: graceful no-op for two transformers 5.x runtime bugs

Two upstream changes broke zoo at runtime (not just CI drift detection):

1. transformers 5.x removed transformers.cache_utils.HybridCache. zoo's
   utils.py falls back to `HybridCache = typing.Any`, which made
   `gemma.py:260 isinstance(past_key_values, HybridCache)` raise
   TypeError ("isinstance() arg 2 must be a type, a tuple of types,
   or a union") at runtime on the gemma mask path.

   Fix: expose HAS_HYBRID_CACHE alongside the existing HybridCache
   fallback in utils.py, and gate the isinstance call in gemma.py:260
   on the flag. When HybridCache is genuinely missing, the elif chain
   falls through to the dynamic-cache branch -- same observable
   behavior as the pre-5.x path that never matched the type.

2. transformers 5.x rewrote PreTrainedModel.save_pretrained. The
   source-string anchors zoo's LoRA-merge-on-save optimization patches
   in `merge_and_dequantize_lora` (state_dict_split assignment,
   `state_dict[tensor].contiguous()`, `def save_pretrained`, and the
   filename_to_tensors for-loop on the push_to_hub path) are all gone.
   Each downstream `raise RuntimeError("Failed to find ...")` would
   fire, killing `merge_and_dequantize_lora(push_to_hub=True)` on 5.x.

   Fix: scan all required anchors upfront in saving_utils.py. If any
   are missing on installed transformers, emit a one-time warnings.warn
   and fall back to vanilla `model.save_pretrained(...)`. The end-user
   sees the warning, no crash, and is told to call
   `model.merge_and_unload()` (or equivalent) before saving if they
   want merged LoRA weights on disk. The per-anchor RuntimeError
   raises downstream are kept as defense-in-depth for partial drift
   (one anchor renamed, others intact) that the upfront check might
   miss.

The two corresponding drift detectors in tests/ are rewritten in the
companion commit as positive assertions: 4.57.6 keeps the strict
existence check; 5.x asserts the anchor is gone AND zoo's fallback
covers it (`HAS_HYBRID_CACHE` flag / `_required_anchors` list).

* tests: version-gate transformers-5.x drift detectors

Sixteen drift detectors in the Core HF=default / HF=latest matrix
slices were pointing at upstream symbols / source strings that
transformers 5.x intentionally removed:

* 9 `temporary_patches/*_for_causal_lm_forward_named_params` --
  cache_position moved from named forward param into
  **kwargs: Unpack[TransformersKwargs] (DeepseekV3, GptOss, CsmDepth,
  Csm, Qwen3Moe, Qwen3Next, Qwen3VLMoe, MinistralModel, GraniteMoeHybrid).
* 4 compiler/source rewriter probes -- output_attentions branching
  removed, MoE routing_weights cast refactored, Trainer
  is_torch_tpu_available/is_torch_xla_available both removed,
  _update_causal_mask hasattr probe all-False.
* 1 routing_weights.to substring probe in compiler_rewriter_exhaustive.
* 1 GptOssConfig dedent-compare -- rope_theta/initial_context_length/
  rope_scaling collapsed into rope_parameters dict.
* 1 GptOssConfig kwarg signature -- same rope_theta rename.

For all 16, the zoo runtime patch already gracefully no-ops on 5.x
(verified by the Explore audit: try/except + relaxed patch_function +
hasattr guards + str.replace's silent no-op-on-miss). Skip the
detector on 5.x with a clear "<symbol> removed on transformers
{version}" message; keep it strict on 4.57.6 where real drift could
still surface.

Two more drift detectors -- test_hybrid_cache_class_present and the
two saving_utils.py pinned-string tests -- become positive assertions:
on 5.x they confirm zoo's prod fix correctly identifies the missing
anchor (HAS_HYBRID_CACHE is False / `_required_anchors` covers the
missing strings) and the graceful fallback is wired up.

Conftest also sets UNSLOTH_ALLOW_CPU=1 before the existing
`import unsloth` trigger, which unblocks `test_unsloth_trainer_exec_marker`
on CPU-only CI runners (the companion unsloth PR #5429 lets that
import succeed without an accelerator).

* tests: also gate test_MinistralAttention_forward_signature on 5.x

Missed this one in 18c9f62. Same root cause as MinistralModel:
transformers 5.x reflowed MinistralAttention.forward, zoo's
``patch_function(..., match_level='relaxed')`` falls back to a
``(self, *args, **kwargs)`` wrapper, the named-param probe sees the
wrapper signature and drift-fails. Runtime call still works because
the wrapper forwards via kwargs. Skip on 5.x, keep strict on 4.57.6.

* tests: probe Ministral stash + skip relaxed-wrapper case

Zoo's ministral.py:94-96 wraps the actual implementation in a generic
``def forward(self, *args, **kwargs): return _full_forward(...)``
adapter before calling patch_function -- this lets `check_args_kwargs`
accept params transformers 5.x removed (cache_position) by routing
them through **kwargs. The side effect: `inspect.signature(
MinistralAttention.forward)` after patching shows the wrapper, not the
real named params.

This test used to pass only because the test runner couldn't fully
trigger zoo's TEMPORARY_PATCHES loop (`unsloth/models/_utils.py:580`
runs at import time but unsloth import was being silently aborted).
With UNSLOTH_ALLOW_CPU=1 now making `import unsloth` succeed, the
patch fires correctly and the live signature is the wrapper.

Fix: probe the `_original_modeling_ministral_MinistralAttention_forward`
stash that `patch_function(store_original=True)` writes, falling back
to the live attribute. If the live attr is the relaxed wrapper and
there's no stash, the named-param probe isn't meaningful -- the
runtime kwargs contract is enforced elsewhere. Skip with a clear
message. 4.57.6: stash is present and the named-param probe still
runs strict; real drift in the upstream signature still surfaces.

* tests: gate qwen2_5_vl image-processor drift on 5.x

transformers 5.x dropped the slow image processors entirely
(image_processing_qwen2_5_vl.py and image_processing_qwen2_5_vl_fast.py
are both gone in v5.5.0). Qwen2.5-VL now reuses Qwen2VLImageProcessor
directly. zoo's misc.py:1500-1506 patch site for Qwen2_5_VLImageProcessor
is try/except ImportError-wrapped, so it silently no-ops on 5.x and the
runtime shim still fires via the Qwen2VLImageProcessor patch at
misc.py:1485-1498 (the class Qwen2.5-VL inherits at runtime).

The Qwen2.5-VL image-processor path was previously hidden by
``pytest.importorskip("torchvision")`` -- PR #648 added torchvision to
the upstream-regression matrix install, which unmasked the drift.
Runtime is unaffected. Add the same 5.x skip pattern used by the other
16 gated detectors so the dead-code patch site doesn't block CI.

4.57.6 keeps the strict path-existence check.
danielhanchen added a commit that referenced this pull request May 17, 2026
PR #634 silently flipped MLX AdamW's bias_correction from the historical
MLX default of False to True (matching torch.optim.AdamW). For real
multi-epoch fine-tunes the two converge identically after ~10-20
warmup steps, but for short memorization runs the difference is large:
bias_correction=True shrinks the step-1 effective update by ~3x.

Empirical bisection on a Mac M1 CI runner (probes 12 + 14 of the
mlx-parity-probes workflow):
  * pre-#634 trainer (bias_correction=False), 7 steps:
      loss 10.55 -> 5.04 (bouncy), generates "Unsloth! ..."
  * HEAD + PR #663 only (bias_correction=True), 7 steps:
      loss 10.55 -> 0.17 (smooth), generates "5 lbs!"
  * HEAD + bias_correction=False (this PR), 7 steps:
      loss 10.55 -> 2.44 (bouncy), generates "Unsloth! ..."

The upstream MLX smoke test in unslothai/unsloth and every other
existing MLX fine-tune script implicitly relied on the bias_correction=
False default. Restoring it as the default fixes that contract.

Add `adam_bias_correction: bool = False` to MLXTrainingConfig so users
who want true HF/torch.AdamW parity can opt in explicitly. Plumb it
through both the adamw and adam construction paths.

Regression test pins the default to False.
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging-1 that referenced this pull request May 24, 2026
- save_pretrained_merged(save_method='lora') now mirrors MLXTrainer.save_model:
  when the trainable set contains keys outside the module-anchored adapter set
  (intentionally trainable embeddings/projector/vision), it routes through
  save_trainable_adapters so the public save API does not silently lose the
  trained non-LoRA state. Pure LoRA runs still take the lean save_lora_adapters
  path.
- Periodic in-loop checkpoint save now wraps save_trainable_adapters in
  try/except so a fully-frozen state at a checkpoint step degrades to a
  printed skip rather than tearing down the training loop on the empty-trainable
  ValueError introduced in the previous commit.

FA3 note: the original "Saved checkpoint to ..." print line (blame
e6d8f7f "fix mlx: Adds the MLX training path used by Studio on Apple
Silicon (unslothai#634)") is preserved as the try/except else arm; it still runs on
every successful checkpoint write and only the failure path is added.
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging-1 that referenced this pull request May 24, 2026
- _ensure_lora_frozen now early-returns when trainable_parameters() is
  empty so the helper does not require model.parameters() on the stub
  models used by tests like test_adam_optimizers_enable_bias_correction.
  why safe: with zero trainable tensors there is nothing for the norm
  safeguard to flag and no LoRA detection is required.
- save_model uses _extract_mlx_lora_parameters() to derive the rank /
  scale / dropout fields and computes the trainable split inside the LoRA
  branch so the writer routing is driven by the same module-anchored
  adapter set regardless of which tensors happen to be trainable.
- save_trainable_adapters now unions the trainable tree with the full
  module-anchored adapter set so reloaded LoRA tensors that were frozen
  before a mixed fine-tune are still emitted into the artifact.
- _enrich_mlx_adapter_config now backfills lora_parameters / rank /
  scale / dropout from the model when the caller did not supply them so
  save_pretrained_merged ships a complete adapter_config.json.

FA3 note: the original `from .utils import save_merged_model` and
`from .utils import _get_transformer_layers` lines (blame e6d8f7f
"fix mlx: Adds the MLX training path used by Studio on Apple Silicon
(unslothai#634)") are preserved at their original locations in save_model; the
new `_extract_mlx_lora_parameters` import is added as a separate
statement next to save_merged_model so neither existing import is
deleted or consolidated.
danielhanchen added a commit that referenced this pull request May 27, 2026
* fix(mlx): save only adapter tensors

* Anchor LoRA filter to lora_a/lora_b modules and fix trainer detection for PR #692

Three follow-ups from review feedback:

1. save_lora_adapters now filters via a new
   collect_mlx_lora_adapter_tensors() helper that walks named_modules()
   and keeps only tensors belonging to modules that actually expose
   lora_a / lora_b. The previous "lora_" substring filter falsely
   included any path containing the prefix, e.g. router.lora_gate.weight
   or scalar_lora_alpha. Tolerates lora_A / lora_B casing too.

2. MLXTrainer.save_model() now uses the same helper to decide whether
   to call save_lora_adapters. The previous trainable_parameters() scan
   missed adapter tensors after a reload/freeze (LoRA lives in
   parameters() but is not always marked trainable), which caused
   final export to fall through to save_merged_model() instead of
   writing an adapter file.

3. tests/test_mlx_save_lora_adapters_filter.py adds five regressions
   over the split-save semantics: lora-only filter, brittleness
   demonstration (lora_router), no-LoRA ValueError, trainable
   checkpoint preserves everything, post-reload adapter detection.
   Also widens the ValueError message to point users at
   save_trainable_adapters() for non-LoRA checkpoint use.

adapter_config.json write also now pins encoding="utf-8" for
cross-platform parity.

* Scrub .github/workflows for staging push (matches staging base)

* Split: keep only 1 file(s)

* mlx: align LoRA detection across save and norm-freeze paths

_ensure_lora_frozen now drives LoRA detection from
collect_mlx_lora_adapter_tensors instead of an "lora" substring scan over
trainable_parameters(). The norm-freezing intent from blame commit
82d75e0 ("Improve compiled training loop, CCE memory optimizations, and
LoRA stability" - "Add _ensure_lora_frozen() to prevent NaN from
unfrozen LayerNorm in adaptive optimizers") is preserved verbatim; only
the LoRA presence predicate is swapped so that reloaded/frozen LoRA
models, whose adapter tensors live in parameters() but are not listed as
trainable, still trigger the safeguard. The old predicate silently
returned False in that state and let the LayerNorm NaN bug recur, which
is the exact failure mode the original commit was added to prevent.

save_pretrained_merged now uses the same collect_mlx_lora_adapter_tensors
predicate, so the outer LoRA-vs-merged gate and the inner
save_lora_adapters helper agree on what counts as a LoRA model and the
error path stays consistent.

* mlx: scope norm-freeze guard to active LoRA and hoist merged save check

_ensure_lora_frozen now requires at least one module-anchored LoRA tensor
to be in trainable_parameters() before freezing accidentally trainable
norms. The previous full-tree check would silently freeze a user's
trainable norm when a reloaded model still carried frozen LoRA tensors
in parameters() but the user was running a non-LoRA fine-tune. The
module-anchored predicate is kept (no more "lora" substring false
positives), only the trainability gate is restored, preserving the
blame-cited NaN-protection intent for active LoRA training without
disturbing the non-LoRA reload case.

save_pretrained_merged no longer collects LoRA adapter tensors when
save_method is merged_16bit or merged_4bit, since those branches do not
read has_lora. The call is hoisted into the lora branch so merged saves
skip the full-parameter walk.

* mlx: tighten save_model LoRA detection comment

Trim the four-line rationale block to the single load-bearing fact:
reloaded LoRA can sit in parameters() without trainable_parameters(),
so the old substring check fell through to save_merged_model().

* Cover MLX LoRA detection edges in adapter-filter test module

Extend the existing regression module with five behavior-named
assertions:

- norm-freeze runs when LoRA is actively trained alongside an
  accidentally trainable norm
- norm-freeze skips when adapter tensors are reloaded but not
  trainable (the user is doing a non-LoRA fine-tune)
- norm-freeze skips for non-LoRA models
- save_pretrained_merged raises the user-facing "no LoRA layers"
  message at the gate when adapter tensors are absent
- save_pretrained_merged skips the LoRA collector for merged_16bit
  and merged_4bit save paths

Also drop the unused json + pathlib imports, the unused tmp_path
fixture in test_collect_lora_helper_finds_adapters_after_reload, and
tighten the module docstring.

* Preserve non-LoRA trainables and fix uppercase/SwitchLinear paths for PR #692

- collect_mlx_lora_adapter_tensors() now requires a complete attribute pair
  (lora_a + lora_b, or lora_A + lora_B). Half-adapter modules no longer slip
  through into adapters.safetensors as unreloadable artifacts.
- New iter_mlx_lora_modules() helper feeds the collector, _enrich_mlx_adapter_config(),
  and MLXTrainer.save_model() so uppercase tensors get matching module paths,
  rank, scale, and dropout metadata instead of falling back to defaults.
- MLXTrainer.save_model() routes mixed LoRA + non-LoRA trainables to
  save_trainable_adapters() so intentionally trainable embeddings, projector,
  vision, or norm weights are not silently dropped from the final artifact.
  Frozen-LoRA + non-LoRA-trainable runs now correctly fall through to
  save_merged_model().
- LoRASwitchLinear rank reads shape[-2] for (num_experts, rank, in_dims)
  instead of writing in_dims into adapter_config["rank"].

Test cases:
- test_collect_lora_helper_accepts_uppercase_pair
- test_collect_lora_helper_drops_half_adapter_module
- test_iter_mlx_lora_modules_reports_attr_pair
- test_enrich_adapter_config_records_uppercase_lora_paths

* Scrub .github/workflows for staging push (matches staging base)

* Split: keep only 1 file(s)

* mlx: tighten LoRA save routing and guard frozen-trainable checkpoint

- trainer.save_model now calls the unsloth_zoo save_lora_adapters utility
  directly, matching the calling convention used for save_trainable_adapters
  on the sibling branch and removing the dependence on the loader-side
  monkeypatch when the trainer is constructed without _patch_mlx_saving.
- _ensure_lora_frozen reuses one collect_mlx_lora_adapter_tensors pass and
  uses the resulting adapter key set for the norm safeguard so paths that
  merely contain "lora" (e.g. router.lora_gate.weight) no longer disable
  the freeze.
- save_trainable_adapters now raises when trainable_parameters() is empty
  so checkpoint directories never carry adapter_config.json without an
  adapters.safetensors next to it.
- Drop unused unpack names (_b_attr, _module) in iter_mlx_lora_modules
  loops.

* mlx: route save_pretrained_merged through trainable-aware adapter writer

- save_pretrained_merged(save_method='lora') now mirrors MLXTrainer.save_model:
  when the trainable set contains keys outside the module-anchored adapter set
  (intentionally trainable embeddings/projector/vision), it routes through
  save_trainable_adapters so the public save API does not silently lose the
  trained non-LoRA state. Pure LoRA runs still take the lean save_lora_adapters
  path.
- Periodic in-loop checkpoint save now wraps save_trainable_adapters in
  try/except so a fully-frozen state at a checkpoint step degrades to a
  printed skip rather than tearing down the training loop on the empty-trainable
  ValueError introduced in the previous commit.

FA3 note: the original "Saved checkpoint to ..." print line (blame
e6d8f7f "fix mlx: Adds the MLX training path used by Studio on Apple
Silicon (#634)") is preserved as the try/except else arm; it still runs on
every successful checkpoint write and only the failure path is added.

* mlx: guard norm-freeze on empty trainable + tighten LoRA save path

- _ensure_lora_frozen now early-returns when trainable_parameters() is
  empty so the helper does not require model.parameters() on the stub
  models used by tests like test_adam_optimizers_enable_bias_correction.
  why safe: with zero trainable tensors there is nothing for the norm
  safeguard to flag and no LoRA detection is required.
- save_model uses _extract_mlx_lora_parameters() to derive the rank /
  scale / dropout fields and computes the trainable split inside the LoRA
  branch so the writer routing is driven by the same module-anchored
  adapter set regardless of which tensors happen to be trainable.
- save_trainable_adapters now unions the trainable tree with the full
  module-anchored adapter set so reloaded LoRA tensors that were frozen
  before a mixed fine-tune are still emitted into the artifact.
- _enrich_mlx_adapter_config now backfills lora_parameters / rank /
  scale / dropout from the model when the caller did not supply them so
  save_pretrained_merged ships a complete adapter_config.json.

FA3 note: the original `from .utils import save_merged_model` and
`from .utils import _get_transformer_layers` lines (blame e6d8f7f
"fix mlx: Adds the MLX training path used by Studio on Apple Silicon
(#634)") are preserved at their original locations in save_model; the
new `_extract_mlx_lora_parameters` import is added as a separate
statement next to save_merged_model so neither existing import is
deleted or consolidated.

* mlx: trim verbose adapter-save commentary

Drop multi-line restatements of routing intent and the test-name back
reference; keep only the load-bearing why for the empty-trainable guard
and the structural-detect comment in save_model. No behavior change.

* mlx: extend adapter-filter coverage with mixed and frozen-LoRA edges

Adds regression coverage for:
- save_trainable_adapters raises when nothing trainable AND no LoRA
- save_trainable_adapters preserves frozen LoRA alongside trainable norms
- _ensure_lora_frozen freezes norms whose path contains literal "lora"
- save_pretrained_merged(save_method='lora') routes mixed vs pure cases
- save_pretrained_merged(save_method='lora') writes complete adapter_config

Guards the autouse _install_shim against double-injecting the torch shim
on hosts where real mlx is already importable.

* Make adapter export strictly LoRA-only and drop uppercase paths for PR #692

The earlier mixed-trainable routing in save_pretrained_merged() and
MLXTrainer.save_model() let base tensors like q_proj.weight leak into
adapters.safetensors whenever a reloaded checkpoint had base weights
marked trainable. save_method='lora' now always emits a clean LoRA-only
artifact that mlx-lm.load_adapters() can read. Callers that intentionally
ship mixed LoRA + embedding / projector / vision trainables use the
explicit save_trainable_adapters() API; in-loop checkpoints still cover
the full trainable tree.

Same review pass also exposed three smaller asymmetries; fix all three:

- iter_mlx_lora_modules / collect_mlx_lora_adapter_tensors no longer
  detect uppercase lora_A / lora_B. mlx-lm reload only ever recreates
  lowercase wrappers, so accepting uppercase made save_lora_adapters
  emit weights that load_weights(..., strict=False) silently dropped.
- _is_lm_head_trainable() now consults collect_mlx_lora_adapter_tensors()
  instead of an "lora" substring test, so trainable keys whose names
  merely contain "lora" (lora_router.weight or base.lora_special.lm_head
  .weight) are no longer misclassified as adapter state.
- _enrich_mlx_adapter_config() backfills num_layers + fine_tune_type +
  peft_type so save_pretrained_merged(save_method='lora') ships a config
  mlx-lm.load_adapters() can consume, matching the trainer's save path.
- _extract_mlx_lora_parameters() reads MLX Dropout._p_1 (keep prob) with
  a .p shim fallback so nonzero dropout no longer serializes as 0.0.

Drop the now-unused iter_mlx_lora_modules import from trainer.py to
clear the F401 added by the previous round.

* Preserve external trainables, gate LoRA metadata, fix LoRA+ scope for PR #692

The earlier strictly-LoRA-only save dropped intentionally trained tensors
that live outside any LoRA-wrapped module (embed_tokens, lm_head,
projector, vision, norm) from the final artifact. Use a module-prefix
filter so trainables OUTSIDE LoRA modules survive while base weights
INSIDE a LoRA module (which a reload may mark trainable) are still
excluded, keeping the original Studio reload leak shut.

Adjacent review pass fixes:

- save_pretrained_merged(save_method='lora', push_to_hub=True) now
  uploads the adapter directory via HfApi.upload_folder instead of
  re-routing through push_to_hub_merged, which would call
  save_merged_model() and overwrite the adapter-only artifact.
- _enrich_mlx_adapter_config() only writes lora_parameters / num_layers
  / fine_tune_type='lora' / peft_type='LORA' when the model actually has
  LoRA modules (or the caller declared a lora/dora artifact). Full
  fine-tune checkpoints saved via save_trainable_adapters() no longer
  carry fake LoRA metadata that mlx-lm.load_adapters() would treat as a
  LoRA reload.
- LoRA+ gradient scaling in _grad_leaf_scale() now anchors on the
  parameter suffix (.lora_b) rather than substring "lora_b", so an
  unrelated trainable like lora_b_router.weight cannot receive the LoRA+
  multiplier.
- collect_mlx_lora_adapter_tensors() also picks up DoRA modules'
  trained magnitude vector m alongside lora_a / lora_b so DoRA reload
  recovers the learned magnitudes.

Tests: route preserve external trainables; full-checkpoint config has no
LoRA metadata.

* Stop leaking base weights under LoRA and restore push-to-hub LoRA API for PR #692

save_trainable_adapters() previously dumped every tensor returned by
model.trainable_parameters(). After a checkpoint reload, base tensors
inside LoRA-wrapped modules (e.g. q_proj.weight under a wrapped q_proj)
can end up marked trainable, so the public mixed-fine-tune route in
save_pretrained_merged(save_method='lora') and MLXTrainer.save_model
would silently include those base weights in adapters.safetensors. Apply
the same lora_module_prefixes filter inside save_trainable_adapters
itself: keep external trainables (embed_tokens, lm_head, projector,
vision, norm) and the module-anchored LoRA tensors, but drop base
weights that live INSIDE a LoRA module. Regression test:
test_save_pretrained_merged_lora_mixed_external_drops_inside_lora_base.

Adjacent push-to-hub fix:

The earlier LoRA push branch used a bare HfApi.upload_folder call that
silently dropped four push_to_hub_merged behaviours: caller-supplied
repo_id, private-flag updates on re-push, ModelCard tag merging, and the
"Trained with Unsloth" commit message + description convention. It also
used upload_folder instead of upload_large_folder, so large adapter
exports (e.g. save_trainable_adapters dumps with embeddings) could not
resume on a flaky connection.

Factor _push_lora_adapters_to_hub() that mirrors push_to_hub_merged()
end-to-end without re-saving a full merged model on top of the adapter
directory, and forward repo_id / commit_message / commit_description /
create_pr / revision through save_pretrained_merged so both save methods
share the same Hub-push surface.

* Honor commit_message / commit_description / create_pr on LoRA push for PR #692

_push_lora_adapters_to_hub previously called HfApi.upload_large_folder which,
on huggingface_hub>=0.34 (this repo's floor), silently drops commit_message,
commit_description, create_pr, and revision. Callers passing
save_pretrained_merged(save_method='lora', push_to_hub=True, commit_message=...,
create_pr=True) got a commit titled "Upload N LFS files" on main with no PR
opened. The kwargs were only honored by the upload_folder fallback, which
huggingface_hub>=0.34 never triggers.

Route LoRA pushes through upload_folder primarily (LoRA adapter directories
are small, typically <500MB and ~5 files, so chunking buys nothing) and keep
upload_large_folder as the fallback for environments where upload_folder is
unavailable or has an incompatible signature.

Add tests: one asserts that custom commit_message / commit_description /
create_pr / revision actually reach the upload call and that
upload_large_folder is not hit on the happy path; the other asserts the
fallback to upload_large_folder still runs when upload_folder raises
TypeError or AttributeError.

* Honor custom commit metadata on push_to_hub_merged for PR #692

push_to_hub_merged had the same upload_large_folder kwarg-drop defect as the
LoRA helper that was fixed in 50c1db3: on huggingface_hub>=0.34, calling
upload_large_folder silently ignores commit_message, commit_description, and
create_pr. Callers passing push_to_hub_merged(..., commit_message="Release v2",
create_pr=True) got a commit titled "Upload N LFS files" on main with no PR
opened.

Unlike LoRA adapters, merged saves can be multi-GB so upload_large_folder's
chunked-resume behavior is still the right default for the common case. Gate
the route choice on whether the caller passed custom commit metadata
(commit_message / commit_description differ from the function defaults, or
create_pr=True): use upload_folder when the caller explicitly cares about
commit semantics, else default to upload_large_folder for resume-on-multi-GB
behavior. Keep the symmetric AttributeError/TypeError fallback either way.

Hoist the defaults into module-level _PUSH_MERGED_DEFAULT_* sentinels so the
"did the caller customize this" check is robust to future default-string edits.

Add tests asserting custom create_pr / commit_message routes through
upload_folder, and that the no-custom-metadata case still uses
upload_large_folder for the large-merge resume path.

* Tighten DoRA m gate + reject empty save_adapter_artifacts for PR #692

`m` is a generic 1-letter attribute name. If a future LoRA wrapper exposes
self.m as a learned mixing scalar (not a DoRA magnitude vector), the prior
hasattr(module, "m") branch would ship it under DoRA semantics. Gate the
collection on type(module).__name__.startswith("DoRA") so only real DoRA
modules contribute the magnitude tensor; today's real DoRALinear /
QDoRALinear both match.

_save_adapter_artifacts previously emitted adapter_config.json with no
adapters.safetensors next to it when tensors={}. Every public caller already
raises a clear ValueError before getting here, but enforce the invariant
locally so any future direct call cannot produce a half-written artifact
that mlx-lm reload chokes on.

Add tests: collect_mlx_lora_adapter_tensors skips an unrelated `m` attr on
a non-DoRA module; _save_adapter_artifacts raises ValueError on empty input.

* Treat custom revision as commit-metadata for push_to_hub_merged + DoRA test for PR #692

push_to_hub_merged previously dropped a custom `revision=` kwarg on the
default route: when commit_message/commit_description/create_pr were all
defaults but revision was non-None, _caller_wants_commit_metadata was
False and the upload routed through upload_large_folder, which silently
lands on main regardless of the revision arg. Add `revision is not None`
to the predicate so a custom target branch also forces upload_folder.

Add positive DoRA collect test: a class whose name starts with "DoRA"
must have its `m` magnitude tensor included in the collected adapter
keys. Without this lock, a future typo (DORA, Dora, wrong attribute
name) would silently strip DoRA magnitudes from every export and the
existing negative test would still pass.

Add push_to_hub_merged revision-routing test asserting the new predicate.

* Fix 4 P1s from reviewer.py round on PR #692

1) Full-finetune checkpoints now stamp fine_tune_type="full". When
   _enrich_mlx_adapter_config runs on a no-LoRA model, it previously
   left fine_tune_type unset. mlx-lm's load_adapters() defaults missing
   fine_tune_type to "lora" and then reads num_layers / lora_parameters,
   so the saved full-precision tensors failed to reload. The else branch
   now setdefault("fine_tune_type", "full") so reload routes correctly.

2) DoRA exports now stamp fine_tune_type="dora". The collector already
   includes the q_proj.m magnitude tensor for DoRA classes, but the
   adapter_config still said "lora". mlx-lm only recreates DoRA wrappers
   when use_dora=(fine_tune_type=="dora"), so the saved m tensor was
   silently dropped on reload via strict=False. Detect DoRA modules via
   type(module).__name__.startswith("DoRA") and override the fine_tune_type
   to "dora" before setdefault picks the lora default.

3) _is_lm_head_trainable now filters base weights inside LoRA-wrapped
   modules using the same lora_module_prefixes pattern that
   save_trainable_adapters uses. A LoRA-wrapped lm_head with a
   reload-leaked lm_head.weight previously made the function return True,
   defeating the CCE memory guard and computing a full V x H weight
   gradient per chunk. Now reload-leaked base tensors under LoRA modules
   are correctly treated as non-trainable.

4) _push_lora_adapters_to_hub now uploads with allow_patterns scoped to
   adapter / tokenizer / config / readme files. Without this filter, a
   save_directory that already contained stale model-*.safetensors,
   model.safetensors.index.json, or *.gguf files from a prior merged save
   would silently push them to the LoRA adapter repo. Public-by-default
   repos would have leaked merged weights under a "LoRA adapter" repo;
   the allow-list rules out catch-alls like "*.safetensors" / "*" so a
   future allow_patterns expansion cannot re-introduce the regression.

Also unblock the lora_parameters / top-level rank-scale-dropout backfill:
the top-level keys are now backfilled whenever lora_parameters is present,
not only when lora_parameters was computed by this function. mlx-vlm reads
the top-level keys directly; without this fix a caller-supplied
lora_parameters dict left rank/scale/dropout absent at the top level.

Also drop the num_layers=-1 sentinel here too (mirroring the trainer side
None-gating in PR #679): if _get_transformer_layers returns nothing,
num_layers is simply omitted so mlx-lm's loader raises a clear
AttributeError instead of slicing range(-1) and applying zero LoRA layers.

Tests: enrich stamps "full" on no-LoRA models, "dora" on DoRA models;
_is_lm_head_trainable returns False when only LoRA-wrapped lm_head plus
reload-leaked weight are trainable; _push_lora_adapters_to_hub passes a
restrictive allow_patterns list that excludes stale full-model artifacts.

* Preserve trainable .bias under LoRA-wrapped modules for PR #692

The reload-leak guard previously dropped EVERY trainable key under a
LoRA module prefix. That correctly excludes the wrapped base .weight
(the V x H matmul gradient that defeats the memory guard), but also
dropped other trainable params at the same path: notably q_proj.bias
when the user explicitly trained bias=True on a LoRA-wrapped Linear.

Refine save_trainable_adapters's filter to drop only `.weight` keys
under LoRA prefixes. .bias and other params under the same prefix now
survive, so a checkpoint+reload roundtrip preserves them.

The routing decision in save_pretrained_merged and MLXTrainer.save_model
follows the same refinement: previously routed to save_trainable_adapters
only when there was a trainable OUTSIDE any LoRA module. A
trainable .bias INSIDE a LoRA module (with no external trainables)
would have routed to save_lora_adapters and been dropped. Now routing
treats inside-LoRA `.weight` as the only reload-leak risk; any other
trainable (bias / external / etc.) triggers the trainable-aware writer.

Add a test that exercises a LoRA-wrapped Linear with bias=True and an
external trainable, asserting q_proj.bias survives to the safetensors
output while q_proj.weight is correctly dropped.

* Tighten LoRA filter + push routing + fine_tune_type stamps for PR #692

Several R9 reviewer.py findings, batched.

(1) Quantized base tensors leaked through the adapter writer. The
    .weight-only filter missed .scales, .biases, and their .linear.*
    variants on mlx-lm QuantizedLinear, so QLoRA reload-trainable
    layers' quantization tensors slipped into adapters.safetensors.
    Hoist the filter into a shared _is_base_tensor_inside_lora_module
    helper with an extended suffix tuple (weight / scales / biases /
    linear.weight / linear.scales / linear.biases). save_trainable_adapters,
    save_pretrained_merged routing, MLXTrainer.save_model routing, and
    _is_lm_head_trainable all use the same predicate now.

(2) Root-level LoRA wrappers (module_name == "") leaked bare `weight` /
    `scales` / `biases`. The empty-name prefix is intentionally
    excluded from lora_module_prefixes (otherwise it swallows the
    whole tree), so the filter never matched. Add a has_root_lora_module
    flag and special-case the bare-key check when one exists.

(3) Caller fine_tune_type="full" while the model has LoRA modules
    produced an internally-inconsistent artifact: saved lora_a/lora_b
    but config said full, so mlx-lm reload skipped LoRA wrapping and
    silently dropped the adapter. Override the caller to "lora" in
    that case.

(4) Full-finetune saves carried stale LoRA fields (peft_type,
    lora_parameters, rank, scale, dropout, num_layers,
    unsloth_mlx_lora_module_paths) from re-used config dicts.
    setdefault did not strip them. Now the no-LoRA branch
    unconditionally pops stale LoRA keys so reload sees a clean
    full-finetune dict.

(5) revision is not None forced upload_folder on merged pushes, losing
    the resumable/chunked path for multi-GB merged uploads.
    upload_large_folder accepts revision natively; only commit_message /
    commit_description / create_pr force upload_folder. Drop revision
    from _caller_wants_commit_metadata.

(6) create_pr=True silently fell back to upload_large_folder on
    TypeError/AttributeError from upload_folder, landing on main with
    no PR. Refuse to silently lose the PR boundary; raise a clear
    RuntimeError telling the user to upgrade huggingface_hub or call
    create_pull_request() + pass revision themselves. Applied to both
    the LoRA helper and the merged push fallback.

(7) Embedding/lm-head LR multiplier still used substring matching, so
    names like decoder.not_lm_head_router.weight or
    foo.embed_tokens_aux.weight could pick up the embedding LR.
    Anchor on path segments via name.split(".") so the multiplier
    only fires on a real `embed_tokens` or `lm_head` segment.

Update the prior revision-routing test to assert the new behavior
(upload_large_folder honors revision alone). Add tests for: quantized
base scales/biases dropped, caller fine_tune_type="full" override on
LoRA-bearing models, stale LoRA fields stripped on full-finetune,
create_pr=True failure raises instead of silent main push.

* Write minimal README before ModelCard.load so tags propagate for PR #692

Round-9 review caught that `_push_lora_adapters_to_hub` (and the merged
push) only applied caller-provided `tags=[...]` when a README.md already
existed in `save_directory`. Fresh LoRA adapter directories carry the
adapter and tokenizer files but no model card, so `ModelCard.load()`
raised `FileNotFoundError` and the requested tags silently never reached
the Hub.

Seed a minimal YAML front-matter `README.md` before loading the card when
no card exists yet, then merge the requested tags with the existing ones.
Applied to both the LoRA push and the merged push (merged saves usually
have a card from `save_merged_model()`'s `create_model_card` fallback,
but seed defensively so an upstream card-fallback failure does not
quietly drop tags either).

* Segment-match lm_head + harden commit-metadata fallback + cross-check fine_tune_type for PR #692

Round-10 review caught several asymmetric / silent-fallback bugs in the
adapter save and Hub push paths:

* `_is_lm_head_trainable` matched `'lm_head'` / `'embed_tokens.weight'`
  as substrings, so unrelated trainables like
  `decoder.not_lm_head_router.weight` or `foo.embed_tokens_aux.weight`
  were misclassified as the real LM head. Switch to segment matching to
  mirror the trainer-side LR-multiplier fix already in place.
* Both `_push_lora_adapters_to_hub` and `push_to_hub_merged` fallback
  paths previously raised only when `create_pr=True`, silently dropping
  caller-provided `commit_message` / `commit_description` when
  upload_folder() was unavailable. Refuse the fallback whenever any of
  the three commit-metadata fields was set by the caller; capture the
  intent BEFORE the default backfill so `None` vs. user-string is still
  distinguishable.
* `push_to_hub_merged(..., commit_message=None, commit_description=None)`
  forced the non-resumable upload_folder route because `None` did not
  equal the default constant. Treat explicit `None` as equivalent to
  the default so wrapper layers that forward optional kwargs do not
  accidentally lose large-folder resume semantics.
* `_enrich_mlx_adapter_config` trusted caller-supplied
  `fine_tune_type='lora'` even when the live model had no LoRA modules,
  writing LoRA fields to a full-finetune artifact and breaking reload.
  Cross-check the declared artifact against the live model so stale
  caller config always yields to ground truth.
* `_is_base_tensor_inside_lora_module` only caught bare `weight` /
  `scales` / `biases` at the root LoRA wrapper, missing the
  QuantizedLinear-wrapped variants (`linear.weight`, `linear.scales`,
  `linear.biases`) that live there too. Introduce a small
  `_ROOT_LORA_WRAPPED_BASE_KEYS` whitelist covering both shapes.

* Fail loud on private=True + derive fine_tune_type from live model + mirror lora_parameters top-level for PR #692

Round-11 review caught three remaining symmetric-fix gaps in the
adapter save and Hub push paths:

* `update_repo_settings` failure used to print-and-continue, which means
  an existing public repo stays public when the caller asked for
  `private=True`. Refuse to upload in that case so adapters never get
  published to a public Hub URL the caller explicitly tried to make
  private. Public-by-request failures still print and continue (their
  intent is unchanged).
* `_enrich_mlx_adapter_config` previously preserved caller-provided
  `fine_tune_type` when the value was `"dora"` but the live model had
  no DoRA modules (or vice versa). mlx-lm would then rebuild DoRA
  wrappers and look for a `.m` magnitude tensor the saved
  adapters.safetensors does not contain, dropping every adapter via
  strict=False. Derive `fine_tune_type` strictly from the live model
  presence of `DoRA*` modules; the caller's stale value yields to
  ground truth.
* Top-level `rank` / `scale` / `dropout` were only backfilled when
  absent, so a caller-supplied stale `rank=99` could shadow the real
  `lora_parameters.rank=4`. mirror `lora_parameters` to the top level
  verbatim so the two shapes mlx-vlm reload checks always agree.

* Harden push_to_hub_gguf private=True to match LoRA/merged push for PR #692

Round-13 Opus review caught that `push_to_hub_gguf` missed the same
visibility hardening applied to `_push_lora_adapters_to_hub` and
`push_to_hub_merged` in R11. `HfApi.create_repo(exist_ok=True)` is a
no-op for the visibility flag on existing repos, so
`model.push_to_hub_gguf(repo_id="me/existing-public-repo", private=True)`
would silently upload the GGUF shards (often multi-GB merged weights)
to a public Hub URL the caller explicitly tried to make private.

Apply the same pattern: paired `update_repo_settings(private=...)` +
hard-fail RuntimeError when `private=True` was requested and the
visibility update fails. Public-by-request failures continue to
print-and-continue.

This is the highest-impact of the three push paths because GGUF files
are typically the ones users distribute most widely, and they contain
full merged weights rather than adapters.

* Honour omitted private=None on create_repo + cover root + non-root LoRAEmbedding for PR #692

Round-15 review caught two regressions I introduced plus one
asymmetric-fix gap:

* All three create_repo sites (LoRA push, merged push, GGUF push) used
  `private=bool(private) if private is not None else False`, which
  silently forces a brand-new repo to `private=False` whenever the
  caller did not pass the kwarg. That overrides Hugging Face Hub
  organization-level default-private policies: a user inside a
  default-private org would get a public repo on first push. Build the
  create_repo kwargs conditionally so `private` is omitted unless the
  caller actually set it; the Hub then applies the account/org policy
  for initial visibility. The existing update_repo_settings + R11
  hard-fail-on-failure logic still enforces explicit `private=True`.
* `_LORA_WRAPPED_BASE_SUFFIXES` only covered `.linear.*` wrapped-base
  state. LoRAEmbedding and DoRAEmbedding wrap an inner nn.Embedding at
  `.embedding`, so non-root embedding adapters could leak
  `embed_tokens.embedding.weight` through `save_trainable_adapters`
  and `save_pretrained_merged(save_method="lora")`. Add the
  `.embedding.weight` / `.embedding.scales` / `.embedding.biases`
  suffixes alongside the existing `.linear.*` variants.
* `_ROOT_LORA_WRAPPED_BASE_KEYS` had the same gap for root-level
  LoRAEmbedding wrappers. Extend the frozenset to include
  `embedding.weight` / `embedding.scales` / `embedding.biases`.

* Drop wrapped linear.bias + reload DoRA paths with DoRA wrappers for PR #692

Fixes two asymmetric-fix gaps from round 16 review:

1. unsloth_zoo/mlx/utils.py: `_LORA_WRAPPED_BASE_SUFFIXES` and
   `_ROOT_LORA_WRAPPED_BASE_KEYS` already filtered the wrapped base
   weight / scales / biases of an mlx-lm `LoRALinear`, but missed the
   inner `linear.bias` of a bias-bearing wrapped `nn.Linear`. Without
   it, `q_proj.linear.bias` and `lm_head.linear.bias` leaked into
   adapter saves and `_is_lm_head_trainable()` reported True for
   adapter-only training. A bare `.bias` is deliberately NOT in the
   suffix tuple because `q_proj.bias` at the LoRA-module level is
   user-trained bias state that an earlier round explicitly preserved.

2. unsloth_zoo/mlx/loader.py: `_apply_lora_at_paths()` only imported
   `LoRALinear` and wrapped every saved module path as plain LoRA.
   After the PR added DoRA export (fine_tune_type='dora' + saving the
   `.m` magnitude tensor), the path-aware reload branch silently
   downgraded DoRA back to LoRA because the recreated wrapper had no
   `.m` parameter, and the downstream `model.load_weights(
   strict=False)` dropped `q_proj.m`. Honour the saved
   `fine_tune_type`: import `DoRALinear` / `DoRAEmbedding` and wrap
   with DoRA classes when the adapter declares DoRA. Cover
   `nn.Embedding` and `SwitchLinear` paths too (DoRA falls back to
   LoRA on switch modules since mlx-lm has no DoRA switch wrapper).

* Optional newer LoRA classes + verify-before-fail private guard + switch rank fix for PR #692

Fixes three asymmetric-fix / regression issues from round 17 review:

1. unsloth_zoo/mlx/loader.py: `_apply_lora_at_paths` unconditionally
   imported `LoRAEmbedding` and `LoRASwitchLinear` from
   `mlx_lm.tuner.lora`. Older mlx-lm wheels that only ship
   `LoRALinear` raised an ImportError before any wrapper was attached,
   so linear-only adapters could not reload. Wrap both imports in
   try/except and treat missing classes as "skip this module type"
   instead of failing the whole adapter load path; LoRALinear (which
   has shipped since the first mlx-lm release we support) is still
   required.

2. unsloth_zoo/mlx/utils.py: the round-16 `update_repo_settings` block
   always raised on private=True when the call failed, even though
   `create_repo(private=True)` had already set visibility on freshly-
   created repos. A token without `write:repo_settings` then blocked
   the upload even though the repo was already private. Refactor the
   three duplicated blocks into `_ensure_hub_repo_visibility(api,
   repo_id, private)` which (a) skips the update when private is None,
   (b) prints-and-continues on private=False, (c) on private=True
   verifies via `repo_info` whether the repo is already private after
   the failed update; only raises when the visibility is confirmed
   non-private.

3. unsloth_zoo/mlx/utils.py: `_extract_mlx_lora_parameters` inferred
   switch LoRA rank from `lora_a.shape[-1]` for 2-D layouts, but
   `mlx-lm==0.22.x` stores switch `lora_a` as
   `(rank * num_experts, in_dims)`, returning ranks like 64 or 4096
   instead of 4 or 8. Prefer `lora_b.shape[-1]` when the module
   declares `num_experts` since both layouts (newer mlx-lm and 0.22.x)
   keep rank as the trailing axis of `lora_b`. Plain LoRALinear keeps
   the existing `shape[-1]` fallback.

* Re-raise DoRA-unavailable + restore .linear.bias as trainable + drop stale adapter_model.safetensors + overwrite stale lora_parameters for PR #692

Fixes four round 18 findings:

1. unsloth_zoo/mlx/loader.py: the path-aware reload site caught every
   Exception from `_apply_lora_at_paths()` and fell back to upstream
   `load_adapters()`. That fallback rebuilds plain LoRALinear wrappers
   at every saved path, so the explicit `RuntimeError` raised when
   `fine_tune_type='dora'` but `mlx_lm.tuner.dora` is unavailable was
   silently swallowed and DoRA `*.m` magnitude tensors dropped via
   `load_weights(strict=False)`. Catch RuntimeError separately and
   re-raise when the message contains the DoRA-unavailable signal,
   then continue catching every other Exception.

2. unsloth_zoo/mlx/utils.py: round 16 added `.linear.bias` to
   `_LORA_WRAPPED_BASE_SUFFIXES` and `_ROOT_LORA_WRAPPED_BASE_KEYS`
   to drop the wrapped base bias of LoRALinear. That was wrong:
   upstream mlx-lm LoRALinear also stores the legitimate user-trained
   bias at `q_proj.linear.bias`, and dropping it silently loses
   training state when the user unfreezes / trains bias. Remove
   `.linear.bias` from both filter sets and document why bare
   `.bias` and `.linear.bias` must survive trainable checkpoints.

3. unsloth_zoo/mlx/utils.py: the LoRA upload allow-list still
   permitted `adapter_model.safetensors`. The MLX adapter save path
   writes `adapters.safetensors`, never `adapter_model.safetensors`;
   a matching file in the directory is therefore stale by definition
   (PEFT/HF leftover) and would re-upload that stale PEFT adapter
   alongside the current MLX adapter. Drop it from the allow-list.

4. unsloth_zoo/mlx/utils.py: `_enrich_mlx_adapter_config` only wrote
   `lora_parameters` from `_extract_mlx_lora_parameters` when the
   caller dict did not already have it. A reused config with stale
   `lora_parameters={"rank": 99}` then survived against actual
   rank-4 saved tensors, causing reload wrapper mismatch. When the
   live model has any LoRA modules, derive rank/scale/dropout from
   the live walker and overwrite caller-supplied stale values; only
   preserve caller `lora_parameters` when the live model has no
   LoRA modules (declared_lora_artifact-only path).

* Let namespaced Unsloth MLX RuntimeError propagate past the outer adapter-detection except for PR #692

Round 18 added a DoRA-unavailable RuntimeError re-raise inside the
inner _apply_lora_at_paths handler so the saved DoRA `*.m` magnitude
tensors would not silently drop into a plain-LoRA reload. The re-raise
worked, but the outer adapter-detection `except Exception as e:` block
at loader.py only re-raised ValueError("Unsloth:") and swallowed every
other exception into a print + standard-load fallback. The DoRA
RuntimeError was therefore re-raised by the inner handler and then
immediately caught and silenced by the outer handler, falling back to
the base-only load path with no warning that fine_tune_type='dora'
was unsupported.

Extend the outer guard to also re-raise namespaced
`RuntimeError("Unsloth MLX:")` so the inner handler's intentional
re-raise reaches the user; unrelated runtime errors keep falling
through to the soft standard-load fallback.

* Navigate list-indexed module paths in _apply_lora_at_paths for PR #692

`_apply_lora_at_paths` re-attached LoRA wrappers via
`setattr(parent, leaf, wrapped)` and `parent = by_name.get(parent_path,
model)`. That silently no-ops when the saved path ends in a numeric
segment such as `vision_tower.merger.layers.0`, which the training-time
`_lora_walk_module` deliberately produces for Qwen2.5-VL's vision merger
and any projector whose Linears live inside a Python list:

- `name.rpartition(".")` -> `parent_path="vision_tower.merger.layers"`,
  `leaf="0"`.
- `by_name.get(parent_path, model)` returns `model` (the fallback)
  because `named_modules()` does not emit list containers themselves.
- `hasattr(model, "0")` is False, so the `setattr` branch is skipped
  silently. The base nn.Linear stays in place.
- The follow-up `load_weights(strict=False)` then silently drops the
  saved `...layers.0.lora_{a,b}` tensors and the user gets a
  "successful" reload with a fully reverted vision/projector LoRA
  and no warning.

Mirror the navigation pattern from `_lora_walk_module`: split the saved
path, walk segments trying `parent[int(seg)]` first then `getattr` for
attribute access, and apply the same `parent[int(leaf)] = wrapped`
fallback to `setattr` for the final segment. List-indexed wrappers now
install correctly and the subsequent `load_weights(strict=False)`
binds the saved tensors instead of dropping them.

* Guard DoRA fallback + fail-loud LoRAEmbedding + treat default commit string as default for PR #692

Three asymmetric-fix gaps caught by reviewers.

1) FastMLXModel.from_pretrained's adapter-detection inner block had a DoRA
   capability check inside the saved-paths branch (via _apply_lora_at_paths
   raising RuntimeError when mlx_lm.tuner.dora is missing), but the no-
   saved-paths fallback fell straight through to load_adapters(model,
   local_path) without that check. A DoRA adapter without
   unsloth_mlx_lora_module_paths would silently rebuild plain LoRA wrappers,
   then drop the .m magnitude tensors via the strict=False reload. Add the
   same capability check before the fallback load_adapters call so the user
   gets the namespaced "install a DoRA-capable mlx-lm" error.

2) _apply_lora_at_paths silently continued past embedding LoRA paths when
   the installed mlx-lm predated LoRAEmbedding. The saved embed_tokens
   adapter tensors then dropped via the downstream strict=False reload with
   no warning, leaving the user with a partially loaded adapter. Raise a
   namespaced RuntimeError instead, parallel to the DoRA-unavailable guard.

3) _push_lora_adapters_to_hub treated commit_message is not None as
   "caller wants custom metadata", which made the upload_large_folder
   fallback refuse uploads even when the caller forwarded the Unsloth
   default string verbatim (push_to_hub_lora -> _push_lora_adapters_to_hub
   pattern). Compare against the default strings the same way
   _push_merged_model_to_hub already does, so default-string forwarding no
   longer blocks the fallback.

* Broaden inner-handler re-raise to match LoRAEmbedding fail-loud

The fail-loud RuntimeError raised inside _apply_lora_at_paths when
mlx_lm.tuner.lora.LoRAEmbedding is unavailable was being swallowed by
the inner except clause's narrow DoRA-only re-raise filter, and falling
through to load_adapters() silently dropped the saved
embed_tokens.lora_a / lora_b tensors via the strict=False reload.

Broaden the inner filter to re-raise any "Unsloth MLX:" namespaced
RuntimeError so the embedding-LoRA capability gap surfaces to the
caller the same way the DoRA capability gap does. Mirrors the outer
handler's broader check directly below.

* Tighten verbose comments in mlx adapter export code

Comments-only refactor: drop pure narration and compress multi-line WHY
explanations to a single line per intent. No behavior change.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants