Skip to content

fix(mlx): persist LoRA adapter metadata on save#23

Closed
danielhanchen wants to merge 13 commits into
mainfrom
pr-679-head
Closed

fix(mlx): persist LoRA adapter metadata on save#23
danielhanchen wants to merge 13 commits into
mainfrom
pr-679-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Staging mirror of unslothai#679

Original PR: unslothai#679
Author: Lyxot

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Fix MLX LoRA adapter saves so adapter_config.json contains the metadata needed to recreate adapters faithfully on reload.

  • Persist live LoRA rank, scale, and dropout into both lora_parameters and top-level compatibility fields
  • Read MLX dropout correctly from either .p or MLX's _p_1 keep-probability storage
  • Infer LoRA rank from the correct tensor axis, including switch/MoE LoRA tensors shaped like (num_experts, r, input_dims)
  • Preserve explicit unsloth_mlx_lora_module_paths instead of overwriting caller-provided topology metadata
  • Re-apply eval mode after adapter reload so nonzero-dropout adapters do not run dropout during default inference
  • Share the same rank/dropout inference helpers between adapter config enrichment and MLXTrainer.save_model

Why

Several MLX save paths call save_lora_adapters() without a complete explicit adapter config. Missing or stale LoRA metadata can make reload reconstruct adapters with different rank, scale, or dropout settings, which changes post-reload behavior.

This also fixes two edge cases:

  • MLX Dropout stores keep probability as _p_1, so reading only .p writes stale dropout=0.0.
  • LoRASwitchLinear stores rank on lora_a.shape[-2], not shape[-1].

Lyxot and others added 10 commits May 20, 2026 01:31
Fixes a "The size of tensor a (N) must match the size of tensor b (M)"
RuntimeError in gpt-oss inference when generation crosses the sliding
window (e.g. short prompt with max_new_tokens that pushes the sequence
past 128 tokens) on transformers 5.x running below its required torch
(< 2.11). In that case the KV cache returns more positions for a
full-attention layer than the causal mask covers (pre-allocated cache
slots), so the eager path's `attn_weights += attention_mask` crashes.

The surplus key positions are masked out anyway, so attend only the
overlap: trim key/value (or the mask) to the shorter length in both
eager attention variants. This keeps inference correct and shape
consistent across transformers 4.57.x / 5.5.x and torch 2.9 - 2.11.

Also drop the dead singular `past_key_value` cache_position-free forward
variant: the singular naming predates transformers dropping
`cache_position`, so that signature combination does not exist in any
release. Only the `past_key_values` variant is reachable.

Verified on gpt-oss-20b across the full matrix (transformers
{4.57.6, 5.5.0} x torch {2.9.1, 2.10.0, 2.11.0}): all reasoning efforts
generate coherent output, and greedy continuations match across torch
versions for a given transformers version.
Reload + metadata follow-ups from review feedback:

1. loader._apply_lora_at_paths() now also recreates LoRASwitchLinear
   for SwitchLinear / QuantizedSwitchLinear and LoRAEmbedding for
   Embedding / QuantizedEmbedding. The previous version saved switch
   rank metadata but the reload helper only wrapped Linear, so MoE
   adapter weights were silently dropped at load time. Switch/embedding
   imports are wrapped in try/except for older mlx-lm.

2. utils._get_mlx_dropout_probability() now reads MLX _p_1 keep-prob
   first and falls back to .p. Compat shims that expose both (.p=0.0
   default plus a real _p_1) previously wrote stale dropout=0.0.

3. utils._infer_mlx_lora_rank() returns None instead of falling back
   to lora_a.shape[-1] when lora_a and lora_b disagree on the rank
   dimension. The previous fallback wrote the input dim as rank for
   partially materialized or non-LoRA shapes.

4. utils._enrich_mlx_adapter_config() only infers rank/scale/dropout
   from modules that appear in the caller-provided
   unsloth_mlx_lora_module_paths set. Previously an unrelated earlier
   LoRA in named_modules() would write the wrong language-tower params
   when the caller selected a vision/projector path. Also distinguishes
   "caller passed nothing" from "caller passed [] or None" so explicit
   empty values are preserved.

5. utils._enrich_mlx_adapter_config() now fills num_layers (derived
   via _get_transformer_layers) and fine_tune_type when callers invoke
   save_lora_adapters() without a trainer-built config. mlx-lm
   load_adapters() dereferences config.num_layers and previously
   crashed with AttributeError on these direct-save artifacts.
The 3D rank-consistency check in _infer_mlx_lora_rank() used
lora_b_shape[-2] which is the out_features axis. mlx-lm's
LoRASwitchLinear stores lora_b as (num_experts, out_dims, rank)
so rank lives on lora_b_shape[-1]. Using [-2] caused valid MoE
adapters to be rejected as contradicted and rank to be left
unrecorded in adapter_config.json.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request expands MLX LoRA support to include MoE switch layers and embedding LoRA, while improving adapter reload consistency by explicitly capturing rank, scale, and dropout metadata. It also implements a fix for KV cache and attention mask alignment within the gpt_oss patches to resolve potential shape mismatches. The review feedback suggests enhancing the robustness of attention mask slicing to avoid potential indexing errors and recommends removing redundant type conversions in the dropout probability calculation.

key_states = key_states[:, :, :masklen, :]
value_states = value_states[:, :, :masklen, :]
elif masklen > kvlen:
attention_mask = attention_mask[:, :, :, :kvlen]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing attention_mask[:, :, :, :kvlen] assumes a 4D tensor. While this is standard for causal masks in most Transformers models using SDPA, it will raise an IndexError if a 2D or 3D mask is passed. Given that this is a specific patch for GptOssAttention, if the model is guaranteed to use 4D masks in this path, it is acceptable, but a more robust slicing approach would be safer.

Comment thread unsloth_zoo/mlx/utils.py
# real MLX nn.Dropout stores keep-prob as _p_1; compat shims may set
# both .p (often a stale 0.0 default) and _p_1, so _p_1 must win.
if hasattr(drop, "_p_1"):
return float(1.0 - float(getattr(drop, "_p_1")))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The conversion to float is applied twice here. While harmless, it's redundant. Additionally, ensure that getattr(drop, "_p_1") is indeed a numeric type or an MLX scalar array to avoid potential TypeError during the subtraction.

Suggested change
return float(1.0 - float(getattr(drop, "_p_1")))
return 1.0 - float(getattr(drop, "_p_1"))

- _enrich_mlx_adapter_config: an explicit empty unsloth_mlx_lora_module_paths
  list now preserves the caller topology but no longer suppresses live
  rank/scale/dropout inference (treat empty as no filter, with a fallback
  module captured before filtering).
- MLXTrainer.save_model: skip lora_a-bearing modules whose rank cannot be
  inferred instead of silently falling back to rank 8 while still copying
  the bad module scale and dropout.
- _apply_lora_at_paths: when from_base does not accept scale/dropout
  kwargs on older mlx-lm, retry with r=rank alone and restore the saved
  scale on the wrapped layer (scale is a Python attribute, not loaded by
  load_weights).
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request expands MLX LoRA support to include MoE and Embedding layers, improves the inference and persistence of LoRA parameters (rank, scale, and dropout) during model saving and reloading, and ensures models are set to evaluation mode post-reload. It also introduces a fix for attention mask shape mismatches in gpt_oss patches to improve compatibility with newer Transformers and Torch versions. I have no feedback to provide.

- _enrich_mlx_adapter_config: capture the fallback rank/scale/dropout only
  from modules that pass the explicit_path_set filter, so an explicit
  caller-supplied path list whose modules are not inferable no longer
  borrows metadata from unrelated LoRA layers. Auto-discovery and empty
  explicit-list paths still benefit from the fallback unchanged.
- _apply_lora_at_paths: when from_base() cannot accept the dropout kwarg
  on older mlx-lm, patch wrapped.dropout._p_1 (or .p on shims) after the
  fallback so the saved adapter dropout is faithfully restored alongside
  the scale that was already being patched.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request expands MLX LoRA support to include MoE switch and embedding layers, while improving the persistence of LoRA parameters like rank, scale, and dropout during model saving and reloading. It also introduces a fix for shape mismatches between KV caches and attention masks in eager attention patches. Review feedback identified a potential AttributeError in the rank inference utility and a logic error where fallback parameter assignments were rendered unreachable due to their placement after a filtering check.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2682 to +2683
lora_a_shape = tuple(getattr(lora_a, "shape", ()) or ())
lora_b_shape = tuple(getattr(lora_b, "shape", ()) or ())

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of getattr(lora_a, "shape", ...) will raise an AttributeError if lora_a is None. While current callers check for the attribute's existence, getattr's default value only handles cases where the attribute is missing from the object, not when the object itself is None. It is safer to explicitly check for None to avoid potential crashes in the rank inference helper.

Suggested change
lora_a_shape = tuple(getattr(lora_a, "shape", ()) or ())
lora_b_shape = tuple(getattr(lora_b, "shape", ()) or ())
lora_a_shape = lora_a.shape if lora_a is not None else ()
lora_b_shape = lora_b.shape if lora_b is not None else ()

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2794 to +2801
if explicit_path_set is not None and name not in explicit_path_set:
continue
if fallback_rank is None:
fallback_rank = inferred_rank
fallback_scale = float(getattr(module, "scale", 1.0))
fallback_dropout = _get_mlx_dropout_probability(
getattr(module, "dropout", None)
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback_rank logic is currently redundant because it is assigned after the explicit_path_set filter. This means fallback_rank will only be non-None if lora_rank is also set in the same iteration, making the check at line 2808 unreachable. To correctly implement a fallback for when the user-provided filter matches no modules, the fallback assignment should be moved above the filter check.

                if fallback_rank is None:
                    fallback_rank = inferred_rank
                    fallback_scale = float(getattr(module, "scale", 1.0))
                    fallback_dropout = _get_mlx_dropout_probability(
                        getattr(module, "dropout", None)
                    )

                # only infer rank/scale/dropout from modules the caller
                # actually selected; otherwise an earlier unrelated LoRA
                # would write the wrong language-tower params.
                if explicit_path_set is not None and name not in explicit_path_set:
                    continue

- _infer_mlx_lora_rank: replace getattr(lora_a, 'shape', ()) with explicit
  None and hasattr guards so a None tensor attribute can no longer raise
  AttributeError on callers that share this helper outside the existing
  hasattr-gated sites.
- _enrich_mlx_adapter_config: remove fallback_rank/scale/dropout variables
  and the post-loop rescue. After the explicit-path filter was moved ahead
  of fallback capture, both fallback_* and lora_rank are assigned from the
  same filtered module pool, making the rescue unreachable. The scoping
  decision is intentional: an explicit-filtered save must not borrow
  metadata from unselected modules.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

Fixes pushed to unslothai#679.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants