gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#48
gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#48danielhanchen wants to merge 28 commits into
Conversation
unsloth/gemma-4-26B-A4B-it loads at ~46 GB even with load_in_4bit=True because Gemma4TextExperts stores experts as fused 3D nn.Parameter tensors (gate_up_proj of shape (128, 1408, 2816), down_proj of (128, 2816, 704)) so torch._grouped_mm can dispatch a single grouped matmul per layer. bitsandbytes' replace_with_bnb_linear only swaps nn.Linear instances, so the fused expert weights stay BF16 and dominate the VRAM footprint. This adds an opt-in helper that walks the loaded model, finds every Gemma4TextExperts module, slices each fused (E, O, I) Parameter into E individual bnb.nn.Linear4bit modules (per-expert), and patches forward to dispatch per-expert instead of via torch._grouped_mm. Trade-off: - VRAM win: 46 GB -> 14.27 GB resident on unsloth/gemma-4-26B-A4B-it (B200, transformers 5.5.0, single GPU). Linear4bit count 206 -> 7886. Forward-pass cosine similarity vs BF16 reference is 0.994 on a fixed prompt, i.e. standard QLoRA fidelity. - Throughput loss: per-expert dispatch loses the grouped_mm speedup. Acceptable for "model fits at 4-bit on a single GPU"; QLoRA training still needs the matching per-expert LoRA path which is not in this PR. Gated on UNSLOTH_GEMMA4_MOE_4BIT=1, default off until the per-expert LoRA path lands (the swap renames gate_up_proj -> gate_up_proj_4bit which would break unsloth_zoo's grouped_mm LoRA extractor as-is). The renamed attributes also make the helper idempotent: re-entering it sees `_unsloth_gemma4_moe_4bit_swapped` and no-ops, so multiple calls across nested loaders are safe. No regression on non-MoE checkpoints: the helper only touches modules that are isinstance(Gemma4TextExperts) with the expected 3D shape. Tests cover env-var gating, no-op behaviour on non-Gemma4 models, the transformers-without-gemma4 ImportError path, and idempotence on a stub Gemma4TextExperts module. Refs #5344
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements a per-expert Linear4bit swap for Gemma-4 MoE models, enabling 4-bit quantization for fused expert weights to significantly reduce VRAM usage. The implementation includes a new utility module, integration into the model loading process, and comprehensive unit tests. Review feedback identifies two key memory optimization opportunities: moving fused parameters to the CPU during the swap process to minimize peak GPU VRAM and refactoring the forward pass to avoid large one-hot tensor allocations that could lead to memory issues with long sequences.
| device = gate_up.device | ||
|
|
||
| gate_up_list = nn.ModuleList() | ||
| down_list = nn.ModuleList() | ||
| for e in range(num_experts): | ||
| gu = _quantize_one_expert_to_linear4bit( | ||
| gate_up.data[e], | ||
| compute_dtype = compute_dtype, | ||
| quant_type = quant_type, | ||
| ) | ||
| dp = _quantize_one_expert_to_linear4bit( | ||
| down.data[e], | ||
| compute_dtype = compute_dtype, | ||
| quant_type = quant_type, | ||
| ) | ||
| gate_up_list.append(gu.to(device)) | ||
| down_list.append(dp.to(device)) |
There was a problem hiding this comment.
To minimize peak GPU VRAM during the swap, it is safer to move the fused parameters to CPU before slicing and quantizing. This ensures that we don't have both the full BF16 fused weights and the growing list of 4-bit quantized weights on the GPU simultaneously, which could lead to OOM on devices with limited memory (e.g., 48GB cards attempting to swap a 26B model).
| device = gate_up.device | |
| gate_up_list = nn.ModuleList() | |
| down_list = nn.ModuleList() | |
| for e in range(num_experts): | |
| gu = _quantize_one_expert_to_linear4bit( | |
| gate_up.data[e], | |
| compute_dtype = compute_dtype, | |
| quant_type = quant_type, | |
| ) | |
| dp = _quantize_one_expert_to_linear4bit( | |
| down.data[e], | |
| compute_dtype = compute_dtype, | |
| quant_type = quant_type, | |
| ) | |
| gate_up_list.append(gu.to(device)) | |
| down_list.append(dp.to(device)) | |
| device = gate_up.device | |
| # Move to CPU to minimize peak GPU VRAM during quantization | |
| gate_up_cpu = gate_up.to("cpu") | |
| down_cpu = down.to("cpu") | |
| gate_up_list = nn.ModuleList() | |
| down_list = nn.ModuleList() | |
| for e in range(num_experts): | |
| gu = _quantize_one_expert_to_linear4bit( | |
| gate_up_cpu.data[e], | |
| compute_dtype = compute_dtype, | |
| quant_type = quant_type, | |
| ) | |
| dp = _quantize_one_expert_to_linear4bit( | |
| down_cpu.data[e], | |
| compute_dtype = compute_dtype, | |
| quant_type = quant_type, | |
| ) | |
| gate_up_list.append(gu.to(device)) | |
| down_list.append(dp.to(device)) |
| with torch.no_grad(): | ||
| expert_mask = torch.nn.functional.one_hot( | ||
| top_k_index, | ||
| num_classes = self.num_experts, | ||
| ) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
| expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero() | ||
|
|
||
| for expert_idx in expert_hit: | ||
| expert_idx = expert_idx[0] | ||
| if expert_idx == self.num_experts: | ||
| continue | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
| current_state = hidden_states[token_idx] | ||
| gate_up = self.gate_up_proj_4bit[expert_idx](current_state) | ||
| gate, up = gate_up.chunk(2, dim = -1) | ||
| current_hidden_states = self.act_fn(gate) * up | ||
| current_hidden_states = self.down_proj_4bit[expert_idx](current_hidden_states) | ||
| current_hidden_states = ( | ||
| current_hidden_states * top_k_weights[token_idx, top_k_pos, None] | ||
| ) | ||
| final_hidden_states.index_add_( | ||
| 0, | ||
| token_idx, | ||
| current_hidden_states.to(final_hidden_states.dtype), | ||
| ) |
There was a problem hiding this comment.
The current implementation creates a large one-hot tensor of shape (N, K, E), which can consume significant memory for long sequences or models with many experts. You can achieve the same result more efficiently by iterating over experts and using torch.where directly on the top_k_index tensor. This avoids the (N, K, E) allocation and simplifies the logic by removing the redundant sentinel check.
final_hidden_states = torch.zeros_like(hidden_states)
for expert_idx in range(self.num_experts):
with torch.no_grad():
mask = (top_k_index == expert_idx)
if not mask.any():
continue
token_idx, top_k_pos = torch.where(mask)
current_state = hidden_states[token_idx]
gate_up = self.gate_up_proj_4bit[expert_idx](current_state)
gate, up = gate_up.chunk(2, dim = -1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.down_proj_4bit[expert_idx](current_hidden_states)
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0,
token_idx,
current_hidden_states.to(final_hidden_states.dtype),
)|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to enable 4-bit quantization for Gemma-4 MoE experts by swapping fused 3D parameters with per-expert Linear4bit modules, which significantly reduces VRAM usage. The implementation includes a new utility module, integration into the vision model loading path, and comprehensive unit tests. Feedback from the review identifies a redundant index check in the per-expert forward pass that can be safely removed.
| if expert_idx == self.num_experts: | ||
| continue | ||
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) |
There was a problem hiding this comment.
This check is redundant. Since expert_mask is generated using one_hot with num_classes = self.num_experts (line 72), the indices in the first dimension of expert_mask are strictly in the range [0, self.num_experts - 1]. Consequently, expert_idx obtained from expert_hit (which indexes into that first dimension) can never equal self.num_experts. If top_k_index contained a value equal to self.num_experts, the one_hot call would have already raised a RuntimeError.
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])The Gemma-4 MoE per-expert Linear4bit swap previously gated only on the positional load_in_4bit argument. loader.py forwards load_in_4bit=False to FastBaseModel.from_pretrained whenever the caller supplies a quantization_config (BitsAndBytesConfig), so callers that opt in via UNSLOTH_GEMMA4_MOE_4BIT=1 plus BitsAndBytesConfig(load_in_4bit=True) silently bypassed the swap. The adjacent guardrail already normalises load_in_4bit from quantization_config; the swap gate now does the same and sources bnb_4bit_compute_dtype from quantization_config when no local bnb_config is built. The except branch around the swap also previously stated "Falling back to BF16 experts", which misrepresents the model state when the helper fails partway through (already-swapped Gemma4TextExperts modules stay in 4-bit; only the remainder remain BF16). The warning now counts the modules marked _unsloth_gemma4_moe_4bit_swapped and reports the partial state, advising a reload to recover a uniform state. The comment above the fused-Parameter dels in gemma4_moe_4bit.py overstated the swap's memory bound; rephrased to describe the actual per-module peak (fused BF16 plus accumulated per-expert nf4).
- Normalize string compute_dtype values from dict-style quantization_config
(e.g. {"bnb_4bit_compute_dtype": "bfloat16"}) into torch.dtype before
forwarding to bitsandbytes. A raw string would propagate to
bnb.nn.Linear4bit and crash the first forward with
"Invalid device string: 'bfloat16'".
- Forward bnb_4bit_quant_type from the user's BitsAndBytesConfig or dict
config into swap_gemma4_experts_to_per_expert_linear4bit so swapped
experts match the quantization type used by the rest of the model
(previously always nf4 even when the caller requested fp4).
- Wrap the swap and its warning into a local closure and invoke it from
both the regular auto_model.from_pretrained branch and the
fast_inference=True / convert_vllm_to_huggingface branch. The closure
is idempotent on non-Gemma-4 models, so the vLLM call is free when the
loaded model has no Gemma4TextExperts modules.
- Escalate partial-state swap failures: if the helper raises after one
or more Gemma4TextExperts modules were already committed to 4-bit,
the wrapper re-raises a RuntimeError instructing the caller to reload
the model. Previously a warning implied a clean BF16 fallback, which
is false when partial conversion has already occurred.
The closure is multi-line (long block) because it needs to capture the
already-resolved quantization parameters and be reusable across both
load paths; the alternative is duplicating the entire block.
Shorten WHAT-style narrative on private helpers (_per_expert_forward, _quantize_one_expert_to_linear4bit) to one-line WHY statements; collapse the three-line per-module peak-VRAM note to a single line; drop the three-line opt-in description at the swap call site since the closure name already conveys the intent.
63c0b32 to
6746f4d
Compare
Staging mirror of unslothai/unsloth#5432
Original PR: unslothai/unsloth#5432
Author: danielhanchen
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
unsloth/gemma-4-26B-A4B-itloads at around 46 GB even withload_in_4bit=True. The cause is structural rather than a flag being dropped:Gemma4TextExpertsstores all experts as two fused 3Dnn.Parametertensors (sotorch._grouped_mmcan dispatch one grouped matmul per layer), andbnb.replace_with_bnb_linearonly swapsnn.Linearinstances. The fused expert weights stay BF16 and dominate the VRAM footprint.This PR adds an opt-in helper that, after the model is loaded, walks every
Gemma4TextExpertsmodule, slices each fused(num_experts, out, in)Parameter intonum_expertsindividualbitsandbytes.nn.Linear4bitmodules, and patchesforwardto dispatch per-expert.Measured impact
unsloth/gemma-4-26B-A4B-iton a single B200, transformers 5.5.0, torch 2.9.1:full_finetuning=True)load_in_4bit=True, no swapload_in_4bit=True, swap onForward-pass logits compared on a fixed prompt; cosine similarity 0.994 is standard QLoRA fidelity. Top-5 tokens overlap; argmax differs by 1 token as expected for nf4.
Trade-off
Per-expert dispatch loses the
torch._grouped_mmthroughput. Acceptable for "the model fits at 4-bit on a single GPU", which is what#5344is asking for. QLoRA training still needs the matching per-expert LoRA path; that is the gating reason this PR keeps the swap behindUNSLOTH_GEMMA4_MOE_4BIT=1.The swap also renames
gate_up_projtogate_up_proj_4bit(and the same fordown_proj) so it does not collide withunsloth_zoo'sgemma4_moe.pygrouped-mm LoRA extractor which keys off the original names. That extractor would need a parallel_4bit-aware path before we can flip the default to ON.Activation
The hook is only entered when
load_in_4bit=Trueandfull_finetuning=False, and only fires if the model actually containsGemma4TextExpertsmodules. On every other checkpoint the helper is a no-op.Safety
_unsloth_gemma4_moe_4bit_swappedon the module and skips it.(num_experts, 2*intermediate, hidden)and(num_experts, hidden, intermediate)layout.This PR tracks the moving review branch (pr-5432-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.
Changed files:
.github/workflows/consolidated-tests-ci.yml.github/workflows/lint-ci.yml.github/workflows/mlx-ci.yml.github/workflows/notebooks-ci.yml.github/workflows/release-desktop.yml.github/workflows/security-audit.yml.github/workflows/stale.yml.github/workflows/studio-api-smoke.yml.github/workflows/studio-backend-ci.yml.github/workflows/studio-frontend-ci.yml.github/workflows/studio-inference-smoke.yml.github/workflows/studio-mac-api-smoke.yml.github/workflows/studio-mac-inference-smoke.yml.github/workflows/studio-mac-ui-smoke.yml.github/workflows/studio-mac-update-smoke.yml.github/workflows/studio-tauri-smoke.yml.github/workflows/studio-ui-smoke.yml.github/workflows/studio-update-smoke.yml.github/workflows/studio-windows-api-smoke.yml.github/workflows/studio-windows-inference-smoke.yml.github/workflows/studio-windows-ui-smoke.yml.github/workflows/studio-windows-update-smoke.yml.github/workflows/version-compat-ci.yml.github/workflows/wheel-smoke.ymlunsloth/models/gemma4_moe_4bit.pyunsloth/models/vision.pytests/test_gemma4_moe_4bit_swap.py