Add Bnb4bit support for MoE models on transformers v5 - #4032 #527
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial functionality to enable 4-bit quantization for Mixture-of-Experts (MoE) models within the Transformers library, particularly for versions 5 and above. It addresses the challenge of quantizing MoE expert parameters that are defined as Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This PR introduces support for quantization of MoE parameters in transformers v5 by converting expert parameters to nn.Params4bit and handling quantization/dequantization for PEFT LoRA compatibility. It includes a new module moe_bnb_transformers.py with patching functions and modifications to misc.py and __init__.py to integrate the new functionality.
| # If the parameter is a Params4bit, dequantize it | ||
| if _check_bnb_available() and isinstance(param, Params4bit): | ||
| # Dequantize the parameter | ||
| return bnb.functional.dequantize_4bit(param.data, param.quant_state) |
There was a problem hiding this comment.
Consider adding a check to ensure param.quant_state is not None before dequantizing. If quant_state is None, it could lead to an error during dequantization.
| # If the parameter is a Params4bit, dequantize it | |
| if _check_bnb_available() and isinstance(param, Params4bit): | |
| # Dequantize the parameter | |
| return bnb.functional.dequantize_4bit(param.data, param.quant_state) | |
| # If the parameter is a Params4bit, dequantize it | |
| if _check_bnb_available() and isinstance(param, Params4bit) and param.quant_state is not None: | |
| # Dequantize the parameter | |
| return bnb.functional.dequantize_4bit(param.data, param.quant_state) |
| except Exception as e: | ||
| return raise_error("transformers.quantizers.quantizers_utils.should_convert_module", e) |
| if not has_been_replaced: | ||
| logger.warning(f"Unsloth: No expert parameters were found to be replaced for {model.name_or_path}") |
| except Exception as e: | ||
| logger.warning(f"Unsloth: Error handling expert param quantization for {full_layer_name}: {e}") |
| # TODO: Can we raise an error here? | ||
| logger.warning( | ||
| f"Unsloth: Error checking MoE expert param_needs_quantization for {param_name}: {e}" | ||
| ) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f3f2c6eba9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d5b567c528
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fb69ead7ca
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
Upstream Bnb4bitQuantize.convert unwraps input_dict[key] twice (first .values() then [0]). The patched version did only the first unwrap, which produced a list for most weight-converter dispatch paths and raised TypeError inside Params4bit constructor — silently masked by the broad except Exception and falling back to original_convert (leaving experts unquantized). Empirically: tiny DeepSeek V3 happens to dispatch experts as bare Tensors so the bug was invisible there. Qwen3.5-35B-A3B dispatches as list[Tensor] and triggers the TypeError on every expert param. Reviewed-by: R1 (correctness), R3 (regression)
…-MoE warning (M7+M8) Package-wide convention (mxfp4.py, qwen3_moe.py, qwen3_5_moe.py, qwen3_vl_moe.py, qwen3_next_moe.py, glm4_moe.py, deepseek_v3_moe.py) gates every logger.info on UNSLOTH_ENABLE_LOGGING. The three info logs in moe_bnb_transformers.py and the unconditional warning when no experts found weren't gated and spammed every 4-bit load including dense models (Phi3, GLM4 dense, Llama, Mistral, etc.). Reviewed-by: R3 (regression), R4 (conventions)
…lause (m4+m6+M1) - Add __all__ to stop wildcard re-export of bnb/torch/nn/Optional/etc. into unsloth_zoo.temporary_patches namespace. - Tighten param_needs_quantization: only return True when Params4bit AND bnb_quantized=False, protecting against re-invocation after first quantize. - Narrow patched_param_needs_quantization except from bare Exception to (KeyError, AttributeError) — the expected failures from get_module_from_name. Reviewed-by: R1 (correctness), R3 (regression), R4 (conventions)
…h in detector (M3+B2)
M3: _get_base_weight calls bnb.functional.dequantize_4bit(param.data, None) when
param.quant_state has not been populated yet (meta-device placeholder, or before
.to(cuda)). Add explicit guard to fall through to other branches in that case.
B2: _is_moe_experts_module was tightened from (nn.Parameter|Tensor, ndim in (2,3))
to (Params4bit, ndim==2) OR (nn.Parameter, ndim==3). This contradicts the function's
own doc-comment ('After PEFT's nn.utils.parametrize wrapping, accessing gate_up_proj
returns torch.Tensor (not nn.Parameter), so we must accept both.'). Restore the
Tensor branch so the post-parametrize path is still detected.
Reviewed-by: R1 (correctness), R3 (regression)
…issing _original_shape (M5+M6+m5) R2 found two related issues in patch_peft_param_wrapper_4bit_expert_shape: M5/M6: _patched_get_param hard-codes the peft 0.18 dim-ordering convention (num_experts, in_features, out_features = shape) and reassigns self.in_features and self.out_features on every call. PEFT 0.19's ParamWrapper.update_layer swaps in_features<->out_features for 3D params (gated on _did_swap_in_out_features), then calls _move_adapter_to_device_of_base_layer which calls get_param() again. The reassignment UN-does PEFT's swap, breaking second-adapter add_adapter and any external reader of layer.in_features. Drop the in_features/out_features assignment; keep num_experts (which both peft versions derive identically). m5: the 'else: # TODO: Can we raise an error here? pass' branch silently returned the packed Params4bit. PEFT's _get_in_out_features would then read the (K, 1) packed shape and create LoRA factors with wrong dims. Raise ValueError explaining the failure mode instead. Reviewed-by: R2 (peft)
Reviewed-by: R4 (conventions)
Three changes from REV findings: REV-narrowed-except-loses-runtime-context: patched_convert still had broad except Exception that masked B1 for months. Narrow to (KeyError, AttributeError) for expected get_module_from_name failures, and use logger.exception() for unexpected failures so the traceback is preserved. REV-m3-defer-not-fix: when _get_base_weight sees Params4bit with quant_state=None, fall-through to subsequent branches just returns the raw packed uint8 data which crashes grouped_mm later with a different error. Raise an actionable error instead. REV-typing-tuple-unused: drop unused Tuple import. Reviewed-by: REV (independent post-fix review)
…s (B4)
ParamWrapper.merge does in-place `param.data += delta_weight`. For a
Params4bit MoE expert, `param.data` is the *packed* 4-bit storage of
shape (N_packed, 1) while `delta_weight` is in logical 3D (E, in, out).
This raises a RuntimeError shape mismatch ("size of tensor a (16777216)
must match the size of tensor b (1024) at non-singleton dimension 1")
at the `+=` line, breaking `merge_and_unload()` for every 4-bit MoE
adapter on both peft 0.18 and 0.19.
Override ParamWrapper.merge/unmerge to do dequant -> add -> re-quantize
when the wrapped param is a Params4bit with a 3D _original_shape,
mirroring lora.bnb.Linear4bit.merge. PEFT's own get_delta_weight is
re-used so the 0.18 (e i o) vs 0.19 (e o i, post-swap) einsum
convention is honoured automatically.
Also expose `dtype` on _ParamShapeProxy from quant_state.dtype so PEFT's
internal `.to(param.dtype)` casts in get_delta_weight and
_move_adapter_to_device_of_base_layer no longer truncate bf16 LoRA
values to uint8 (the proxy was delegating .dtype to the packed
Params4bit storage dtype).
Verified by smoke/test_merge_4bit_moe.py on imdatta0/tiny_deepseek_v3
bnb4bit with both peft 0.18 and 0.19: pre-merge load + forward, then
merge_and_unload, then post-merge forward. All pass; the merged expert
parameter remains a Params4bit (re-quantized) with _original_shape
preserved.
…atched MoE forward (B6)
For MoE arches whose experts class is NOT replaced by an unsloth-zoo per-arch
patch (e.g. some Glm4Moe variants, Gemma4MoE), the experts forward falls
through to transformers v5's generic dispatchers in
`transformers.integrations.moe`:
- `grouped_mm_experts_forward` reads `self.gate_up_proj` / `self.down_proj`
raw and passes them through `_grouped_linear` -> `weight.transpose(-2, -1)`
-> `_grouped_mm` -> `torch._grouped_mm(input.to(weight.dtype), ...)`. With
a Params4bit (uint8 packed storage) this raises "Expected mat_a to be
Float32, BFloat16 or Float16 matrix, got Byte" during training.
- `batched_mm_experts_forward` does `selected_weights = self.gate_up_proj[expert_ids]`
BEFORE calling `_batched_linear`. The indexing on a packed Params4bit
returns a (S, 1) uint8 slice that has lost its quant_state and dtype
information, so `_batched_linear` raises "batch1 must be a 3D tensor"
inside `torch.bmm` during autoregressive decoding.
Three layered patches in `moe_bnb_transformers.py`:
1. `_grouped_linear`: dequantize Params4bit weight (using `_original_shape`
to recover the logical 3D `(E, in, out)` shape from `(N, 1)` packed
storage), cast to input dtype, delegate to original.
2. `_batched_linear`: same treatment for the bmm path.
3. `batched_mm_experts_forward`: temporarily swap `self.gate_up_proj` /
`self.down_proj` / `self.up_proj` to dequantized 3D tensors for the
duration of the forward call, then restore — necessary because the
indexing happens before any of the helper functions get a chance to
intercept. Re-register in `ALL_EXPERTS_FUNCTIONS` so the
`use_experts_implementation` decorator picks up the patched version.
Forward-only -- base weights are frozen; gradient flow stays on LoRA paths
that the per-arch wrappers (when they exist) inject separately. For arches
without a per-arch wrapper this gives base-only forward; PEFT's
`_activate_lora` parametrization adds the LoRA delta on top.
Verified by full E2E pipeline on `imdatta0/tiny_glm4_moe_2.8B_0.7B` bnb4bit
+ peft 0.18: load -> train -> save_adapter -> merge -> merged_forward ->
merged_generate -> reload base+adapter -> forward -> generate. Pre-fix:
train_error at first step. Post-fix: full E2E SUCCESS.
…lelist/MergeModulelist reverse (B5)
transformers v5's `SplitModulelist.get_target_patterns` and
`MergeModulelist.get_target_pattern` raise `ValueError("Undefined Operation
encountered!")` when invoked with `len(input_dict) == 1` AND
`len(target_patterns) > 1`. This is the case `revert_weight_conversion` hits
during `save_pretrained` for fused-MoE models (DeepSeek-V3, Qwen3-MoE, GLM-4
MoE, Qwen3.5-MoE, etc.): the in-memory model has one fused expert tensor while
target_patterns derived from the forward-conversion mapping contains multiple
star-templated entries.
This blocks `merged.save_pretrained()` for any fused-MoE model regardless of
quantization.
Fix: when `len(target_patterns) > 1`, pick the target whose suffix component
(after the last `.`, with `*` stripped) matches the source's suffix; fall
back to `target_patterns[0]` otherwise. For `SplitModulelist`, expand `*`
with i over the chosen pattern. Forward-direction logic is unchanged.
…plitModulelist/MergeModulelist reverse (B5)" This reverts commit c57bbe5.
….19 compat) PEFT 0.19's `convert_peft_adapter_state_dict_for_transformers` (in peft/utils/transformers_weight_conversion.py:268) constructs new WeightConverter instances with `distributed_operation` and `quantization_operation` kwargs. transformers 5.6.2's WeightConverter.__init__ signature is just `(source_patterns, target_patterns, operations)` so the extra kwargs raise `TypeError: WeightConverter.__init__() got an unexpected keyword argument 'distributed_operation'` from inside PEFT's adapter loader. Symptom: `PeftModel.from_pretrained(base, adapter_dir)` for any MoE-fused 4-bit model fails on peft 0.19 + transformers 5.6.2, blocking the realistic load-base-then-attach-adapter inference workflow. Affected ~7 cells in the post-B4 E2E sweep where load+train+merge succeeded on peft 0.19 but reload_adapter failed. Fix: patch WeightConverter.__init__ to accept unknown kwargs and drop them before delegating to the original signature. Forward-compatible with peft versions written against newer transformers releases; ignored kwargs only affect distributed/quantization codepaths not exercised at adapter load.
`save_pretrained_merged(save_method="merged_16bit")` silently dropped the
entire MoE expert LoRA delta on Qwen3-MoE / Qwen3.5-MoE-style models with
peft >= 0.19.1. The per-expert helpers in `saving_utils.py` hardcoded the
PEFT 0.18 "swapped" tensor layout (`lora_A: (E*r, 2I)`, `lora_B: (H, E*r)`
for gate_up_proj; `lora_A: (E*r, H)`, `lora_B: (I, E*r)` for down_proj),
while PEFT 0.19+ swaps in/out features for non-transposed 3D parameters
and produces `lora_A: (E*r, H)`, `lora_B: (2I, E*r)` and `lora_A: (E*r, I)`,
`lora_B: (H, E*r)`. The layout mismatch hit a bare `except Exception: return W`
and the dim-heuristic fallthrough in the fused helpers, so the merge
silently wrote unmodified base weights and reported success. The
`num_experts` value used by the per-expert loop was also taken from the
shard-local key scan, which is a non-divisor of `total_rank` whenever
experts are split across multiple safetensor shards (16/17 of 128 on
Qwen3-30B-A3B). Finally the merged dir was missing `generation_config.json`,
so chat-tuned models reloaded with default eos / sampling and ran past EOS.
Changes:
- `_detect_moe_lora_layout(lora_A, lora_B, num_experts, out_dim, in_dim)`
classifies the layout by shape against the per-expert disk weight, so
no version sniffing is required. Works on transformers 4.57.x / 5.x
and peft 0.18.x / 0.19.x.
- `_merge_moe_gate_or_up_expert` and `_merge_moe_down_proj_expert`
branch on the detected layout. The "swapped" path is byte-identical
to the previous behaviour.
- `_resolve_num_experts_from_lora_stats` walks `module -> base_layer ->
...` to read the authoritative `num_experts` off the wrapped MoE
module (`Qwen3MoeExperts` etc). `_merge_and_overwrite_lora` uses it
to override `moe_num_experts[prefix]` after the converted-key build,
so the per-expert loop never trips on a shard-local count.
- `_MOE_MERGE_STATE` tracks `(attempted, applied, fallback, first_error)`.
Each helper records a fallback with role / expert / shapes / reason
on any unrecognised layout or exception. After the shard loop
`merge_and_overwrite_lora` raises `RuntimeError` if any fallback
fired, so partially-merged checkpoints can no longer be silently
written. On success it prints `applied/attempted`.
- The `merged_16bit` branch also calls
`model.generation_config.save_pretrained(save_directory)` (best-effort,
matching `fix_tokenizer_config_json`).
Tests:
- Existing 16 per-expert / fused / dense merge tests in
`test_unsloth_zoo_lora_merge.py` still pass byte-for-byte (PEFT 0.18
swapped layout is the default branch).
- 6 new tests:
* standard layout for `_merge_moe_gate_expert`, `_merge_moe_up_expert`,
`_merge_moe_down_proj_expert`,
* layout classifier for both conventions and the unknown cases,
* fallback counter increments and `first_error` populates on
unrecognised shapes,
* `_resolve_num_experts_from_lora_stats` walks the `base_layer` chain.
End-to-end verification on Qwen3-30B-A3B (128 experts x 48 layers,
fused 3D in memory, per-expert 2D on disk), full SFT + save + reload
+ logit compare:
| transformers | peft | trl | merged tensors | trained vs merged KL | samples |
|--------------|--------|--------|----------------|----------------------|---------|
| 5.5.0 | 0.19.1 | 0.25.1 | 18432 / 18432 | 1.6e-5 | 3 / 3 |
| 5.5.0 | 0.18.1 | 0.25.1 | 18432 / 18432 | 1.3e-5 | 3 / 3 |
| 4.57.6 | 0.19.1 | 0.25.1 | dense path | 5.5e-5 | 3 / 3 |
| 5.5.0 | 0.19.1 | 1.4.0 | 18432 / 18432 | 2.1e-4 | 3 / 3 |
Before the patch the M1 row was KL=1.86, samples=1/3, and 0/18432 expert
LoRA deltas were applied. transformers 4.57.6 has `experts = nn.ModuleList
(Qwen3MoeMLP)` (no fused 3D parameter) so the MoE merge helpers do not
fire and every per-expert Linear takes the standard dense `_merge_lora`
path. The MoE helpers are unreachable on transformers <5; the patch only
affects the path that produces the bug.
Fixes unslothai/unsloth#5410. Likely also resolves unslothai/unsloth#4832
(same author, same "garbage after save_pretrained_merged reload" symptom
on DevStral Small 2).
The base_layer walk in _resolve_num_experts_from_lora_stats was an unbounded `while module is not None` loop. PEFT's ParamWrapper does not self-reference in practice, but a self-referential or cyclic `base_layer` chain would hang the merge. Bound the walk to 16 hops, dedupe via an id() set, and swallow exceptions on getattr / getattr-of-attrs so a hostile module that raises on attribute access cannot abort the merge. Confirmed by a synthetic suite (52 cases) across three isolated venvs: peft 0.18.1 + transformers 5.5.0, peft 0.19.1 + transformers 5.5.0, peft 0.19.1 + transformers 4.57.6. All 22 existing merge tests still pass byte-for-byte in each.
Tighten the docstrings and inline comments added by the layout-aware MoE merge work so the diff is closer to the surrounding house style (see chore unslothai#640). No behaviour change; 22 / 22 merge tests still pass.
Two compounding bugs silently disabled LoRA on all qwen-family MoE models: 1. Bound-method bug: qwen3_moe, qwen3_5_moe, qwen3_next_moe, and qwen3_vl_moe registered _unsloth_lora_extractor_fn as a plain function on the class (e.g. Qwen3MoeExperts._unsloth_lora_extractor_fn = fn). Python's descriptor protocol then bound `self` when accessed via an instance, so the call site in moe_utils._extract_lora_from_wrapper (`extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)`) passed 6 args to a 5-arg function. The TypeError was silently swallowed by the try/except in _extract_lora_from_wrapper, returning None and bypassing the LoRA injection. Fixed by wrapping with staticmethod() to match the working glm4_moe and deepseek_v3_moe pattern. 2. Layout assumption: _make_qwen_moe_lora_extractor's down_proj and gate_up_proj branches assumed PEFT 0.18's "swapped" wrapper layout (weight_A second dim is the *output* of the base param). PEFT 0.19 uses the standard layout (weight_A second dim is the *input*), producing a contraction-dim mismatch in torch._grouped_mm. Fixed by detecting layout per param_name using the experts module's hidden_dim and intermediate_dim (matches the GLM4 extractor's approach), with the existing dim_B == hidden_dim heuristic kept as fallback. Verified on imdatta0/tiny_qwen3_moe_2.8B_0.7B with PEFT 0.19.1: 96 trainable LoRA params, grad_norm = 0.214 (was 0.0), loss decreases across steps.
The existing patch_glm4_moe registered _unsloth_lora_extractor_fn and the forward override only on Glm4MoeLiteNaiveMoe / Glm4MoeLiteMoE (the GLM-4.7 "lite" variant). Standard GLM4 MoE checkpoints use Glm4MoeNaiveMoe / Glm4MoeMoE from transformers.models.glm4_moe — those classes were left unpatched, so the original transformers experts forward ran with base weights only. ParamWrapper still intercepted parameter access and our patched ParamWrapper.forward set _unsloth_lora_* attrs on the experts module, but the original forward never read them. Net effect: LoRA was silently bypassed, training loss flat, grad_norm = 0. Verified on imdatta0/tiny_glm4_moe_2.8B_0.7B (60 steps): - bf16: loss 11.81 -> 3.27, grad_norm 0.63 -> 1.99, merge 4608/4608, save_merged + reload PASS - bnb4bit: loss 11.89 -> 3.28, grad_norm 0.64 -> 1.89, merge 4608/4608, save_merged + reload PASS The standard variant uses the same separated-LoRA computation as the lite variant (3-D gate_up_proj / down_proj layout via PEFT ParamWrapper), so the new patch_glm4_moe_standard reuses get_forward_moe_backend() and a layout-aware extractor mirroring the lite one. Production GLM-4.7 (lite) is unchanged.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cd5ad27b0f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| a_slice = lora_stats.lora_A[start:end, :] | ||
| b_slice = lora_stats.lora_B[:, start:end] | ||
| device = _active_merge_device() |
There was a problem hiding this comment.
Define
_active_merge_device before MoE merge calls
_merge_moe_gate_or_up_expert now calls _active_merge_device(), but this helper is not defined anywhere in saving_utils.py. At runtime this raises NameError, the exception path records a merge fallback for every affected expert, and merge_and_overwrite_lora then aborts with the new fallback guard, so MoE LoRA merges fail instead of producing merged weights.
Useful? React with 👍 / 👎.
Hi!
This PR adds support for quantization for MoE parameters of
nn.Parameter.With transformers v5, MoE parameters of
nn.Parameterwon't get quantized. This PR adds support for the quantization by doing the folllowing:nn.Parametertonn.Params4bitAnalysis using
GLM-4.7-Flash: