saving: layout-aware MoE LoRA merge + loud-fail on fallback (#5410)#647
Conversation
`save_pretrained_merged(save_method="merged_16bit")` silently dropped the
entire MoE expert LoRA delta on Qwen3-MoE / Qwen3.5-MoE-style models with
peft >= 0.19.1. The per-expert helpers in `saving_utils.py` hardcoded the
PEFT 0.18 "swapped" tensor layout (`lora_A: (E*r, 2I)`, `lora_B: (H, E*r)`
for gate_up_proj; `lora_A: (E*r, H)`, `lora_B: (I, E*r)` for down_proj),
while PEFT 0.19+ swaps in/out features for non-transposed 3D parameters
and produces `lora_A: (E*r, H)`, `lora_B: (2I, E*r)` and `lora_A: (E*r, I)`,
`lora_B: (H, E*r)`. The layout mismatch hit a bare `except Exception: return W`
and the dim-heuristic fallthrough in the fused helpers, so the merge
silently wrote unmodified base weights and reported success. The
`num_experts` value used by the per-expert loop was also taken from the
shard-local key scan, which is a non-divisor of `total_rank` whenever
experts are split across multiple safetensor shards (16/17 of 128 on
Qwen3-30B-A3B). Finally the merged dir was missing `generation_config.json`,
so chat-tuned models reloaded with default eos / sampling and ran past EOS.
Changes:
- `_detect_moe_lora_layout(lora_A, lora_B, num_experts, out_dim, in_dim)`
classifies the layout by shape against the per-expert disk weight, so
no version sniffing is required. Works on transformers 4.57.x / 5.x
and peft 0.18.x / 0.19.x.
- `_merge_moe_gate_or_up_expert` and `_merge_moe_down_proj_expert`
branch on the detected layout. The "swapped" path is byte-identical
to the previous behaviour.
- `_resolve_num_experts_from_lora_stats` walks `module -> base_layer ->
...` to read the authoritative `num_experts` off the wrapped MoE
module (`Qwen3MoeExperts` etc). `_merge_and_overwrite_lora` uses it
to override `moe_num_experts[prefix]` after the converted-key build,
so the per-expert loop never trips on a shard-local count.
- `_MOE_MERGE_STATE` tracks `(attempted, applied, fallback, first_error)`.
Each helper records a fallback with role / expert / shapes / reason
on any unrecognised layout or exception. After the shard loop
`merge_and_overwrite_lora` raises `RuntimeError` if any fallback
fired, so partially-merged checkpoints can no longer be silently
written. On success it prints `applied/attempted`.
- The `merged_16bit` branch also calls
`model.generation_config.save_pretrained(save_directory)` (best-effort,
matching `fix_tokenizer_config_json`).
Tests:
- Existing 16 per-expert / fused / dense merge tests in
`test_unsloth_zoo_lora_merge.py` still pass byte-for-byte (PEFT 0.18
swapped layout is the default branch).
- 6 new tests:
* standard layout for `_merge_moe_gate_expert`, `_merge_moe_up_expert`,
`_merge_moe_down_proj_expert`,
* layout classifier for both conventions and the unknown cases,
* fallback counter increments and `first_error` populates on
unrecognised shapes,
* `_resolve_num_experts_from_lora_stats` walks the `base_layer` chain.
End-to-end verification on Qwen3-30B-A3B (128 experts x 48 layers,
fused 3D in memory, per-expert 2D on disk), full SFT + save + reload
+ logit compare:
| transformers | peft | trl | merged tensors | trained vs merged KL | samples |
|--------------|--------|--------|----------------|----------------------|---------|
| 5.5.0 | 0.19.1 | 0.25.1 | 18432 / 18432 | 1.6e-5 | 3 / 3 |
| 5.5.0 | 0.18.1 | 0.25.1 | 18432 / 18432 | 1.3e-5 | 3 / 3 |
| 4.57.6 | 0.19.1 | 0.25.1 | dense path | 5.5e-5 | 3 / 3 |
| 5.5.0 | 0.19.1 | 1.4.0 | 18432 / 18432 | 2.1e-4 | 3 / 3 |
Before the patch the M1 row was KL=1.86, samples=1/3, and 0/18432 expert
LoRA deltas were applied. transformers 4.57.6 has `experts = nn.ModuleList
(Qwen3MoeMLP)` (no fused 3D parameter) so the MoE merge helpers do not
fire and every per-expert Linear takes the standard dense `_merge_lora`
path. The MoE helpers are unreachable on transformers <5; the patch only
affects the path that produces the bug.
Fixes unslothai/unsloth#5410. Likely also resolves unslothai/unsloth#4832
(same author, same "garbage after save_pretrained_merged reload" symptom
on DevStral Small 2).
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0a17707b12
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if _MOE_MERGE_STATE["fallback"] > 0: | ||
| err = _MOE_MERGE_STATE.get("first_error") or {} | ||
| raise RuntimeError( |
There was a problem hiding this comment.
Check MoE fallbacks before uploading merged shards
When push_to_hub=True and a MoE layout fallback is recorded, this new RuntimeError is reached only after the upload step at lines 2442-2448 (and the low-disk path uploads shards inside the merge loop), so the function can still publish a checkpoint with base weights written through before reporting failure. That defeats the loud-fail guard for the exact corrupt-result case it is meant to prevent; run this fallback check before any upload/delete, or raise as soon as a shard records a fallback.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request enhances MoE LoRA merging by supporting both "swapped" and "standard" PEFT layouts, introducing robust error tracking for expert merges, and ensuring authoritative expert counts are resolved across shards. It also adds logic to persist generation_config.json and includes comprehensive tests for these improvements. Feedback highlights that the fallback logic for calculating num_experts in _merge_moe_gate_or_up_expert and _merge_moe_down_proj_expert is flawed because it attempts to access a non-existent rank attribute on the LoraStats object, which would result in incorrect slicing for LoRA ranks greater than 1.
| num_experts = total_rank // max(1, getattr(lora_stats, "rank", 0) or 1) | ||
| if num_experts <= 0 or total_rank % num_experts != 0: | ||
| return up_W | ||
| num_experts = lora_stats.lora_A.shape[0] // max(1, getattr(lora_stats, "rank", 0) or 1) |
There was a problem hiding this comment.
The LoraStats dataclass (defined at line 278) does not have a rank attribute. Consequently, getattr(lora_stats, "rank", 0) will always return 0, and the fallback logic will default to rank=1. This will result in an incorrect num_experts calculation (setting it to total_rank) when the actual LoRA rank is greater than 1, which in turn leads to incorrect slicing of the LoRA matrices during the merge. Since _resolve_num_experts_from_lora_stats is intended to be the authoritative source, this fallback should be more robust or the rank should be explicitly stored in LoraStats during the statistics collection phase.
| if num_experts is None or num_experts <= 0: | ||
| num_experts = total_rank // max(1, getattr(lora_stats, "rank", 0) or 1) | ||
| if num_experts <= 0 or total_rank % num_experts != 0: | ||
| num_experts = lora_stats.lora_A.shape[0] // max(1, getattr(lora_stats, "rank", 0) or 1) |
There was a problem hiding this comment.
Similar to the issue in _merge_moe_gate_or_up_expert, getattr(lora_stats, "rank", 0) will always return 0 because the LoraStats dataclass lacks a rank field. This makes the fallback logic for num_experts incorrect for any LoRA with rank > 1. Consider using a more reliable heuristic to determine the rank if it's not explicitly available, or ensure it is captured in the LoraStats object.
The base_layer walk in _resolve_num_experts_from_lora_stats was an unbounded `while module is not None` loop. PEFT's ParamWrapper does not self-reference in practice, but a self-referential or cyclic `base_layer` chain would hang the merge. Bound the walk to 16 hops, dedupe via an id() set, and swallow exceptions on getattr / getattr-of-attrs so a hostile module that raises on attribute access cannot abort the merge. Confirmed by a synthetic suite (52 cases) across three isolated venvs: peft 0.18.1 + transformers 5.5.0, peft 0.19.1 + transformers 5.5.0, peft 0.19.1 + transformers 4.57.6. All 22 existing merge tests still pass byte-for-byte in each.
|
Sandbox simulation report against this branch (HEAD Three isolated
The simulation also caught one latent risk before the merge could land: the End-to-end on Qwen3-30B-A3B across the four (transformers, peft, trl) combinations in the PR description remains unchanged: 18432 / 18432 expert tensors merged, trained vs merged KL between 1.3e-5 and 2.1e-4, greedy samples 3 / 3. |
…ards (#5410) unsloth#5410 was a class of silent-write bug in the save_pretrained_merged path that the existing CI matrix could not detect because the merge-helper tests were not wired through the upstream-drift suite. The full fix lives in unslothai/unsloth-zoo#647 (layout-aware MoE merge helpers, authoritative num_experts resolver, loud-fail counter, generation_config.json save). This PR adds the unsloth-side canary that watches for the four guards staying in place in unsloth-zoo so a future refactor cannot silently regress them. tests/version_compat/test_unsloth_zoo_save_merged_pinned_symbols.py fetches unsloth_zoo/saving_utils.py + tests/test_unsloth_zoo_lora_merge.py from unslothai/unsloth-zoo:main and asserts: - _MOE_MERGE_STATE / _reset_moe_merge_state / _record_moe_merge_fallback are still defined and a `raise RuntimeError(...MoE...)` still fires when fallback > 0. - _detect_moe_lora_layout exists and both "swapped" / "standard" branch labels are reachable in the source. - _resolve_num_experts_from_lora_stats is present AND its base_layer walk is bounded by `for _ in range(N):` (a cyclic ParamWrapper chain must not hang the merge). - merge_and_overwrite_lora still calls model.generation_config.save_pretrained(...). - tests/test_unsloth_zoo_lora_merge.py keeps the six PEFT 0.19+ standard-layout regression tests added in #647. - Local unsloth/save.py still names save_pretrained_merged and routes through merge_and_overwrite_lora (i.e. the entry point still reaches the upstream fix). While #647 is still open, the four symbol tests SKIP cleanly with a message naming #647. When #647 merges into unsloth-zoo main, the same tests automatically become hard gates and catch any future regression. The sixth test (local entry-point grep) passes today. CPU-only static fetch, ~0.1s. Wired into the existing peft-pinned-symbols job in .github/workflows/version-compat-ci.yml so it runs on every PR that touches unsloth/** and on the daily schedule. Local run: 1 passed, 5 skipped (expected; #647 open).
Tighten the docstrings and inline comments added by the layout-aware MoE merge work so the diff is closer to the surrounding house style (see chore #640). No behaviour change; 22 / 22 merge tests still pass.
b0112e5 to
6b0f75e
Compare
…ards (#5410) unsloth#5410 was a class of silent-write bug in the save_pretrained_merged path that the existing CI matrix could not detect because the merge-helper tests were not wired through the upstream-drift suite. The full fix lives in unslothai/unsloth-zoo#647 (layout-aware MoE merge helpers, authoritative num_experts resolver, loud-fail counter, generation_config.json save). This PR adds the unsloth-side canary that watches for the four guards staying in place in unsloth-zoo so a future refactor cannot silently regress them. tests/version_compat/test_unsloth_zoo_save_merged_pinned_symbols.py fetches unsloth_zoo/saving_utils.py + tests/test_unsloth_zoo_lora_merge.py from unslothai/unsloth-zoo:main and asserts: - _MOE_MERGE_STATE / _reset_moe_merge_state / _record_moe_merge_fallback are still defined and a `raise RuntimeError(...MoE...)` still fires when fallback > 0. - _detect_moe_lora_layout exists and both "swapped" / "standard" branch labels are reachable in the source. - _resolve_num_experts_from_lora_stats is present AND its base_layer walk is bounded by `for _ in range(N):` (a cyclic ParamWrapper chain must not hang the merge). - merge_and_overwrite_lora still calls model.generation_config.save_pretrained(...). - tests/test_unsloth_zoo_lora_merge.py keeps the six PEFT 0.19+ standard-layout regression tests added in #647. - Local unsloth/save.py still names save_pretrained_merged and routes through merge_and_overwrite_lora (i.e. the entry point still reaches the upstream fix). While #647 is still open, the four symbol tests SKIP cleanly with a message naming #647. When #647 merges into unsloth-zoo main, the same tests automatically become hard gates and catch any future regression. The sixth test (local entry-point grep) passes today. CPU-only static fetch, ~0.1s. Wired into the existing peft-pinned-symbols job in .github/workflows/version-compat-ci.yml so it runs on every PR that touches unsloth/** and on the daily schedule. Local run: 1 passed, 5 skipped (expected; #647 open).
…ards (#5410) (#5433) * tests: pinned-symbol canary for unsloth-zoo save_pretrained_merged guards (#5410) unsloth#5410 was a class of silent-write bug in the save_pretrained_merged path that the existing CI matrix could not detect because the merge-helper tests were not wired through the upstream-drift suite. The full fix lives in unslothai/unsloth-zoo#647 (layout-aware MoE merge helpers, authoritative num_experts resolver, loud-fail counter, generation_config.json save). This PR adds the unsloth-side canary that watches for the four guards staying in place in unsloth-zoo so a future refactor cannot silently regress them. tests/version_compat/test_unsloth_zoo_save_merged_pinned_symbols.py fetches unsloth_zoo/saving_utils.py + tests/test_unsloth_zoo_lora_merge.py from unslothai/unsloth-zoo:main and asserts: - _MOE_MERGE_STATE / _reset_moe_merge_state / _record_moe_merge_fallback are still defined and a `raise RuntimeError(...MoE...)` still fires when fallback > 0. - _detect_moe_lora_layout exists and both "swapped" / "standard" branch labels are reachable in the source. - _resolve_num_experts_from_lora_stats is present AND its base_layer walk is bounded by `for _ in range(N):` (a cyclic ParamWrapper chain must not hang the merge). - merge_and_overwrite_lora still calls model.generation_config.save_pretrained(...). - tests/test_unsloth_zoo_lora_merge.py keeps the six PEFT 0.19+ standard-layout regression tests added in #647. - Local unsloth/save.py still names save_pretrained_merged and routes through merge_and_overwrite_lora (i.e. the entry point still reaches the upstream fix). While #647 is still open, the four symbol tests SKIP cleanly with a message naming #647. When #647 merges into unsloth-zoo main, the same tests automatically become hard gates and catch any future regression. The sixth test (local entry-point grep) passes today. CPU-only static fetch, ~0.1s. Wired into the existing peft-pinned-symbols job in .github/workflows/version-compat-ci.yml so it runs on every PR that touches unsloth/** and on the daily schedule. Local run: 1 passed, 5 skipped (expected; #647 open). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests/version_compat: relax MoE/generation_config regex to fit zoo#647 zoo#647 landed two layout changes that broke the pinned-symbol canary's exact-string regex matches but kept the underlying guarantees intact: - The post-loop MoE LoRA fallback `raise RuntimeError(...)` wraps the "MoE" wording onto a second line; the old `[^\n]*` did not cross newlines. Switch to `.*?` + re.DOTALL. - The generation_config save now binds the attr to a local var `gen_cfg = getattr(model, "generation_config", ...)` and calls `gen_cfg.save_pretrained(save_directory)`, so a literal `generation_config.save_pretrained(` substring no longer matches. Anchor on the conceptual operation: a `generation_config` mention followed (within a small char window) by a `.save_pretrained(` call. That is what the canary actually cares about. Verified locally: pytest tests/version_compat/test_unsloth_zoo_save_merged_pinned_symbols.py -> 2 passed (4 deselected) --------- Co-authored-by: Daniel Han-Chen <info@unsloth.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Brings PR unslothai#527's moeFix branch up to current main (HEAD 57bbdc0). Conflict resolution: - unsloth_zoo/temporary_patches/qwen3_moe.py: took main's version (PR unslothai#574 already added staticmethod() + a refactored extractor that uses extract_moe_lora_weights_for_grouped_mm; my local layout-aware rewrite is now redundant). - unsloth_zoo/temporary_patches/qwen3_vl_moe.py: took main's version (same reason; PR unslothai#574 already wrapped the extractor with staticmethod). Auto-merged cleanly: - unsloth_zoo/temporary_patches/glm4_moe.py: keeps the new patch_glm4_moe_standard registration alongside main's helper-based refactor of patch_glm4_moe. - unsloth_zoo/saving_utils.py: PR unslothai#647 (saving fixes) is already in main as faee224, so my three saving cherry-picks are subsumed. Stashed-then-dropped local shims: - The local _active_merge_device backport in saving_utils.py and the _unsloth_get_mm_token_id / _unsloth_fix_mm_token_type_ids compat shims in rl_replacements.py were never committed; both symbols now exist in main, so the shims were dropped. Verified post-merge: - python -c "import unsloth; import unsloth_zoo.temporary_patches.{glm4_moe,moe_bnb_transformers,qwen3_moe,misc}" succeeds - patch_glm4_moe_standard, _ParamShapeProxy, patch_peft_param_wrapper_4bit_expert_shape, patch_peft_param_wrapper_merge_4bit are all reachable.
Aligns the bitsandbytes 4-bit MoE support with the FP8 MoE support landed on PR unslothai#548 so both quantization kinds share a single harness: - Rename moe_bnb_transformers.py -> moe_utils_bnb4bit.py (matches the moe_utils_fp8.py file name PR unslothai#548 uses). - Add forward_moe_backend_bnb4bit(self, ...) dispatcher with the same shape as forward_moe_backend_fp8: dequantize gate_up_proj/down_proj, hand off to the regular grouped_mm / triton / native backend via a temporary weight swap (_call_with_temporary_moe_weights). - Add _moe_uses_bnb4bit_expert_weights detection helper. - moe_utils.forward_moe_backend now tries bnb4bit dispatch the same way PR unslothai#548 wires fp8 dispatch (try-import + early return on hit). The two branches are independent and stack trivially. - Drop patch_transformers_grouped_linear_4bit. The lower-level _grouped_linear / _batched_linear / batched_mm_experts_forward wrapping is no longer needed for any of the per-arch MoE classes that Unsloth already patches (qwen3*, glm4_moe lite + standard, deepseek_v3, gpt_oss): they all route through forward_moe_backend which now handles bnb4bit. Arches whose experts class is NOT patched per-class (e.g. transformers-default Gemma4MoE) will need a per-class patch instead of the generic interception. Verified on the user's moe_train_infer_grad_check.py harness across 4 tiny MoE archs x {bf16, bnb4bit} on GPU 7. 4bit and bf16 trajectories match per arch (within stochastic noise), confirming the dequant path is numerically equivalent: qwen3_moe 16bit 12.01->7.39 | 4bit 11.96->7.37 qwen3_5_moe 16bit 11.09->0.47 | 4bit 11.11->0.43 glm4_moe 16bit 11.62->6.69 | 4bit 11.66->6.66 deepseek_v3_moe 16bit (33% acc) | 4bit 10.36->0.05 Save (PR unslothai#647 path) and reload load both succeed on every cell. The qwen3_5_moe 4bit reload-accuracy gap (100% post-train -> 0% post-reload) is a pre-existing bnb4bit save/reload roundtrip issue, not introduced by this consolidation.
Summary
save_pretrained_merged(..., save_method=\"merged_16bit\")silently drops the entire MoE expert LoRA delta on Qwen3-MoE / Qwen3.5-MoE-style models when running onpeft >= 0.19.1+transformers >= 5.0(reported in unslothai/unsloth#5410). Four issues stack up:unsloth_zoo/saving_utils.pyhardcode the PEFT 0.18 "swapped" layout (lora_A: (E*r, 2I),lora_B: (H, E*r)for fused gate_up_proj;lora_A: (E*r, H),lora_B: (I, E*r)for fused down_proj). PEFT 0.19+ swaps in/out features for non-transposed 3D parameters (MNT: Pin GitHub action hashes for security huggingface/peft#2521) and produces the opposite shapes.addmmshape error was swallowed by a baretry / except Exception: return W(and the dim-heuristic in the fused helpers fell through toreturn W), so the merge wrote the unmodified base weight and reported success._merge_and_overwrite_loraflow the per-expert merge loop'snum_expertscame from the shard-local key scan, which can be a non-divisor oftotal_rankwhenever experts are split across multiple safetensor shards (16 / 17 in some shards of the 128-expert Qwen3-30B-A3B layout).generation_config.json, so chat-tuned models reloaded with default eos / sampling and ran past EOS.Fix
_detect_moe_lora_layout(lora_A, lora_B, num_experts, out_dim, in_dim)classifies the layout by shape against the per-expert on-disk weight. No version sniffing, so it works on transformers 4.57.x / 5.x and peft 0.18.x / 0.19.x._merge_moe_gate_or_up_expertand_merge_moe_down_proj_expertbranch on the detected layout. The PEFT 0.18 "swapped" path is byte-identical to the previous behaviour._resolve_num_experts_from_lora_stats(lora_stats, fallback)walksmodule -> base_layer -> ...to read the authoritativenum_expertsoff the wrapped MoE module (Qwen3MoeExpertsand similar)._merge_and_overwrite_loracalls it for every MoE LoRA inconverted_lora_weightsand overrides the shard-local value inmoe_num_expertsbefore the per-expert loop runs._MOE_MERGE_STATEtracks(attempted, applied, fallback, first_error). Helpers record fallbacks with role / expert / shapes / reason on any unrecognised layout or exception. After the shard loopmerge_and_overwrite_loraraisesRuntimeErrorif any fallback fired, so a partially merged checkpoint can no longer be silently written. On success it printsapplied/attempted.merged_16bitbranch also callsmodel.generation_config.save_pretrained(save_directory)(best-effort, matching the same pattern asfix_tokenizer_config_json).Tests
tests/test_unsloth_zoo_lora_merge.pynow covers:_merge_moe_gate_expert,_merge_moe_up_expert,_merge_moe_down_proj_expert._detect_moe_lora_layoutfor swapped, standard, mismatched shapes, and non-divisornum_experts.first_errorpopulates on unrecognised shapes._resolve_num_experts_from_lora_statswalks thebase_layerchain (covers the inner ParamWrapper case formlp.experts.down_projwhere the outer ParamWrapper hasmodule = None).End-to-end verification
Full Qwen3-30B-A3B (128 experts x 48 layers, fused 3D in memory, per-expert 2D on disk): load -> attach LoRA (
r=32, alpha=64, target_modules=[q,k,v,o,gate,up,down]_proj, lora_dropout=0) -> 5 SFT steps ->save_pretrained_merged(merged_16bit)-> reload -> compare merged-reload logits to the trained in-memory model on a fixed eval batch.Reading:
KL = O(1e-4)between merged-reload and trained model is bf16 noise;samples = 3 / 3means all 3 greedy generations on held-out prompts match exactly;merged vs base KLapproxbase vs trained KLconfirms the full training delta is baked into the saved merged dir.Before this patch the first row was KL = 1.86, samples = 1 / 3, 0 / 18432 expert LoRA deltas applied.
Notes:
transformers 4.57.6hasQwen3MoeSparseMoeBlock.experts = nn.ModuleList(Qwen3MoeMLP)per expert (no fused 3D parameter), so the MoE merge helpers do not fire and every per-expert Linear takes the standard dense_merge_lorapath. The MoE helpers are unreachable ontransformers < 5; the patch only affects the path that produces the bug.trl 1.4row usespadding_free=Falsein the reproducer (TRL 1.x raises whenpadding_free=Trueis combined with a finitemax_lengthwithout packing). Unrelated to this patch.transformers 5.5 + peft 0.19 + trl 0.25separately: trained vs merged KL = 1.1e-4, top-1 agreement = 0.978, samples 4 / 4. The dense_merge_lorapath is untouched.Test plan
pytest unsloth-zoo/tests/test_unsloth_zoo_lora_merge.py-> 22 passedgeneration_config.jsonis written intosave_directoryformerged_16bit_MOE_MERGE_STATEraises a clearRuntimeErrorwhen the layout is unrecognised (synthetic test + a deliberately mangled shape fixture)Related
save_pretrained_mergedunsloth#5410