Skip to content

gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#48

Open
danielhanchen wants to merge 28 commits into
fix-issue-5344-quantization-guardrailfrom
pr-5432-head
Open

gemma-4 moe: per-expert Linear4bit swap so 26B-A4B fits at 4-bit (#5344)#48
danielhanchen wants to merge 28 commits into
fix-issue-5344-quantization-guardrailfrom
pr-5432-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Staging mirror of unslothai/unsloth#5432

Original PR: unslothai/unsloth#5432
Author: danielhanchen

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

unsloth/gemma-4-26B-A4B-it loads at around 46 GB even with load_in_4bit=True. The cause is structural rather than a flag being dropped: Gemma4TextExperts stores all experts as two fused 3D nn.Parameter tensors (so torch._grouped_mm can dispatch one grouped matmul per layer), and bnb.replace_with_bnb_linear only swaps nn.Linear instances. The fused expert weights stay BF16 and dominate the VRAM footprint.

This PR adds an opt-in helper that, after the model is loaded, walks every Gemma4TextExperts module, slices each fused (num_experts, out, in) Parameter into num_experts individual bitsandbytes.nn.Linear4bit modules, and patches forward to dispatch per-expert.

Measured impact

unsloth/gemma-4-26B-A4B-it on a single B200, transformers 5.5.0, torch 2.9.1:

mode resident VRAM peak (fwd) Linear4bit count forward fidelity
BF16 baseline (full_finetuning=True) 48.08 GB 48.20 GB 0 reference
load_in_4bit=True, no swap ~46 GB n/a 206 (attention only) partial-quant
load_in_4bit=True, swap on 14.27 GB 14.40 GB 7,886 cosine 0.994 vs BF16

Forward-pass logits compared on a fixed prompt; cosine similarity 0.994 is standard QLoRA fidelity. Top-5 tokens overlap; argmax differs by 1 token as expected for nf4.

Trade-off

Per-expert dispatch loses the torch._grouped_mm throughput. Acceptable for "the model fits at 4-bit on a single GPU", which is what #5344 is asking for. QLoRA training still needs the matching per-expert LoRA path; that is the gating reason this PR keeps the swap behind UNSLOTH_GEMMA4_MOE_4BIT=1.

The swap also renames gate_up_proj to gate_up_proj_4bit (and the same for down_proj) so it does not collide with unsloth_zoo's gemma4_moe.py grouped-mm LoRA extractor which keys off the original names. That extractor would need a parallel _4bit-aware path before we can flip the default to ON.

Activation

import os
os.environ["UNSLOTH_GEMMA4_MOE_4BIT"] = "1"

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    "unsloth/gemma-4-26B-A4B-it",
    load_in_4bit = True,
)
# Unsloth: swapped 30 Gemma4TextExperts module(s) to per-expert Linear4bit
# (see https://github.com/unslothai/unsloth/issues/5344).

The hook is only entered when load_in_4bit=True and full_finetuning=False, and only fires if the model actually contains Gemma4TextExperts modules. On every other checkpoint the helper is a no-op.

Safety

  • Idempotent: re-running the swap sees _unsloth_gemma4_moe_4bit_swapped on the module and skips it.
  • Shape-checked: bails out without modifying weights if the fused Parameters do not have the (num_experts, 2*intermediate, hidden) and (num_experts, hidden, intermediate) layout.
  • Failure-soft: if any error is raised during the swap, falls back to BF16 experts and warns the user. The existing guardrail from #5430 will then fire and explain w

This PR tracks the moving review branch (pr-5432-head). Iteration fix commits land here directly. Review-added tests are in a separate PR.

Changed files:

  • .github/workflows/consolidated-tests-ci.yml
  • .github/workflows/lint-ci.yml
  • .github/workflows/mlx-ci.yml
  • .github/workflows/notebooks-ci.yml
  • .github/workflows/release-desktop.yml
  • .github/workflows/security-audit.yml
  • .github/workflows/stale.yml
  • .github/workflows/studio-api-smoke.yml
  • .github/workflows/studio-backend-ci.yml
  • .github/workflows/studio-frontend-ci.yml
  • .github/workflows/studio-inference-smoke.yml
  • .github/workflows/studio-mac-api-smoke.yml
  • .github/workflows/studio-mac-inference-smoke.yml
  • .github/workflows/studio-mac-ui-smoke.yml
  • .github/workflows/studio-mac-update-smoke.yml
  • .github/workflows/studio-tauri-smoke.yml
  • .github/workflows/studio-ui-smoke.yml
  • .github/workflows/studio-update-smoke.yml
  • .github/workflows/studio-windows-api-smoke.yml
  • .github/workflows/studio-windows-inference-smoke.yml
  • .github/workflows/studio-windows-ui-smoke.yml
  • .github/workflows/studio-windows-update-smoke.yml
  • .github/workflows/version-compat-ci.yml
  • .github/workflows/wheel-smoke.yml
  • unsloth/models/gemma4_moe_4bit.py
  • unsloth/models/vision.py
  • tests/test_gemma4_moe_4bit_swap.py

danielhanchen and others added 24 commits May 15, 2026 03:49
unsloth/gemma-4-26B-A4B-it loads at ~46 GB even with load_in_4bit=True
because Gemma4TextExperts stores experts as fused 3D nn.Parameter tensors
(gate_up_proj of shape (128, 1408, 2816), down_proj of (128, 2816, 704))
so torch._grouped_mm can dispatch a single grouped matmul per layer.
bitsandbytes' replace_with_bnb_linear only swaps nn.Linear instances, so
the fused expert weights stay BF16 and dominate the VRAM footprint.

This adds an opt-in helper that walks the loaded model, finds every
Gemma4TextExperts module, slices each fused (E, O, I) Parameter into E
individual bnb.nn.Linear4bit modules (per-expert), and patches forward
to dispatch per-expert instead of via torch._grouped_mm.

Trade-off:

- VRAM win: 46 GB -> 14.27 GB resident on unsloth/gemma-4-26B-A4B-it
  (B200, transformers 5.5.0, single GPU). Linear4bit count 206 -> 7886.
  Forward-pass cosine similarity vs BF16 reference is 0.994 on a fixed
  prompt, i.e. standard QLoRA fidelity.

- Throughput loss: per-expert dispatch loses the grouped_mm speedup.
  Acceptable for "model fits at 4-bit on a single GPU"; QLoRA training
  still needs the matching per-expert LoRA path which is not in this PR.

Gated on UNSLOTH_GEMMA4_MOE_4BIT=1, default off until the per-expert
LoRA path lands (the swap renames gate_up_proj -> gate_up_proj_4bit
which would break unsloth_zoo's grouped_mm LoRA extractor as-is).

The renamed attributes also make the helper idempotent: re-entering it
sees `_unsloth_gemma4_moe_4bit_swapped` and no-ops, so multiple calls
across nested loaders are safe.

No regression on non-MoE checkpoints: the helper only touches modules
that are isinstance(Gemma4TextExperts) with the expected 3D shape.

Tests cover env-var gating, no-op behaviour on non-Gemma4 models, the
transformers-without-gemma4 ImportError path, and idempotence on a stub
Gemma4TextExperts module.

Refs #5344
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a per-expert Linear4bit swap for Gemma-4 MoE models, enabling 4-bit quantization for fused expert weights to significantly reduce VRAM usage. The implementation includes a new utility module, integration into the model loading process, and comprehensive unit tests. Review feedback identifies two key memory optimization opportunities: moving fused parameters to the CPU during the swap process to minimize peak GPU VRAM and refactoring the forward pass to avoid large one-hot tensor allocations that could lead to memory issues with long sequences.

Comment on lines +178 to +194
device = gate_up.device

gate_up_list = nn.ModuleList()
down_list = nn.ModuleList()
for e in range(num_experts):
gu = _quantize_one_expert_to_linear4bit(
gate_up.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
dp = _quantize_one_expert_to_linear4bit(
down.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
gate_up_list.append(gu.to(device))
down_list.append(dp.to(device))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To minimize peak GPU VRAM during the swap, it is safer to move the fused parameters to CPU before slicing and quantizing. This ensures that we don't have both the full BF16 fused weights and the growing list of 4-bit quantized weights on the GPU simultaneously, which could lead to OOM on devices with limited memory (e.g., 48GB cards attempting to swap a 26B model).

Suggested change
device = gate_up.device
gate_up_list = nn.ModuleList()
down_list = nn.ModuleList()
for e in range(num_experts):
gu = _quantize_one_expert_to_linear4bit(
gate_up.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
dp = _quantize_one_expert_to_linear4bit(
down.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
gate_up_list.append(gu.to(device))
down_list.append(dp.to(device))
device = gate_up.device
# Move to CPU to minimize peak GPU VRAM during quantization
gate_up_cpu = gate_up.to("cpu")
down_cpu = down.to("cpu")
gate_up_list = nn.ModuleList()
down_list = nn.ModuleList()
for e in range(num_experts):
gu = _quantize_one_expert_to_linear4bit(
gate_up_cpu.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
dp = _quantize_one_expert_to_linear4bit(
down_cpu.data[e],
compute_dtype = compute_dtype,
quant_type = quant_type,
)
gate_up_list.append(gu.to(device))
down_list.append(dp.to(device))

Comment on lines +69 to +94
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
top_k_index,
num_classes = self.num_experts,
)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate_up = self.gate_up_proj_4bit[expert_idx](current_state)
gate, up = gate_up.chunk(2, dim = -1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.down_proj_4bit[expert_idx](current_hidden_states)
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0,
token_idx,
current_hidden_states.to(final_hidden_states.dtype),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation creates a large one-hot tensor of shape (N, K, E), which can consume significant memory for long sequences or models with many experts. You can achieve the same result more efficiently by iterating over experts and using torch.where directly on the top_k_index tensor. This avoids the (N, K, E) allocation and simplifies the logic by removing the redundant sentinel check.

    final_hidden_states = torch.zeros_like(hidden_states)
    for expert_idx in range(self.num_experts):
        with torch.no_grad():
            mask = (top_k_index == expert_idx)
            if not mask.any():
                continue
            token_idx, top_k_pos = torch.where(mask)

        current_state = hidden_states[token_idx]
        gate_up = self.gate_up_proj_4bit[expert_idx](current_state)
        gate, up = gate_up.chunk(2, dim = -1)
        current_hidden_states = self.act_fn(gate) * up
        current_hidden_states = self.down_proj_4bit[expert_idx](current_hidden_states)
        current_hidden_states = (
            current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
        )
        final_hidden_states.index_add_(
            0,
            token_idx,
            current_hidden_states.to(final_hidden_states.dtype),
        )

@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to enable 4-bit quantization for Gemma-4 MoE experts by swapping fused 3D parameters with per-expert Linear4bit modules, which significantly reduces VRAM usage. The implementation includes a new utility module, integration into the vision model loading path, and comprehensive unit tests. Feedback from the review identifies a redundant index check in the per-expert forward pass that can be safely removed.

Comment on lines +79 to +81
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

This check is redundant. Since expert_mask is generated using one_hot with num_classes = self.num_experts (line 72), the indices in the first dimension of expert_mask are strictly in the range [0, self.num_experts - 1]. Consequently, expert_idx obtained from expert_hit (which indexes into that first dimension) can never equal self.num_experts. If top_k_index contained a value equal to self.num_experts, the one_hot call would have already raised a RuntimeError.

        top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

The Gemma-4 MoE per-expert Linear4bit swap previously gated only on the
positional load_in_4bit argument. loader.py forwards load_in_4bit=False
to FastBaseModel.from_pretrained whenever the caller supplies a
quantization_config (BitsAndBytesConfig), so callers that opt in via
UNSLOTH_GEMMA4_MOE_4BIT=1 plus BitsAndBytesConfig(load_in_4bit=True)
silently bypassed the swap. The adjacent guardrail already normalises
load_in_4bit from quantization_config; the swap gate now does the same
and sources bnb_4bit_compute_dtype from quantization_config when no
local bnb_config is built.

The except branch around the swap also previously stated "Falling back
to BF16 experts", which misrepresents the model state when the helper
fails partway through (already-swapped Gemma4TextExperts modules stay
in 4-bit; only the remainder remain BF16). The warning now counts the
modules marked _unsloth_gemma4_moe_4bit_swapped and reports the partial
state, advising a reload to recover a uniform state.

The comment above the fused-Parameter dels in gemma4_moe_4bit.py
overstated the swap's memory bound; rephrased to describe the actual
per-module peak (fused BF16 plus accumulated per-expert nf4).
- Normalize string compute_dtype values from dict-style quantization_config
  (e.g. {"bnb_4bit_compute_dtype": "bfloat16"}) into torch.dtype before
  forwarding to bitsandbytes. A raw string would propagate to
  bnb.nn.Linear4bit and crash the first forward with
  "Invalid device string: 'bfloat16'".

- Forward bnb_4bit_quant_type from the user's BitsAndBytesConfig or dict
  config into swap_gemma4_experts_to_per_expert_linear4bit so swapped
  experts match the quantization type used by the rest of the model
  (previously always nf4 even when the caller requested fp4).

- Wrap the swap and its warning into a local closure and invoke it from
  both the regular auto_model.from_pretrained branch and the
  fast_inference=True / convert_vllm_to_huggingface branch. The closure
  is idempotent on non-Gemma-4 models, so the vLLM call is free when the
  loaded model has no Gemma4TextExperts modules.

- Escalate partial-state swap failures: if the helper raises after one
  or more Gemma4TextExperts modules were already committed to 4-bit,
  the wrapper re-raises a RuntimeError instructing the caller to reload
  the model. Previously a warning implied a clean BF16 fallback, which
  is false when partial conversion has already occurred.

The closure is multi-line (long block) because it needs to capture the
already-resolved quantization parameters and be reusable across both
load paths; the alternative is duplicating the entire block.
Shorten WHAT-style narrative on private helpers (_per_expert_forward,
_quantize_one_expert_to_linear4bit) to one-line WHY statements; collapse
the three-line per-module peak-VRAM note to a single line; drop the
three-line opt-in description at the swap call site since the closure
name already conveys the intent.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant