Skip to content

gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#5432

Open
danielhanchen wants to merge 29 commits into
fix-issue-5344-quantization-guardrailfrom
feat-gemma4-moe-4bit-swap
Open

gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#5432
danielhanchen wants to merge 29 commits into
fix-issue-5344-quantization-guardrailfrom
feat-gemma4-moe-4bit-swap

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

unsloth/gemma-4-26B-A4B-it loads at around 46 GB even with load_in_4bit=True. The cause is structural rather than a flag being dropped: Gemma4TextExperts stores all experts as two fused 3D nn.Parameter tensors (so torch._grouped_mm can dispatch one grouped matmul per layer), and bnb.replace_with_bnb_linear only swaps nn.Linear instances. The fused expert weights stay BF16 and dominate the VRAM footprint.

This PR adds an opt-in helper that, after the model is loaded, walks every Gemma4TextExperts module, slices each fused (num_experts, out, in) Parameter into num_experts individual bitsandbytes.nn.Linear4bit modules, and patches forward to dispatch per-expert.

Measured impact

unsloth/gemma-4-26B-A4B-it on a single B200, transformers 5.5.0, torch 2.9.1:

mode resident VRAM peak (fwd) Linear4bit count forward fidelity
BF16 baseline (full_finetuning=True) 48.08 GB 48.20 GB 0 reference
load_in_4bit=True, no swap ~46 GB n/a 206 (attention only) partial-quant
load_in_4bit=True, swap on 14.27 GB 14.40 GB 7,886 cosine 0.994 vs BF16

Forward-pass logits compared on a fixed prompt; cosine similarity 0.994 is standard QLoRA fidelity. Top-5 tokens overlap; argmax differs by 1 token as expected for nf4.

Trade-off

Per-expert dispatch loses the torch._grouped_mm throughput. Acceptable for "the model fits at 4-bit on a single GPU", which is what #5344 is asking for. QLoRA training still needs the matching per-expert LoRA path; that is the gating reason this PR keeps the swap behind UNSLOTH_GEMMA4_MOE_4BIT=1.

The swap also renames gate_up_proj to gate_up_proj_4bit (and the same for down_proj) so it does not collide with unsloth_zoo's gemma4_moe.py grouped-mm LoRA extractor which keys off the original names. That extractor would need a parallel _4bit-aware path before we can flip the default to ON.

Activation

import os
os.environ["UNSLOTH_GEMMA4_MOE_4BIT"] = "1"

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/gemma-4-26B-A4B-it",
    load_in_4bit = True,
)
# Unsloth: swapped 30 Gemma4TextExperts module(s) to per-expert Linear4bit
# (see https://github.com/unslothai/unsloth/issues/5344).

The hook is only entered when load_in_4bit=True and full_finetuning=False, and only fires if the model actually contains Gemma4TextExperts modules. On every other checkpoint the helper is a no-op.

Safety

  • Idempotent: re-running the swap sees _unsloth_gemma4_moe_4bit_swapped on the module and skips it.
  • Shape-checked: bails out without modifying weights if the fused Parameters do not have the (num_experts, 2*intermediate, hidden) and (num_experts, hidden, intermediate) layout.
  • Failure-soft: if any error is raised during the swap, falls back to BF16 experts and warns the user. The existing guardrail from guardrail: detect silent 4-bit / 8-bit quantization bypass (#5344) #5430 will then fire and explain why VRAM is high.
  • Per-instance forward rebind: does not touch the class-level Gemma4TextExperts.forward, so sibling models in the same process keep the standard path.

Tests

tests/test_gemma4_moe_4bit_swap.py covers:

  • env-var gating (is_gemma4_moe_4bit_enabled)
  • no-op on models without Gemma4TextExperts
  • graceful 0 return when transformers does not expose Gemma4TextExperts
  • idempotence on a stub Gemma4TextExperts module (GPU path; CPU path validates the no-op branch)
$ python -m pytest tests/test_gemma4_moe_4bit_swap.py tests/test_issue_5344_guardrail.py -q
...........                                                              [100%]
11 passed in 0.22s

Depends on

#5430 (the guardrail). This PR is opened against fix-issue-5344-quantization-guardrail; please merge that one first.

Follow-up (not in this PR)

Per-expert LoRA on the swapped nn.ModuleList[Linear4bit] so QLoRA training works with UNSLOTH_GEMMA4_MOE_4BIT=1. That unblocks flipping the default to ON.

Refs #5344

Test plan

  • pytest tests/test_gemma4_moe_4bit_swap.py tests/test_issue_5344_guardrail.py (11/11)
  • Manual: unsloth/gemma-4-26B-A4B-it load with UNSLOTH_GEMMA4_MOE_4BIT=1 reports swapped 30 Gemma4TextExperts module(s), resident VRAM 14.27 GB, forward cosine 0.994 vs BF16.
  • No regression on non-MoE: unsloth/gemma-3-1b-it (local + HF), unsloth/gemma-4-E2B-it (local + HF) load identically with and without UNSLOTH_GEMMA4_MOE_4BIT=1.
  • Failure-soft: a deliberately broken transformers.models.gemma4.modeling_gemma4 import yields 0 swapped and no raise.

unsloth/gemma-4-26B-A4B-it loads at ~46 GB even with load_in_4bit=True
because Gemma4TextExperts stores experts as fused 3D nn.Parameter tensors
(gate_up_proj of shape (128, 1408, 2816), down_proj of (128, 2816, 704))
so torch._grouped_mm can dispatch a single grouped matmul per layer.
bitsandbytes' replace_with_bnb_linear only swaps nn.Linear instances, so
the fused expert weights stay BF16 and dominate the VRAM footprint.

This adds an opt-in helper that walks the loaded model, finds every
Gemma4TextExperts module, slices each fused (E, O, I) Parameter into E
individual bnb.nn.Linear4bit modules (per-expert), and patches forward
to dispatch per-expert instead of via torch._grouped_mm.

Trade-off:

- VRAM win: 46 GB -> 14.27 GB resident on unsloth/gemma-4-26B-A4B-it
  (B200, transformers 5.5.0, single GPU). Linear4bit count 206 -> 7886.
  Forward-pass cosine similarity vs BF16 reference is 0.994 on a fixed
  prompt, i.e. standard QLoRA fidelity.

- Throughput loss: per-expert dispatch loses the grouped_mm speedup.
  Acceptable for "model fits at 4-bit on a single GPU"; QLoRA training
  still needs the matching per-expert LoRA path which is not in this PR.

Gated on UNSLOTH_GEMMA4_MOE_4BIT=1, default off until the per-expert
LoRA path lands (the swap renames gate_up_proj -> gate_up_proj_4bit
which would break unsloth_zoo's grouped_mm LoRA extractor as-is).

The renamed attributes also make the helper idempotent: re-entering it
sees `_unsloth_gemma4_moe_4bit_swapped` and no-ops, so multiple calls
across nested loaders are safe.

No regression on non-MoE checkpoints: the helper only touches modules
that are isinstance(Gemma4TextExperts) with the expected 3D shape.

Tests cover env-var gating, no-op behaviour on non-Gemma4 models, the
transformers-without-gemma4 ImportError path, and idempotence on a stub
Gemma4TextExperts module.

Refs #5344

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a per-expert bitsandbytes Linear4bit swap for Gemma-4 MoE experts to enable 4-bit quantization, which significantly reduces VRAM usage. The implementation includes a new module for the swap logic, integration into the vision model loading process, and comprehensive unit tests. Feedback from the review highlights a potential TypeError and performance bottleneck when iterating over GPU tensors to index nn.ModuleList, suggesting a move to CPU-based integer indexing. Additionally, a redundant index check in the forward pass was identified as dead code.

Comment on lines +73 to +76
)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Iterating over a GPU tensor (expert_hit) and using its elements (which remain GPU tensors) to index an nn.ModuleList will cause a TypeError, as nn.ModuleList only supports integer indexing. Furthermore, accessing elements of a GPU tensor within a loop triggers host-device synchronizations that significantly impact performance. It is recommended to move the active expert indices to the CPU and convert them to a list of integers. Using torch.any is also more idiomatic than sum > 0 for boolean masks.

Suggested change
)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero()
expert_hit = torch.any(expert_mask, dim=(-1, -2)).nonzero().view(-1).tolist()
for expert_idx in expert_hit:
References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

Comment on lines +77 to +78
for expert_idx in expert_hit:
expert_idx = expert_idx[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is redundant and represents dead code. Since expert_mask is generated via one_hot with num_classes=self.num_experts, any index returned by nonzero() is guaranteed to be strictly less than self.num_experts.

References
  1. To improve efficiency, avoid redundant data iterations. Combine checks and transformations into a single loop and return computed values for callers to reuse.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ddf54efa5f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +193 to +194
gate_up_list.append(gu.to(device))
down_list.append(dp.to(device))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Free fused experts before allocating replacements

When the Gemma-4 MoE checkpoint only barely fits in its partially-quantized form, this loop allocates every replacement Linear4bit for the current experts module while the original BF16 gate_up_proj/down_proj tensors are still registered and referenced until the later del statements. That creates an avoidable VRAM spike during the opt-in swap and can OOM before any BF16 expert memory is released, which defeats the tight-memory use case this path is meant to unblock. Consider offloading/freeing the fused tensors before building the ModuleLists, or quantizing from CPU slices.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3394d1a81d

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/vision.py Outdated
# whose fused 3D expert weights bnb cannot quantize (#5344).
# Off by default; users enable via UNSLOTH_GEMMA4_MOE_4BIT=1.
if load_in_4bit and not full_finetuning:
try:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor 4-bit quantization_config when gating the swap

When callers provide a BitsAndBytesConfig(load_in_4bit=True) instead of the load_in_4bit boolean, FastLanguageModel clears the boolean before calling this loader to avoid double-applying the config, while leaving the user's quantization_config in kwargs. In that documented/custom-quantization path this condition is false, so UNSLOTH_GEMMA4_MOE_4BIT=1 never swaps the Gemma-4 fused experts and the model still carries the large BF16 expert tensors. Please treat a user quantization config with load_in_4bit=True as enabling this path as well.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5ae1456e19

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/vision.py Outdated
Comment on lines +1073 to +1077
compute_dtype = (
bnb_config.bnb_4bit_compute_dtype
if bnb_config is not None
else torch.bfloat16
),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use the actual user quantization config for swapped experts

When callers pass a custom quantization_config while also setting load_in_4bit=True, from_pretrained forwards that user config in kwargs, but this swap still derives the expert settings from the synthetic default bnb_config. In that scenario the rest of the model can be loaded with e.g. fp4, a different compute dtype, or different storage settings, while the Gemma-4 MoE experts are always rebuilt with this default NF4 path, causing mixed and surprising quantization behavior instead of honoring the requested config.

Useful? React with 👍 / 👎.

pull Bot pushed a commit to ShinnChow/unsloth that referenced this pull request May 15, 2026
…nslothai#5460)

The Windows-runner "Stop Studio" step's kill + sleep block has
been observed to exit 143 (SIGTERM) even when the upstream test
work passed. Most recently caught on PR unslothai#5432 Job 3 "JSON, images":
all four assertions (json_object, plain inference, image/openai,
image/anthropic) printed PASS, then the kill step ran for ~2
seconds and exited 143, failing the job.

Teardown does not gate correctness. Wrap all three Stop Studio
steps with set +e + redirected error streams + explicit exit 0
so transient Git Bash signal weirdness no longer masks a green
test run.

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to enable 4-bit quantization for Gemma-4 MoE experts by replacing the fused 3D expert weights with per-expert Linear4bit modules, solving the VRAM problem reported in #5344. As a summary, this PR introduces unsloth/models/gemma4_moe_4bit.py (the swap helper and per-expert forward), wires it into FastBaseModel.from_pretrained behind UNSLOTH_GEMMA4_MOE_4BIT=1, and adds a CPU-only test suite.

The VRAM reduction is real and meaningful (46 GB -> 14.27 GB on the 26B model). A few issues to address before landing:

Severity Finding
P1 vision.py:1063 -- swap gate checks only positional load_in_4bit; ignores quantization_config.load_in_4bit, so callers using BitsAndBytesConfig(load_in_4bit=True) silently bypass the swap
P1 tests:88 -- importlib.import_module = _broken_import does not intercept from ... import (which goes through builtins.__import__), so test_swap_skips_when_transformers_lacks_gemma4 gives false confidence
P2 vision.py:1086 -- exception handler always says "Falling back to BF16 experts" but once one module is swapped and the next fails, the model is in a mixed 4-bit/BF16 state, not pure BF16
P3 gemma4_moe_4bit.py:196 -- comment "bounded by one expert at a time" is inaccurate; the full fused BF16 tensor stays live throughout the expert loop and is freed only after the loop
P3 tests -- no test exercises _per_expert_forward; a regression in routing math, chunk ordering, or index_add_ accumulation would not be caught

See inline comments for details and suggested fixes.

Comment thread unsloth/models/vision.py Outdated
# Opt-in per-expert Linear4bit swap for Gemma-4 MoE checkpoints
# whose fused 3D expert weights bnb cannot quantize (#5344).
# Off by default; users enable via UNSLOTH_GEMMA4_MOE_4BIT=1.
if load_in_4bit and not full_finetuning:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] The swap gate checks only the positional load_in_4bit argument. loader.py forwards load_in_4bit=False to FastBaseModel.from_pretrained whenever the caller supplies a quantization_config object (see loader.py:715-725), so a caller doing:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-4-26B-A4B-it",
    quantization_config = BitsAndBytesConfig(load_in_4bit=True),
)

with UNSLOTH_GEMMA4_MOE_4BIT=1 will silently bypass the swap. The adjacent guardrail _warn_if_quantization_silently_dropped already normalises both sources; the swap gate should do the same.

Suggested change
if load_in_4bit and not full_finetuning:
_user_qcfg = kwargs.get("quantization_config", None)
if isinstance(_user_qcfg, dict):
_qcfg_4bit = bool(_user_qcfg.get("load_in_4bit", False))
_qcfg_dtype = _user_qcfg.get("bnb_4bit_compute_dtype", None)
elif _user_qcfg is not None:
_qcfg_4bit = bool(getattr(_user_qcfg, "load_in_4bit", False))
_qcfg_dtype = getattr(_user_qcfg, "bnb_4bit_compute_dtype", None)
else:
_qcfg_4bit = False
_qcfg_dtype = None
_effective_load_in_4bit = bool(load_in_4bit) or _qcfg_4bit
if _effective_load_in_4bit and not full_finetuning:

Comment thread unsloth/models/vision.py Outdated
Comment on lines +1086 to +1092
except Exception as _e:
warnings.warn(
f"Unsloth: Gemma-4 MoE 4-bit swap failed: "
f"{type(_e).__name__}: {_e}. Falling back to BF16 "
f"experts. Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.",
stacklevel = 2,
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] If swap_gemma4_experts_to_per_expert_linear4bit raises mid-loop (e.g. OOM on the third of six expert layers), the already-swapped layers remain permanently in 4-bit because _unsloth_gemma4_moe_4bit_swapped is set per module before the next module is attempted. The message "Falling back to BF16 experts" misrepresents this mixed-precision state. Reloading is the only way to recover a uniform model.

Suggested change
except Exception as _e:
warnings.warn(
f"Unsloth: Gemma-4 MoE 4-bit swap failed: "
f"{type(_e).__name__}: {_e}. Falling back to BF16 "
f"experts. Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.",
stacklevel = 2,
)
except Exception as _e:
_partial = sum(
1 for _m in model.modules()
if getattr(_m, "_unsloth_gemma4_moe_4bit_swapped", False)
)
if _partial:
_state = (
f"{_partial} Gemma4TextExperts module(s) are "
f"already in 4-bit; remaining modules stay BF16. "
f"Reload the model to recover a uniform state."
)
else:
_state = "Falling back to BF16 experts."
warnings.warn(
f"Unsloth: Gemma-4 MoE 4-bit swap failed: "
f"{type(_e).__name__}: {_e}. {_state} "
f"Unset UNSLOTH_GEMMA4_MOE_4BIT to silence.",
stacklevel = 2,
)

Comment thread unsloth/models/gemma4_moe_4bit.py Outdated
Comment on lines +196 to +197
# Drop the fused Parameters before attaching the ModuleLists so peak
# VRAM during the swap stays bounded by one expert at a time.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P3] The comment is inaccurate. The fused gate_up_proj and down_proj BF16 tensors are still live (held by gate_up and down local variables, and their .data[e] slices are accessed on every iteration of the expert loop) throughout the per-expert quantization loop. They are freed only after the loop via del. Peak VRAM during the swap is therefore fused BF16 + accumulated per-expert nf4, not just one expert at a time.

Suggested change
# Drop the fused Parameters before attaching the ModuleLists so peak
# VRAM during the swap stays bounded by one expert at a time.
# The fused BF16 gate_up_proj / down_proj stay live through the loop
# above; per-module peak is fused BF16 + accumulated per-expert nf4.
# They are released here, before attaching the ModuleLists.

Comment on lines +88 to +108

def test_swap_skips_when_transformers_lacks_gemma4():
"""If transformers does not expose Gemma4TextExperts, the helper must
return 0 without raising. We simulate the ImportError by patching."""
import unsloth.models.gemma4_moe_4bit as g4m

real_import = importlib.import_module

def _broken_import(name, *args, **kwargs):
if name == "transformers.models.gemma4.modeling_gemma4":
raise ImportError("simulated absence")
return real_import(name, *args, **kwargs)

try:
importlib.import_module = _broken_import
# Re-exercise via the public helper. It imports Gemma4TextExperts
# inside its try/except, so the simulated ImportError must yield 0.
model = nn.Sequential(nn.Linear(8, 8))
assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0
finally:
importlib.import_module = real_import

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] importlib.import_module = _broken_import does not intercept from transformers.models.gemma4.modeling_gemma4 import Gemma4TextExperts inside swap_gemma4_experts_to_per_expert_linear4bit. Python resolves from X import Y through builtins.__import__, which is independent of importlib.import_module. The test passes only because nn.Sequential(nn.Linear(8, 8)) has no Gemma4TextExperts instances; the early-exit if not isinstance(module, Gemma4TextExperts): continue fires before any import attempt. The import-failure branch is never actually exercised.

The correct approach is to block the import at the sys.modules level: setting sys.modules[key] = None causes from key import ... to raise ImportError immediately.

Suggested change
def test_swap_skips_when_transformers_lacks_gemma4():
"""If transformers does not expose Gemma4TextExperts, the helper must
return 0 without raising. We simulate the ImportError by patching."""
import unsloth.models.gemma4_moe_4bit as g4m
real_import = importlib.import_module
def _broken_import(name, *args, **kwargs):
if name == "transformers.models.gemma4.modeling_gemma4":
raise ImportError("simulated absence")
return real_import(name, *args, **kwargs)
try:
importlib.import_module = _broken_import
# Re-exercise via the public helper. It imports Gemma4TextExperts
# inside its try/except, so the simulated ImportError must yield 0.
model = nn.Sequential(nn.Linear(8, 8))
assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0
finally:
importlib.import_module = real_import
def test_swap_skips_when_transformers_lacks_gemma4():
"""If transformers does not expose Gemma4TextExperts, the helper must
return 0 without raising. `from X import Y` resolves via
`builtins.__import__`, so we simulate absence via sys.modules sentinel:
setting sys.modules[key] = None forces a fresh `from key import ...`
to raise ImportError."""
import sys
import unsloth.models.gemma4_moe_4bit as g4m
MODKEY = "transformers.models.gemma4.modeling_gemma4"
_SENTINEL = object()
original = sys.modules.get(MODKEY, _SENTINEL)
sys.modules[MODKEY] = None
try:
model = nn.Sequential(nn.Linear(8, 8))
assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(model) == 0
finally:
if original is _SENTINEL:
sys.modules.pop(MODKEY, None)
else:
sys.modules[MODKEY] = original

test_swap_skips_models_without_gemma4_experts()
test_swap_skips_when_transformers_lacks_gemma4()
test_swap_idempotent_on_stub_module_without_cuda()
print("All 4 swap tests passed.")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P3] None of the four tests invoke _per_expert_forward. They check structural properties (env var, swap count, attribute presence), but a regression in routing math, chunk(2) ordering, index_add_ accumulation, or the top_k_weights broadcast would go undetected.

A CPU-runnable test can bypass bitsandbytes entirely by monkey-patching _quantize_one_expert_to_linear4bit to return an exact-copy nn.Linear, then comparing the patched forward against a reference loop over the original fused weights:

Suggested change
print("All 4 swap tests passed.")
def test_per_expert_forward_matches_reference_on_cpu():
"""Exercise the patched _per_expert_forward on CPU by monkey-patching
the bnb quantizer to return an exact-copy nn.Linear. Verifies the
routing math, chunk(2) ordering, index_add_ accumulation, and
top-k weighting match the upstream fused-weight reference forward."""
import unsloth.models.gemma4_moe_4bit as g4m
module = _stub_gemma4_module()
if module is None:
return # transformers without gemma4 module: nothing to test
def _fake_quantize(weight_2d, compute_dtype, quant_type="nf4"):
out_features, in_features = weight_2d.shape
linear = nn.Linear(in_features, out_features, bias=False)
linear.weight = nn.Parameter(
weight_2d.detach().clone().to(torch.bfloat16),
requires_grad=False,
)
return linear
ref_gate_up = module.gate_up_proj.detach().clone()
ref_down = module.down_proj.detach().clone()
ref_act_fn = module.act_fn
num_experts = module.num_experts
hidden_dim = module.hidden_dim
real_quant = g4m._quantize_one_expert_to_linear4bit
g4m._quantize_one_expert_to_linear4bit = _fake_quantize
try:
assert g4m.swap_gemma4_experts_to_per_expert_linear4bit(module) == 1
finally:
g4m._quantize_one_expert_to_linear4bit = real_quant
torch.manual_seed(0)
n_tokens, top_k = 7, 2
hidden = torch.randn(n_tokens, hidden_dim, dtype=torch.bfloat16)
top_k_index = torch.tensor(
[[0, 1], [2, 3], [0, 2], [1, 3], [0, 1], [2, 3], [1, 0]], dtype=torch.long
)
top_k_weights = torch.full((n_tokens, top_k), 0.5, dtype=torch.bfloat16)
got = module(hidden, top_k_index, top_k_weights)
ref = torch.zeros_like(hidden)
for e in range(num_experts):
mask = top_k_index == e
if not mask.any():
continue
tok_idx, kpos = torch.where(mask)
cs = hidden[tok_idx]
gate, up = torch.nn.functional.linear(cs, ref_gate_up[e]).chunk(2, dim=-1)
ch = ref_act_fn(gate) * up
ch = torch.nn.functional.linear(ch, ref_down[e])
ch = ch * top_k_weights[tok_idx, kpos, None]
ref.index_add_(0, tok_idx, ch.to(ref.dtype))
assert got.shape == ref.shape
assert torch.allclose(got, ref, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
test_is_enabled_reads_env_var()
test_swap_skips_models_without_gemma4_experts()
test_swap_skips_when_transformers_lacks_gemma4()
test_swap_idempotent_on_stub_module_without_cuda()
test_per_expert_forward_matches_reference_on_cpu()
print("All 5 swap tests passed.")

danielhanchen and others added 6 commits May 16, 2026 15:46
The Gemma-4 MoE per-expert Linear4bit swap previously gated only on the
positional load_in_4bit argument. loader.py forwards load_in_4bit=False
to FastBaseModel.from_pretrained whenever the caller supplies a
quantization_config (BitsAndBytesConfig), so callers that opt in via
UNSLOTH_GEMMA4_MOE_4BIT=1 plus BitsAndBytesConfig(load_in_4bit=True)
silently bypassed the swap. The adjacent guardrail already normalises
load_in_4bit from quantization_config; the swap gate now does the same
and sources bnb_4bit_compute_dtype from quantization_config when no
local bnb_config is built.

The except branch around the swap also previously stated "Falling back
to BF16 experts", which misrepresents the model state when the helper
fails partway through (already-swapped Gemma4TextExperts modules stay
in 4-bit; only the remainder remain BF16). The warning now counts the
modules marked _unsloth_gemma4_moe_4bit_swapped and reports the partial
state, advising a reload to recover a uniform state.

The comment above the fused-Parameter dels in gemma4_moe_4bit.py
overstated the swap's memory bound; rephrased to describe the actual
per-module peak (fused BF16 plus accumulated per-expert nf4).
- Normalize string compute_dtype values from dict-style quantization_config
  (e.g. {"bnb_4bit_compute_dtype": "bfloat16"}) into torch.dtype before
  forwarding to bitsandbytes. A raw string would propagate to
  bnb.nn.Linear4bit and crash the first forward with
  "Invalid device string: 'bfloat16'".

- Forward bnb_4bit_quant_type from the user's BitsAndBytesConfig or dict
  config into swap_gemma4_experts_to_per_expert_linear4bit so swapped
  experts match the quantization type used by the rest of the model
  (previously always nf4 even when the caller requested fp4).

- Wrap the swap and its warning into a local closure and invoke it from
  both the regular auto_model.from_pretrained branch and the
  fast_inference=True / convert_vllm_to_huggingface branch. The closure
  is idempotent on non-Gemma-4 models, so the vLLM call is free when the
  loaded model has no Gemma4TextExperts modules.

- Escalate partial-state swap failures: if the helper raises after one
  or more Gemma4TextExperts modules were already committed to 4-bit,
  the wrapper re-raises a RuntimeError instructing the caller to reload
  the model. Previously a warning implied a clean BF16 fallback, which
  is false when partial conversion has already occurred.

The closure is multi-line (long block) because it needs to capture the
already-resolved quantization parameters and be reusable across both
load paths; the alternative is duplicating the entire block.
Shorten WHAT-style narrative on private helpers (_per_expert_forward,
_quantize_one_expert_to_linear4bit) to one-line WHY statements; collapse
the three-line per-module peak-VRAM note to a single line; drop the
three-line opt-in description at the swap call site since the closure
name already conveys the intent.
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels May 16, 2026
@danielhanchen

Copy link
Copy Markdown
Member Author

Auto-review verdict: Approved

Adds an opt-in (UNSLOTH_GEMMA4_MOE_4BIT=1) post-load swap that replaces Gemma-4 MoE fused 3D expert Parameters with per-expert bnb.nn.Linear4bit modules and patches the expert forward, cutting resident VRAM on gemma-4-26B-A4B-it from ~46 GB to ~14 GB. Review hardened the swap to honor both load_in_4bit and quantization_config (including string dtypes and quant_type), cover the fast_inference/vLLM branch, and surface partial-state failures as RuntimeError so callers reload instead of silently using a mixed model.

Reason: All real findings fixed across two review loops; remaining diff is 3 files / 165 lines of correct, well-scoped code.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f16d0df146

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +106 to +110
in_features,
out_features,
bias = False,
compute_dtype = compute_dtype,
quant_type = quant_type,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate bnb double-quant setting into swapped experts

The per-expert swap hardcodes Linear4bit construction and never threads through bnb_4bit_use_double_quant, so users who load with a custom 4-bit config (for example, disabling double quantization) will get mixed quantization behavior: regular nn.Linear layers follow their config while Gemma-4 experts are always rebuilt with the default compressed-statistics path. This can change memory/accuracy characteristics relative to the requested quantization_config and is especially surprising in the custom-config path this commit added support for.

Useful? React with 👍 / 👎.

@David-AU-github

Copy link
Copy Markdown

Will this branch / can this branch be adapted to also work with Qwen sparse moes too?
-> Load in 4 bit?
This would be a game changer.

@Jeremy-Developer-Page

Copy link
Copy Markdown

Q&A is this problem solved with the latest 2026.5.7 or not because I'm still waiting to train the gemma 4 moe model with the latest version of Unsloth.
Thanks
Jeremy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-approved Auto-review approved the PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants