fix(mlx): persist LoRA adapter metadata on save#23
Conversation
Fixes a "The size of tensor a (N) must match the size of tensor b (M)"
RuntimeError in gpt-oss inference when generation crosses the sliding
window (e.g. short prompt with max_new_tokens that pushes the sequence
past 128 tokens) on transformers 5.x running below its required torch
(< 2.11). In that case the KV cache returns more positions for a
full-attention layer than the causal mask covers (pre-allocated cache
slots), so the eager path's `attn_weights += attention_mask` crashes.
The surplus key positions are masked out anyway, so attend only the
overlap: trim key/value (or the mask) to the shorter length in both
eager attention variants. This keeps inference correct and shape
consistent across transformers 4.57.x / 5.5.x and torch 2.9 - 2.11.
Also drop the dead singular `past_key_value` cache_position-free forward
variant: the singular naming predates transformers dropping
`cache_position`, so that signature combination does not exist in any
release. Only the `past_key_values` variant is reachable.
Verified on gpt-oss-20b across the full matrix (transformers
{4.57.6, 5.5.0} x torch {2.9.1, 2.10.0, 2.11.0}): all reasoning efforts
generate coherent output, and greedy continuations match across torch
versions for a given transformers version.
Reload + metadata follow-ups from review feedback: 1. loader._apply_lora_at_paths() now also recreates LoRASwitchLinear for SwitchLinear / QuantizedSwitchLinear and LoRAEmbedding for Embedding / QuantizedEmbedding. The previous version saved switch rank metadata but the reload helper only wrapped Linear, so MoE adapter weights were silently dropped at load time. Switch/embedding imports are wrapped in try/except for older mlx-lm. 2. utils._get_mlx_dropout_probability() now reads MLX _p_1 keep-prob first and falls back to .p. Compat shims that expose both (.p=0.0 default plus a real _p_1) previously wrote stale dropout=0.0. 3. utils._infer_mlx_lora_rank() returns None instead of falling back to lora_a.shape[-1] when lora_a and lora_b disagree on the rank dimension. The previous fallback wrote the input dim as rank for partially materialized or non-LoRA shapes. 4. utils._enrich_mlx_adapter_config() only infers rank/scale/dropout from modules that appear in the caller-provided unsloth_mlx_lora_module_paths set. Previously an unrelated earlier LoRA in named_modules() would write the wrong language-tower params when the caller selected a vision/projector path. Also distinguishes "caller passed nothing" from "caller passed [] or None" so explicit empty values are preserved. 5. utils._enrich_mlx_adapter_config() now fills num_layers (derived via _get_transformer_layers) and fine_tune_type when callers invoke save_lora_adapters() without a trainer-built config. mlx-lm load_adapters() dereferences config.num_layers and previously crashed with AttributeError on these direct-save artifacts.
The 3D rank-consistency check in _infer_mlx_lora_rank() used lora_b_shape[-2] which is the out_features axis. mlx-lm's LoRASwitchLinear stores lora_b as (num_experts, out_dims, rank) so rank lives on lora_b_shape[-1]. Using [-2] caused valid MoE adapters to be rejected as contradicted and rank to be left unrecorded in adapter_config.json.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request expands MLX LoRA support to include MoE switch layers and embedding LoRA, while improving adapter reload consistency by explicitly capturing rank, scale, and dropout metadata. It also implements a fix for KV cache and attention mask alignment within the gpt_oss patches to resolve potential shape mismatches. The review feedback suggests enhancing the robustness of attention mask slicing to avoid potential indexing errors and recommends removing redundant type conversions in the dropout probability calculation.
| key_states = key_states[:, :, :masklen, :] | ||
| value_states = value_states[:, :, :masklen, :] | ||
| elif masklen > kvlen: | ||
| attention_mask = attention_mask[:, :, :, :kvlen] |
There was a problem hiding this comment.
The slicing attention_mask[:, :, :, :kvlen] assumes a 4D tensor. While this is standard for causal masks in most Transformers models using SDPA, it will raise an IndexError if a 2D or 3D mask is passed. Given that this is a specific patch for GptOssAttention, if the model is guaranteed to use 4D masks in this path, it is acceptable, but a more robust slicing approach would be safer.
| # real MLX nn.Dropout stores keep-prob as _p_1; compat shims may set | ||
| # both .p (often a stale 0.0 default) and _p_1, so _p_1 must win. | ||
| if hasattr(drop, "_p_1"): | ||
| return float(1.0 - float(getattr(drop, "_p_1"))) |
There was a problem hiding this comment.
The conversion to float is applied twice here. While harmless, it's redundant. Additionally, ensure that getattr(drop, "_p_1") is indeed a numeric type or an MLX scalar array to avoid potential TypeError during the subtraction.
| return float(1.0 - float(getattr(drop, "_p_1"))) | |
| return 1.0 - float(getattr(drop, "_p_1")) |
- _enrich_mlx_adapter_config: an explicit empty unsloth_mlx_lora_module_paths list now preserves the caller topology but no longer suppresses live rank/scale/dropout inference (treat empty as no filter, with a fallback module captured before filtering). - MLXTrainer.save_model: skip lora_a-bearing modules whose rank cannot be inferred instead of silently falling back to rank 8 while still copying the bad module scale and dropout. - _apply_lora_at_paths: when from_base does not accept scale/dropout kwargs on older mlx-lm, retry with r=rank alone and restore the saved scale on the wrapped layer (scale is a Python attribute, not loaded by load_weights).
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request expands MLX LoRA support to include MoE and Embedding layers, improves the inference and persistence of LoRA parameters (rank, scale, and dropout) during model saving and reloading, and ensures models are set to evaluation mode post-reload. It also introduces a fix for attention mask shape mismatches in gpt_oss patches to improve compatibility with newer Transformers and Torch versions. I have no feedback to provide.
- _enrich_mlx_adapter_config: capture the fallback rank/scale/dropout only from modules that pass the explicit_path_set filter, so an explicit caller-supplied path list whose modules are not inferable no longer borrows metadata from unrelated LoRA layers. Auto-discovery and empty explicit-list paths still benefit from the fallback unchanged. - _apply_lora_at_paths: when from_base() cannot accept the dropout kwarg on older mlx-lm, patch wrapped.dropout._p_1 (or .p on shims) after the fallback so the saved adapter dropout is faithfully restored alongside the scale that was already being patched.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request expands MLX LoRA support to include MoE switch and embedding layers, while improving the persistence of LoRA parameters like rank, scale, and dropout during model saving and reloading. It also introduces a fix for shape mismatches between KV caches and attention masks in eager attention patches. Review feedback identified a potential AttributeError in the rank inference utility and a logic error where fallback parameter assignments were rendered unreachable due to their placement after a filtering check.
| lora_a_shape = tuple(getattr(lora_a, "shape", ()) or ()) | ||
| lora_b_shape = tuple(getattr(lora_b, "shape", ()) or ()) |
There was a problem hiding this comment.
The use of getattr(lora_a, "shape", ...) will raise an AttributeError if lora_a is None. While current callers check for the attribute's existence, getattr's default value only handles cases where the attribute is missing from the object, not when the object itself is None. It is safer to explicitly check for None to avoid potential crashes in the rank inference helper.
| lora_a_shape = tuple(getattr(lora_a, "shape", ()) or ()) | |
| lora_b_shape = tuple(getattr(lora_b, "shape", ()) or ()) | |
| lora_a_shape = lora_a.shape if lora_a is not None else () | |
| lora_b_shape = lora_b.shape if lora_b is not None else () |
| if explicit_path_set is not None and name not in explicit_path_set: | ||
| continue | ||
| if fallback_rank is None: | ||
| fallback_rank = inferred_rank | ||
| fallback_scale = float(getattr(module, "scale", 1.0)) | ||
| fallback_dropout = _get_mlx_dropout_probability( | ||
| getattr(module, "dropout", None) | ||
| ) |
There was a problem hiding this comment.
The fallback_rank logic is currently redundant because it is assigned after the explicit_path_set filter. This means fallback_rank will only be non-None if lora_rank is also set in the same iteration, making the check at line 2808 unreachable. To correctly implement a fallback for when the user-provided filter matches no modules, the fallback assignment should be moved above the filter check.
if fallback_rank is None:
fallback_rank = inferred_rank
fallback_scale = float(getattr(module, "scale", 1.0))
fallback_dropout = _get_mlx_dropout_probability(
getattr(module, "dropout", None)
)
# only infer rank/scale/dropout from modules the caller
# actually selected; otherwise an earlier unrelated LoRA
# would write the wrong language-tower params.
if explicit_path_set is not None and name not in explicit_path_set:
continue- _infer_mlx_lora_rank: replace getattr(lora_a, 'shape', ()) with explicit None and hasattr guards so a None tensor attribute can no longer raise AttributeError on callers that share this helper outside the existing hasattr-gated sites. - _enrich_mlx_adapter_config: remove fallback_rank/scale/dropout variables and the post-loop rescue. After the explicit-path filter was moved ahead of fallback capture, both fallback_* and lora_rank are assigned from the same filtered module pool, making the rescue unreachable. The scoping decision is intentional: an explicit-filtered save must not borrow metadata from unselected modules.
|
Fixes pushed to unslothai#679. |
Staging mirror of unslothai#679
Original PR: unslothai#679
Author: Lyxot
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Fix MLX LoRA adapter saves so
adapter_config.jsoncontains the metadata needed to recreate adapters faithfully on reload.rank,scale, anddropoutinto bothlora_parametersand top-level compatibility fields.por MLX's_p_1keep-probability storage(num_experts, r, input_dims)unsloth_mlx_lora_module_pathsinstead of overwriting caller-provided topology metadataMLXTrainer.save_modelWhy
Several MLX save paths call
save_lora_adapters()without a complete explicit adapter config. Missing or stale LoRA metadata can make reload reconstruct adapters with different rank, scale, or dropout settings, which changes post-reload behavior.This also fixes two edge cases:
Dropoutstores keep probability as_p_1, so reading only.pwrites staledropout=0.0.LoRASwitchLinearstores rank onlora_a.shape[-2], notshape[-1].