gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#5432
gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#5432danielhanchen wants to merge 29 commits into
Conversation
unsloth/gemma-4-26B-A4B-it loads at ~46 GB even with load_in_4bit=True because Gemma4TextExperts stores experts as fused 3D nn.Parameter tensors (gate_up_proj of shape (128, 1408, 2816), down_proj of (128, 2816, 704)) so torch._grouped_mm can dispatch a single grouped matmul per layer. bitsandbytes' replace_with_bnb_linear only swaps nn.Linear instances, so the fused expert weights stay BF16 and dominate the VRAM footprint. This adds an opt-in helper that walks the loaded model, finds every Gemma4TextExperts module, slices each fused (E, O, I) Parameter into E individual bnb.nn.Linear4bit modules (per-expert), and patches forward to dispatch per-expert instead of via torch._grouped_mm. Trade-off: - VRAM win: 46 GB -> 14.27 GB resident on unsloth/gemma-4-26B-A4B-it (B200, transformers 5.5.0, single GPU). Linear4bit count 206 -> 7886. Forward-pass cosine similarity vs BF16 reference is 0.994 on a fixed prompt, i.e. standard QLoRA fidelity. - Throughput loss: per-expert dispatch loses the grouped_mm speedup. Acceptable for "model fits at 4-bit on a single GPU"; QLoRA training still needs the matching per-expert LoRA path which is not in this PR. Gated on UNSLOTH_GEMMA4_MOE_4BIT=1, default off until the per-expert LoRA path lands (the swap renames gate_up_proj -> gate_up_proj_4bit which would break unsloth_zoo's grouped_mm LoRA extractor as-is). The renamed attributes also make the helper idempotent: re-entering it sees `_unsloth_gemma4_moe_4bit_swapped` and no-ops, so multiple calls across nested loaders are safe. No regression on non-MoE checkpoints: the helper only touches modules that are isinstance(Gemma4TextExperts) with the expected 3D shape. Tests cover env-var gating, no-op behaviour on non-Gemma4 models, the transformers-without-gemma4 ImportError path, and idempotence on a stub Gemma4TextExperts module. Refs #5344
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces a per-expert bitsandbytes Linear4bit swap for Gemma-4 MoE experts to enable 4-bit quantization, which significantly reduces VRAM usage. The implementation includes a new module for the swap logic, integration into the vision model loading process, and comprehensive unit tests. Feedback from the review highlights a potential TypeError and performance bottleneck when iterating over GPU tensors to index nn.ModuleList, suggesting a move to CPU-based integer indexing. Additionally, a redundant index check in the forward pass was identified as dead code.
| ) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
| expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero() | ||
|
|
There was a problem hiding this comment.
Iterating over a GPU tensor (expert_hit) and using its elements (which remain GPU tensors) to index an nn.ModuleList will cause a TypeError, as nn.ModuleList only supports integer indexing. Furthermore, accessing elements of a GPU tensor within a loop triggers host-device synchronizations that significantly impact performance. It is recommended to move the active expert indices to the CPU and convert them to a list of integers. Using torch.any is also more idiomatic than sum > 0 for boolean masks.
| ) | |
| expert_mask = expert_mask.permute(2, 1, 0) | |
| expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero() | |
| expert_hit = torch.any(expert_mask, dim=(-1, -2)).nonzero().view(-1).tolist() | |
| for expert_idx in expert_hit: |
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
| for expert_idx in expert_hit: | ||
| expert_idx = expert_idx[0] |
There was a problem hiding this comment.
This check is redundant and represents dead code. Since expert_mask is generated via one_hot with num_classes=self.num_experts, any index returned by nonzero() is guaranteed to be strictly less than self.num_experts.
References
- To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ddf54efa5f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| gate_up_list.append(gu.to(device)) | ||
| down_list.append(dp.to(device)) |
There was a problem hiding this comment.
Free fused experts before allocating replacements
When the Gemma-4 MoE checkpoint only barely fits in its partially-quantized form, this loop allocates every replacement Linear4bit for the current experts module while the original BF16 gate_up_proj/down_proj tensors are still registered and referenced until the later del statements. That creates an avoidable VRAM spike during the opt-in swap and can OOM before any BF16 expert memory is released, which defeats the tight-memory use case this path is meant to unblock. Consider offloading/freeing the fused tensors before building the ModuleLists, or quantizing from CPU slices.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3394d1a81d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # whose fused 3D expert weights bnb cannot quantize (#5344). | ||
| # Off by default; users enable via UNSLOTH_GEMMA4_MOE_4BIT=1. | ||
| if load_in_4bit and not full_finetuning: | ||
| try: |
There was a problem hiding this comment.
Honor 4-bit quantization_config when gating the swap
When callers provide a BitsAndBytesConfig(load_in_4bit=True) instead of the load_in_4bit boolean, FastLanguageModel clears the boolean before calling this loader to avoid double-applying the config, while leaving the user's quantization_config in kwargs. In that documented/custom-quantization path this condition is false, so UNSLOTH_GEMMA4_MOE_4BIT=1 never swaps the Gemma-4 fused experts and the model still carries the large BF16 expert tensors. Please treat a user quantization config with load_in_4bit=True as enabling this path as well.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 5ae1456e19
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| compute_dtype = ( | ||
| bnb_config.bnb_4bit_compute_dtype | ||
| if bnb_config is not None | ||
| else torch.bfloat16 | ||
| ), |
There was a problem hiding this comment.
Use the actual user quantization config for swapped experts
When callers pass a custom quantization_config while also setting load_in_4bit=True, from_pretrained forwards that user config in kwargs, but this swap still derives the expert settings from the synthetic default bnb_config. In that scenario the rest of the model can be loaded with e.g. fp4, a different compute dtype, or different storage settings, while the Gemma-4 MoE experts are always rebuilt with this default NF4 path, causing mixed and surprising quantization behavior instead of honoring the requested config.
Useful? React with 👍 / 👎.
…nslothai#5460) The Windows-runner "Stop Studio" step's kill + sleep block has been observed to exit 143 (SIGTERM) even when the upstream test work passed. Most recently caught on PR unslothai#5432 Job 3 "JSON, images": all four assertions (json_object, plain inference, image/openai, image/anthropic) printed PASS, then the kill step ran for ~2 seconds and exited 143, failing the job. Teardown does not gate correctness. Wrap all three Stop Studio steps with set +e + redirected error streams + explicit exit 0 so transient Git Bash signal weirdness no longer masks a green test run.
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to enable 4-bit quantization for Gemma-4 MoE experts by replacing the fused 3D expert weights with per-expert Linear4bit modules, solving the VRAM problem reported in #5344. As a summary, this PR introduces unsloth/models/gemma4_moe_4bit.py (the swap helper and per-expert forward), wires it into FastBaseModel.from_pretrained behind UNSLOTH_GEMMA4_MOE_4BIT=1, and adds a CPU-only test suite.
The VRAM reduction is real and meaningful (46 GB -> 14.27 GB on the 26B model). A few issues to address before landing:
| Severity | Finding |
|---|---|
| P1 | vision.py:1063 -- swap gate checks only positional load_in_4bit; ignores quantization_config.load_in_4bit, so callers using BitsAndBytesConfig(load_in_4bit=True) silently bypass the swap |
| P1 | tests:88 -- importlib.import_module = _broken_import does not intercept from ... import (which goes through builtins.__import__), so test_swap_skips_when_transformers_lacks_gemma4 gives false confidence |
| P2 | vision.py:1086 -- exception handler always says "Falling back to BF16 experts" but once one module is swapped and the next fails, the model is in a mixed 4-bit/BF16 state, not pure BF16 |
| P3 | gemma4_moe_4bit.py:196 -- comment "bounded by one expert at a time" is inaccurate; the full fused BF16 tensor stays live throughout the expert loop and is freed only after the loop |
| P3 | tests -- no test exercises _per_expert_forward; a regression in routing math, chunk ordering, or index_add_ accumulation would not be caught |
See inline comments for details and suggested fixes.
| # Opt-in per-expert Linear4bit swap for Gemma-4 MoE checkpoints | ||
| # whose fused 3D expert weights bnb cannot quantize (#5344). | ||
| # Off by default; users enable via UNSLOTH_GEMMA4_MOE_4BIT=1. | ||
| if load_in_4bit and not full_finetuning: |
There was a problem hiding this comment.
[P1] The swap gate checks only the positional load_in_4bit argument. loader.py forwards load_in_4bit=False to FastBaseModel.from_pretrained whenever the caller supplies a quantization_config object (see loader.py:715-725), so a caller doing:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gemma-4-26B-A4B-it",
quantization_config = BitsAndBytesConfig(load_in_4bit=True),
)with UNSLOTH_GEMMA4_MOE_4BIT=1 will silently bypass the swap. The adjacent guardrail _warn_if_quantization_silently_dropped already normalises both sources; the swap gate should do the same.
| if load_in_4bit and not full_finetuning: | |
| _user_qcfg = kwargs.get("quantization_config", None) | |
| if isinstance(_user_qcfg, dict): | |
| _qcfg_4bit = bool(_user_qcfg.get("load_in_4bit", False)) | |
| _qcfg_dtype = _user_qcfg.get("bnb_4bit_compute_dtype", None) | |
| elif _user_qcfg is not None: | |
| _qcfg_4bit = bool(getattr(_user_qcfg, "load_in_4bit", False)) | |
| _qcfg_dtype = getattr(_user_qcfg, "bnb_4bit_compute_dtype", None) | |
| else: | |
| _qcfg_4bit = False | |
| _qcfg_dtype = None | |
| _effective_load_in_4bit = bool(load_in_4bit) or _qcfg_4bit | |
| if _effective_load_in_4bit and not full_finetuning: |
| except Exception as _e: | ||
| warnings.warn( | ||
| f"Unsloth: Gemma-4 MoE 4-bit swap failed: " | ||
| f"{type(_e).__name__}: {_e}. Falling back to BF16 " | ||
| f"experts. Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.", | ||
| stacklevel = 2, | ||
| ) |
There was a problem hiding this comment.
[P2] If swap_gemma4_experts_to_per_expert_linear4bit raises mid-loop (e.g. OOM on the third of six expert layers), the already-swapped layers remain permanently in 4-bit because _unsloth_gemma4_moe_4bit_swapped is set per module before the next module is attempted. The message "Falling back to BF16 experts" misrepresents this mixed-precision state. Reloading is the only way to recover a uniform model.
| except Exception as _e: | |
| warnings.warn( | |
| f"Unsloth: Gemma-4 MoE 4-bit swap failed: " | |
| f"{type(_e).__name__}: {_e}. Falling back to BF16 " | |
| f"experts. Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.", | |
| stacklevel = 2, | |
| ) | |
| except Exception as _e: | |
| _partial = sum( | |
| 1 for _m in model.modules() | |
| if getattr(_m, "_unsloth_gemma4_moe_4bit_swapped", False) | |
| ) | |
| if _partial: | |
| _state = ( | |
| f"{_partial} Gemma4TextExperts module(s) are " | |
| f"already in 4-bit; remaining modules stay BF16. " | |
| f"Reload the model to recover a uniform state." | |
| ) | |
| else: | |
| _state = "Falling back to BF16 experts." | |
| warnings.warn( | |
| f"Unsloth: Gemma-4 MoE 4-bit swap failed: " | |
| f"{type(_e).__name__}: {_e}. {_state} " | |
| f"Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.", | |
| stacklevel = 2, | |
| ) |
| # Drop the fused Parameters before attaching the ModuleLists so peak | ||
| # VRAM during the swap stays bounded by one expert at a time. |
There was a problem hiding this comment.
[P3] The comment is inaccurate. The fused gate_up_proj and down_proj BF16 tensors are still live (held by gate_up and down local variables, and their .data[e] slices are accessed on every iteration of the expert loop) throughout the per-expert quantization loop. They are freed only after the loop via del. Peak VRAM during the swap is therefore fused BF16 + accumulated per-expert nf4, not just one expert at a time.
| # Drop the fused Parameters before attaching the ModuleLists so peak | |
| # VRAM during the swap stays bounded by one expert at a time. | |
| # The fused BF16 gate_up_proj / down_proj stay live through the loop | |
| # above; per-module peak is fused BF16 + accumulated per-expert nf4. | |
| # They are released here, before attaching the ModuleLists. |
|
|
||
| def test_swap_skips_when_transformers_lacks_gemma4(): | ||
| """If transformers does not expose Gemma4TextExperts, the helper must | ||
| return 0 without raising. We simulate the ImportError by patching.""" | ||
| import unsloth.models.gemma4_moe_4bit as g4m | ||
|
|
||
| real_import = importlib.import_module | ||
|
|
||
| def _broken_import(name, *args, **kwargs): | ||
| if name == "transformers.models.gemma4.modeling_gemma4": | ||
| raise ImportError("simulated absence") | ||
| return real_import(name, *args, **kwargs) | ||
|
|
||
| try: | ||
| importlib.import_module = _broken_import | ||
| # Re-exercise via the public helper. It imports Gemma4TextExperts | ||
| # inside its try/except, so the simulated ImportError must yield 0. | ||
| model = nn.Sequential(nn.Linear(8, 8)) | ||
| assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0 | ||
| finally: | ||
| importlib.import_module = real_import |
There was a problem hiding this comment.
[P1] importlib.import_module = _broken_import does not intercept from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts inside swap_gemma4_experts_to_per_expert_linear4bit. Python resolves from X import Y through builtins.__import__, which is independent of importlib.import_module. The test passes only because nn.Sequential(nn.Linear(8, 8)) has no Gemma4TextExperts instances; the early-exit if not isinstance(module, Gemma4TextExperts): continue fires before any import attempt. The import-failure branch is never actually exercised.
The correct approach is to block the import at the sys.modules level: setting sys.modules[key] = None causes from key import ... to raise ImportError immediately.
| def test_swap_skips_when_transformers_lacks_gemma4(): | |
| """If transformers does not expose Gemma4TextExperts, the helper must | |
| return 0 without raising. We simulate the ImportError by patching.""" | |
| import unsloth.models.gemma4_moe_4bit as g4m | |
| real_import = importlib.import_module | |
| def _broken_import(name, *args, **kwargs): | |
| if name == "transformers.models.gemma4.modeling_gemma4": | |
| raise ImportError("simulated absence") | |
| return real_import(name, *args, **kwargs) | |
| try: | |
| importlib.import_module = _broken_import | |
| # Re-exercise via the public helper. It imports Gemma4TextExperts | |
| # inside its try/except, so the simulated ImportError must yield 0. | |
| model = nn.Sequential(nn.Linear(8, 8)) | |
| assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0 | |
| finally: | |
| importlib.import_module = real_import | |
| def test_swap_skips_when_transformers_lacks_gemma4(): | |
| """If transformers does not expose Gemma4TextExperts, the helper must | |
| return 0 without raising. `from X import Y` resolves via | |
| `builtins.__import__`, so we simulate absence via sys.modules sentinel: | |
| setting sys.modules[key] = None forces a fresh `from key import ...` | |
| to raise ImportError.""" | |
| import sys | |
| import unsloth.models.gemma4_moe_4bit as g4m | |
| MODKEY = "transformers.models.gemma4.modeling_gemma4" | |
| _SENTINEL = object() | |
| original = sys.modules.get(MODKEY, _SENTINEL) | |
| sys.modules[MODKEY] = None | |
| try: | |
| model = nn.Sequential(nn.Linear(8, 8)) | |
| assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0 | |
| finally: | |
| if original is _SENTINEL: | |
| sys.modules.pop(MODKEY, None) | |
| else: | |
| sys.modules[MODKEY] = original |
| test_swap_skips_models_without_gemma4_experts() | ||
| test_swap_skips_when_transformers_lacks_gemma4() | ||
| test_swap_idempotent_on_stub_module_without_cuda() | ||
| print("All 4 swap tests passed.") |
There was a problem hiding this comment.
[P3] None of the four tests invoke _per_expert_forward. They check structural properties (env var, swap count, attribute presence), but a regression in routing math, chunk(2) ordering, index_add_ accumulation, or the top_k_weights broadcast would go undetected.
A CPU-runnable test can bypass bitsandbytes entirely by monkey-patching _quantize_one_expert_to_linear4bit to return an exact-copy nn.Linear, then comparing the patched forward against a reference loop over the original fused weights:
| print("All 4 swap tests passed.") | |
| def test_per_expert_forward_matches_reference_on_cpu(): | |
| """Exercise the patched _per_expert_forward on CPU by monkey-patching | |
| the bnb quantizer to return an exact-copy nn.Linear. Verifies the | |
| routing math, chunk(2) ordering, index_add_ accumulation, and | |
| top-k weighting match the upstream fused-weight reference forward.""" | |
| import unsloth.models.gemma4_moe_4bit as g4m | |
| module = _stub_gemma4_module() | |
| if module is None: | |
| return # transformers without gemma4 module: nothing to test | |
| def _fake_quantize(weight_2d, compute_dtype, quant_type="nf4"): | |
| out_features, in_features = weight_2d.shape | |
| linear = nn.Linear(in_features, out_features, bias=False) | |
| linear.weight = nn.Parameter( | |
| weight_2d.detach().clone().to(torch.bfloat16), | |
| requires_grad=False, | |
| ) | |
| return linear | |
| ref_gate_up = module.gate_up_proj.detach().clone() | |
| ref_down = module.down_proj.detach().clone() | |
| ref_act_fn = module.act_fn | |
| num_experts = module.num_experts | |
| hidden_dim = module.hidden_dim | |
| real_quant = g4m._quantize_one_expert_to_linear4bit | |
| g4m._quantize_one_expert_to_linear4bit = _fake_quantize | |
| try: | |
| assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(module) == 1 | |
| finally: | |
| g4m._quantize_one_expert_to_linear4bit = real_quant | |
| torch.manual_seed(0) | |
| n_tokens, top_k = 7, 2 | |
| hidden = torch.randn(n_tokens, hidden_dim, dtype=torch.bfloat16) | |
| top_k_index = torch.tensor( | |
| [[0, 1], [2, 3], [0, 2], [1, 3], [0, 1], [2, 3], [1, 0]], dtype=torch.long | |
| ) | |
| top_k_weights = torch.full((n_tokens, top_k), 0.5, dtype=torch.bfloat16) | |
| got = module(hidden, top_k_index, top_k_weights) | |
| ref = torch.zeros_like(hidden) | |
| for e in range(num_experts): | |
| mask = top_k_index == e | |
| if not mask.any(): | |
| continue | |
| tok_idx, kpos = torch.where(mask) | |
| cs = hidden[tok_idx] | |
| gate, up = torch.nn.functional.linear(cs, ref_gate_up[e]).chunk(2, dim=-1) | |
| ch = ref_act_fn(gate) * up | |
| ch = torch.nn.functional.linear(ch, ref_down[e]) | |
| ch = ch * top_k_weights[tok_idx, kpos, None] | |
| ref.index_add_(0, tok_idx, ch.to(ref.dtype)) | |
| assert got.shape == ref.shape | |
| assert torch.allclose(got, ref, atol=1e-2, rtol=1e-2) | |
| if __name__ == "__main__": | |
| test_is_enabled_reads_env_var() | |
| test_swap_skips_models_without_gemma4_experts() | |
| test_swap_skips_when_transformers_lacks_gemma4() | |
| test_swap_idempotent_on_stub_module_without_cuda() | |
| test_per_expert_forward_matches_reference_on_cpu() | |
| print("All 5 swap tests passed.") |
The Gemma-4 MoE per-expert Linear4bit swap previously gated only on the positional load_in_4bit argument. loader.py forwards load_in_4bit=False to FastBaseModel.from_pretrained whenever the caller supplies a quantization_config (BitsAndBytesConfig), so callers that opt in via UNSLOTH_GEMMA4_MOE_4BIT=1 plus BitsAndBytesConfig(load_in_4bit=True) silently bypassed the swap. The adjacent guardrail already normalises load_in_4bit from quantization_config; the swap gate now does the same and sources bnb_4bit_compute_dtype from quantization_config when no local bnb_config is built. The except branch around the swap also previously stated "Falling back to BF16 experts", which misrepresents the model state when the helper fails partway through (already-swapped Gemma4TextExperts modules stay in 4-bit; only the remainder remain BF16). The warning now counts the modules marked _unsloth_gemma4_moe_4bit_swapped and reports the partial state, advising a reload to recover a uniform state. The comment above the fused-Parameter dels in gemma4_moe_4bit.py overstated the swap's memory bound; rephrased to describe the actual per-module peak (fused BF16 plus accumulated per-expert nf4).
- Normalize string compute_dtype values from dict-style quantization_config
(e.g. {"bnb_4bit_compute_dtype": "bfloat16"}) into torch.dtype before
forwarding to bitsandbytes. A raw string would propagate to
bnb.nn.Linear4bit and crash the first forward with
"Invalid device string: 'bfloat16'".
- Forward bnb_4bit_quant_type from the user's BitsAndBytesConfig or dict
config into swap_gemma4_experts_to_per_expert_linear4bit so swapped
experts match the quantization type used by the rest of the model
(previously always nf4 even when the caller requested fp4).
- Wrap the swap and its warning into a local closure and invoke it from
both the regular auto_model.from_pretrained branch and the
fast_inference=True / convert_vllm_to_huggingface branch. The closure
is idempotent on non-Gemma-4 models, so the vLLM call is free when the
loaded model has no Gemma4TextExperts modules.
- Escalate partial-state swap failures: if the helper raises after one
or more Gemma4TextExperts modules were already committed to 4-bit,
the wrapper re-raises a RuntimeError instructing the caller to reload
the model. Previously a warning implied a clean BF16 fallback, which
is false when partial conversion has already occurred.
The closure is multi-line (long block) because it needs to capture the
already-resolved quantization parameters and be reusable across both
load paths; the alternative is duplicating the entire block.
Shorten WHAT-style narrative on private helpers (_per_expert_forward, _quantize_one_expert_to_linear4bit) to one-line WHY statements; collapse the three-line per-module peak-VRAM note to a single line; drop the three-line opt-in description at the swap call site since the closure name already conveys the intent.
for more information, see https://pre-commit.ci
|
Auto-review verdict: Approved Adds an opt-in (UNSLOTH_GEMMA4_MOE_4BIT=1) post-load swap that replaces Gemma-4 MoE fused 3D expert Parameters with per-expert bnb.nn.Linear4bit modules and patches the expert forward, cutting resident VRAM on gemma-4-26B-A4B-it from ~46 GB to ~14 GB. Review hardened the swap to honor both load_in_4bit and quantization_config (including string dtypes and quant_type), cover the fast_inference/vLLM branch, and surface partial-state failures as RuntimeError so callers reload instead of silently using a mixed model. Reason: All real findings fixed across two review loops; remaining diff is 3 files / 165 lines of correct, well-scoped code. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f16d0df146
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| in_features, | ||
| out_features, | ||
| bias = False, | ||
| compute_dtype = compute_dtype, | ||
| quant_type = quant_type, |
There was a problem hiding this comment.
Propagate bnb double-quant setting into swapped experts
The per-expert swap hardcodes Linear4bit construction and never threads through bnb_4bit_use_double_quant, so users who load with a custom 4-bit config (for example, disabling double quantization) will get mixed quantization behavior: regular nn.Linear layers follow their config while Gemma-4 experts are always rebuilt with the default compressed-statistics path. This can change memory/accuracy characteristics relative to the requested quantization_config and is especially surprising in the custom-config path this commit added support for.
Useful? React with 👍 / 👎.
|
Will this branch / can this branch be adapted to also work with Qwen sparse moes too? |
|
Q&A is this problem solved with the latest 2026.5.7 or not because I'm still waiting to train the gemma 4 moe model with the latest version of Unsloth. |
Summary
unsloth/gemma-4-26B-A4B-itloads at around 46 GB even withload_in_4bit=True. The cause is structural rather than a flag being dropped:Gemma4TextExpertsstores all experts as two fused 3Dnn.Parametertensors (sotorch._grouped_mmcan dispatch one grouped matmul per layer), andbnb.replace_with_bnb_linearonly swapsnn.Linearinstances. The fused expert weights stay BF16 and dominate the VRAM footprint.This PR adds an opt-in helper that, after the model is loaded, walks every
Gemma4TextExpertsmodule, slices each fused(num_experts, out, in)Parameter intonum_expertsindividualbitsandbytes.nn.Linear4bitmodules, and patchesforwardto dispatch per-expert.Measured impact
unsloth/gemma-4-26B-A4B-iton a single B200, transformers 5.5.0, torch 2.9.1:full_finetuning=True)load_in_4bit=True, no swapload_in_4bit=True, swap onForward-pass logits compared on a fixed prompt; cosine similarity 0.994 is standard QLoRA fidelity. Top-5 tokens overlap; argmax differs by 1 token as expected for nf4.
Trade-off
Per-expert dispatch loses the
torch._grouped_mmthroughput. Acceptable for "the model fits at 4-bit on a single GPU", which is what#5344is asking for. QLoRA training still needs the matching per-expert LoRA path; that is the gating reason this PR keeps the swap behindUNSLOTH_GEMMA4_MOE_4BIT=1.The swap also renames
gate_up_projtogate_up_proj_4bit(and the same fordown_proj) so it does not collide withunsloth_zoo'sgemma4_moe.pygrouped-mm LoRA extractor which keys off the original names. That extractor would need a parallel_4bit-aware path before we can flip the default to ON.Activation
The hook is only entered when
load_in_4bit=Trueandfull_finetuning=False, and only fires if the model actually containsGemma4TextExpertsmodules. On every other checkpoint the helper is a no-op.Safety
_unsloth_gemma4_moe_4bit_swappedon the module and skips it.(num_experts, 2*intermediate, hidden)and(num_experts, hidden, intermediate)layout.Gemma4TextExperts.forward, so sibling models in the same process keep the standard path.Tests
tests/test_gemma4_moe_4bit_swap.pycovers:is_gemma4_moe_4bit_enabled)Gemma4TextExperts0return when transformers does not exposeGemma4TextExpertsGemma4TextExpertsmodule (GPU path; CPU path validates the no-op branch)Depends on
#5430 (the guardrail). This PR is opened against
fix-issue-5344-quantization-guardrail; please merge that one first.Follow-up (not in this PR)
Per-expert LoRA on the swapped
nn.ModuleList[Linear4bit]so QLoRA training works withUNSLOTH_GEMMA4_MOE_4BIT=1. That unblocks flipping the default to ON.Refs #5344
Test plan
pytest tests/test_gemma4_moe_4bit_swap.py tests/test_issue_5344_guardrail.py(11/11)unsloth/gemma-4-26B-A4B-itload withUNSLOTH_GEMMA4_MOE_4BIT=1reportsswapped 30 Gemma4TextExperts module(s), resident VRAM 14.27 GB, forward cosine 0.994 vs BF16.unsloth/gemma-3-1b-it(local + HF),unsloth/gemma-4-E2B-it(local + HF) load identically with and withoutUNSLOTH_GEMMA4_MOE_4BIT=1.transformers.models.gemma4.modeling_gemma4import yields0swapped and no raise.