From 0a15e1d09a6b915f1c62f225b454ea31e67fa692 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 May 2026 12:30:44 +0700 Subject: [PATCH 1/5] fix: refactor kernels patch to drop routing and inject into Expert registry --- .../expert_parallel/experts_fn.py | 9 +- src/axolotl/integrations/kernels/README.md | 175 ++-- src/axolotl/integrations/kernels/args.py | 43 +- src/axolotl/integrations/kernels/constants.py | 105 +-- .../{gemma4_experts.py => experts.py} | 59 +- .../kernels/libs/sonicmoe/__init__.py | 7 +- .../kernels/libs/sonicmoe/experts.py | 143 +++ .../kernels/libs/sonicmoe/gemma4_experts.py | 106 --- .../kernels/libs/sonicmoe/lora.py | 43 +- .../kernels/libs/sonicmoe/patch.py | 272 ------ .../kernels/libs/sonicmoe/routing.py | 576 ------------ .../kernels/libs/sonicmoe/weight_converter.py | 202 ---- src/axolotl/integrations/kernels/plugin.py | 148 +-- tests/e2e/integrations/test_sonicmoe.py | 258 ++---- tests/e2e/integrations/test_sonicmoe_lora.py | 206 ++--- tests/integrations/test_gemma4_moe.py | 104 +-- tests/integrations/test_routing_parity.py | 492 ---------- tests/integrations/test_sonicmoe.py | 874 +++--------------- tests/integrations/test_sonicmoe_gradients.py | 158 ---- tests/integrations/test_sonicmoe_lora.py | 49 - 20 files changed, 599 insertions(+), 3430 deletions(-) rename src/axolotl/integrations/kernels/libs/scattermoe_lora/{gemma4_experts.py => experts.py} (72%) create mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/experts.py delete mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py delete mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/patch.py delete mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/routing.py delete mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py delete mode 100644 tests/integrations/test_routing_parity.py delete mode 100644 tests/integrations/test_sonicmoe_gradients.py diff --git a/src/axolotl/integrations/expert_parallel/experts_fn.py b/src/axolotl/integrations/expert_parallel/experts_fn.py index 1cc81ef293..89473fd112 100644 --- a/src/axolotl/integrations/expert_parallel/experts_fn.py +++ b/src/axolotl/integrations/expert_parallel/experts_fn.py @@ -59,7 +59,7 @@ def _grouped_mm_local(experts, recv_x, recv_topk_idx, recv_topk_weights): def _scattermoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights): - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( scattermoe_experts_forward, ) @@ -68,13 +68,6 @@ def _scattermoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights): def _sonicmoe_local(experts, recv_x, recv_topk_idx, recv_topk_weights): raise NotImplementedError("Sonicmoe + EP is not yet properly implemented.") - # from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import ( - # gemma4_sonicmoe_experts_forward, - # ) - - # return gemma4_sonicmoe_experts_forward( - # experts, recv_x, recv_topk_idx, recv_topk_weights - # ) _LOCAL_KERNELS = { diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 32d236da49..338381054a 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -1,16 +1,17 @@ # Kernels Integration -MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg: +MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. Transformers v5 introduced a uniform dispatch point for the per-expert grouped GEMMs via the `experts_implementation` config kwarg: ```python class ExpertsInterface(GeneralInterface): _global_mapping = { "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, + "sonicmoe": sonicmoe_experts_forward, # upstream HF integration } ``` -In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`. +Axolotl registers two additional implementations into this same global registry: **ScatterMoE** (Triton, runs on any CUDA GPU) and a LoRA-aware **SonicMoE** variant (CUTLASS / cute-DSL, Hopper or newer). Routing — softmax/sigmoid top-k, group selection, shared experts, bias correction, etc. — stays in each model's `SparseMoEBlock`, where transformers handles all per-architecture variation. Axolotl only swaps the experts forward. ## Usage @@ -28,130 +29,76 @@ use_scattermoe: true use_sonicmoe: true ``` -**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`. +`experts_implementation` is auto-set to `scattermoe` / `sonicmoe` from the kernel flag, but you can override to `eager` / `batched_mm` / `grouped_mm` to compare against the transformers reference implementations. ### SonicMoE installation **Prerequisites:** -- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU +- NVIDIA Hopper (H100/H200) or Blackwell (B200/GB200/B300) GPU - CUDA 12.9+ (13.0+ for B300) -- PyTorch 2.7+ (2.9.1 recommended) -- For B300: Triton 3.6.0 +- PyTorch 2.7+ +- For B300: Triton 3.6.x + +Sonic-MoE itself is loaded lazily from the HF [`kernels-community/sonic-moe`](https://huggingface.co/kernels-community/sonic-moe) hub on first use via the `kernels` package — no manual install is needed for the runtime. For from-source development: ```bash -pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5 +pip install --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git" \ + "nvidia-cutlass-dsl>=4.4.0" "quack-kernels>=0.3.0" ``` -See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details. - -**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels. +**Note:** Blackwell support is in upstream beta. On Blackwell GPUs Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels. ## How It Works -The `KernelsPlugin` runs before model loading and: - -### ScatterMoE -1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels). -2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library. - -### SonicMoE -1. Resolves the model's MoE block class(es) from `constants.py`. -2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format. -3. Supports pluggable routing strategies (see routing table below). - -Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution. - -## Model Support Matrix - -Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum. - -### Routing strategies - -| Routing Strategy | Description | ScatterMoE | SonicMoE | -|---|---|:---:|:---:| -| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes | -| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes | -| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes | -| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes | -| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes | -| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes | -| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes | -| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes | -| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned | - -### Per-model support - -| Model Type | Architecture | Routing | ScatterMoE | SonicMoE | -|---|---|---|:---:|:---:| -| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** | -| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** | -| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** | -| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** | -| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** | -| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** | -| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** | -| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** | -| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** | -| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** | -| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** | -| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** | -| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** | -| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** | -| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** | -| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** | -| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** | -| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** | -| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** | -| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** | -| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* | -| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned | - -\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations. - -\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below. - -### Feature comparison - -| Feature | ScatterMoE | SonicMoE | -|---|:---:|:---:| -| Kernel backend | Triton | CUTLASS | -| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) | -| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd | -| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) | -| Gate/router LoRA | Yes | Yes | -| Expert LoRA | Yes (fused) | Yes (materialized) | -| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) | -| Selective expert dequantization | Yes (~97% memory savings) | No | -| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` | -| torch.compile routing | No | Yes (optional) | - -## Shared Expert Handling - -Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority: - -1. `shared_expert` (Qwen2-MoE) -2. `shared_experts` (GLM-MoE, DeepSeek-V3) -3. `shared_mlp` (HunYuan V1 MoE) - -If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed. - -## Gemma 4 - -Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture: - -- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed. -- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales. -- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models). -- **128 experts, top-k=8** for the 26B-A4B variant. - -Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is. - -## Limitations - -- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`). -- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant). -- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed. -- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP. +The `KernelsPlugin` runs once before model loading and: + +1. Calls `register_scattermoe_experts()` or `register_sonicmoe_experts()`, which inserts the kernel forward into `transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS`. +2. Sets `cfg.experts_implementation` to the matching name. +3. When the model loads, transformers' `@use_experts_implementation` decorator on each model's `Experts` class reads `config._experts_implementation` and dispatches to our registered forward. + +That's the entire integration — there is no per-architecture SparseMoEBlock monkey-patch, no per-model routing code, and no weight-layout conversion. As new MoE models adopt the decorator upstream they immediately benefit from both kernels. + +## LoRA Support + +Both kernels train PEFT adapters on `gate_up_proj` / `down_proj` (and `gate` for the router) end-to-end: + +- **ScatterMoE** fuses the LoRA `B @ A` product into the per-expert grouped GEMM via custom Triton kernels (`parallel_linear_lora`). No extra materialization pass. +- **SonicMoE** materializes `W_eff = W + scaling * (B @ A)` per expert inside a custom `MoELoRAMaterialize` `autograd.Function` and passes the effective weight into the CUTLASS kernel. Backward decomposes `dW_eff` into `dA` and `dB` via the chain rule, so LoRA parameters train without modifying the kernel. + +Both paths detect PEFT `ParamWrapper` on individual expert parameters (`target_parameters` API) and unwrap them before dispatch. + +## Model Support + +Any model whose `Experts` class is decorated with `@use_experts_implementation` upstream works automatically. As of transformers 5.8 this includes (verified): + +| Model Type | ScatterMoE | SonicMoE | +|-------------------|:---------:|:--------:| +| `mixtral` | Yes | Yes | +| `qwen2_moe` | Yes | Yes | +| `qwen3_moe` | Yes | Yes | +| `qwen3_5_moe` | Yes | Yes | +| `olmoe` | Yes | Yes | +| `mistral4` | Yes | Yes | +| `glm_moe_dsa` | Yes | Yes | +| `deepseek_v3` | Yes | Yes | +| `minimax_m2` | Yes | Yes | +| `ernie4_5_moe` | Yes | Yes | +| `hunyuan_v1_moe` | Yes | Yes | +| `gemma4_text` | Yes | Yes | +| `gpt_oss` | Yes | Yes | + +For `gpt_oss` the upstream decorator carries `is_concatenated=False, is_transposed=True, has_bias=True`; the registered forward reads these flags off `self` and adjusts permutation / bias handling accordingly. + +## Feature comparison + +| Feature | ScatterMoE | SonicMoE | +|----------------------------------|:----------:|:--------:| +| Kernel backend | Triton | CUTLASS / cute-DSL | +| GPU requirement | Any CUDA | Hopper+ | +| LoRA path | Fused in Triton kernel | `MoELoRAMaterialize` + custom autograd | +| LoRA overhead | Lower (fused) | Higher (materialization pass) | +| Selective expert dequantization | Yes (~97% memory savings) | No | +| Weight format | Standard `[E, 2*I, H]` | Standard `[E, 2*I, H]` (concat layout, no interleave) | ## Note on MegaBlocks diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index f532fde417..bff9e85081 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -5,6 +5,17 @@ LOG = get_logger(__name__) +# Valid experts_implementation values: +# - "eager" : transformers' per-token loop reference implementation +# - "batched_mm" : transformers' built-in batched matmul path +# - "grouped_mm" : transformers' built-in grouped matmul path (cache-efficient) +# - "scattermoe" : axolotl-registered Triton kernels with LoRA support +# - "sonicmoe" : axolotl-registered CUTLASS / cute-DSL kernels with LoRA support +_BUILTIN_EXPERTS_IMPLS = {"eager", "batched_mm", "grouped_mm"} +_KERNEL_EXPERTS_IMPLS = {"scattermoe", "sonicmoe"} +_VALID_EXPERTS_IMPLS = _BUILTIN_EXPERTS_IMPLS | _KERNEL_EXPERTS_IMPLS + + class KernelsArgs(BaseModel): use_scattermoe: bool | None = None use_sonicmoe: bool | None = None @@ -33,21 +44,39 @@ def check_use_kernels(cls, data): @model_validator(mode="before") @classmethod def check_experts_implementation(cls, data): + """Auto-select impl from kernel flags; reject mismatched/unknown values.""" experts_implementation = data.get("experts_implementation") - use_scattermoe = data.get("use_scattermoe", False) + use_scattermoe = bool(data.get("use_scattermoe")) + use_sonicmoe = bool(data.get("use_sonicmoe")) + if experts_implementation is None: - # transformers may default to batched_mm when unset - data["experts_implementation"] = "eager" - elif experts_implementation == "scattermoe" and not use_scattermoe: + if use_scattermoe: + data["experts_implementation"] = "scattermoe" + elif use_sonicmoe: + data["experts_implementation"] = "sonicmoe" + else: + # Transformers defaults to a non-eager backend when unset; pin to + # eager unless the user explicitly opts in. + data["experts_implementation"] = "eager" + return data + + if experts_implementation == "scattermoe" and not use_scattermoe: LOG.warning( "`experts_implementation='scattermoe'` requires `use_scattermoe: true`. " "Automatically setting to 'eager'." ) data["experts_implementation"] = "eager" - elif experts_implementation not in ("eager", "scattermoe"): + elif experts_implementation == "sonicmoe" and not use_sonicmoe: + LOG.warning( + "`experts_implementation='sonicmoe'` requires `use_sonicmoe: true`. " + "Automatically setting to 'eager'." + ) + data["experts_implementation"] = "eager" + elif experts_implementation not in _VALID_EXPERTS_IMPLS: LOG.warning( - f"`experts_implementation={experts_implementation!r}` is not compatible with " - f"custom MoE kernels. Automatically setting to 'eager'." + f"`experts_implementation={experts_implementation!r}` is not recognized. " + f"Valid options: {sorted(_VALID_EXPERTS_IMPLS)}. " + f"Automatically setting to 'eager'." ) data["experts_implementation"] = "eager" diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index 5239c98778..7373fa5ef2 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -1,76 +1,16 @@ -""" -Supported MoE block mappings for kernel integrations. - -Maps model_type to the SparseMoeBlock class name(s) in transformers. -Used by both ScatterMoE and SonicMoE kernel paths. - -Values can be a single class name (str) or a list of class names for models -with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker). - -Models with custom routing (see sonicmoe/routing.py for implementations): -- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing) -- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing) -- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing) -- gemma4_text: RMSNorm→scale→proj→softmax→topk→renorm→per_expert_scale (experts-level patch) -""" +"""Diagnostic helpers for MoE kernel integrations (kernel dispatch itself +is architecture-agnostic via the ExpertsInterface).""" import importlib -SPARSE_MOE_BLOCK = { - # softmax -> topk routing - "qwen2_moe": "Qwen2MoeSparseMoeBlock", - "qwen3_moe": "Qwen3MoeSparseMoeBlock", - "qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock", - "qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock", - "qwen3_next": "Qwen3NextSparseMoeBlock", - "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", - # qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate) - "qwen3_omni_moe": [ - "Qwen3OmniMoeThinkerTextSparseMoeBlock", - "Qwen3OmniMoeTalkerTextSparseMoeBlock", - ], - "olmoe": "OlmoeSparseMoeBlock", - "mixtral": "MixtralSparseMoeBlock", - "minimax": "MiniMaxSparseMoeBlock", - # softmax -> topk routing (with group-based expert selection) - "mistral4": "Mistral4MoE", - # sigmoid -> topk routing (with group-based expert selection) - "glm_moe_dsa": "GlmMoeDsaMoE", - "deepseek_v3": "DeepseekV3MoE", - "glm4_moe": "Glm4MoeMoE", - "glm4_moe_lite": "Glm4MoeLiteMoE", - "glm4v_moe": "Glm4vMoeTextMoE", - # sigmoid -> topk routing (no group selection) - "minimax_m2": "MiniMaxM2SparseMoeBlock", - # softmax->topk, e_score_correction_bias between softmax and topk - "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", - # softmax->topk, group_limited_greedy, different attr names (num_group) - "deepseek_v2": "DeepseekV2Moe", - # softmax->topk, gate.wg (not gate.weight) - "hunyuan_v1_moe": "HunYuanMoEV1Moe", - # TODO: gpt_oss deferred — transposed weight layout [E,H,2*I], expert biases, - # and custom GLU activation require a dedicated forward path in patch.py. - # "gpt_oss": "GptOssMLP", -} - - -# Models where MoE is NOT in a separate SparseMoeBlock but embedded in the -# decoder layer. For these, we patch the Experts class forward directly -# (same signature: hidden_states, top_k_index, top_k_weights -> Tensor). -# Routing stays untouched — the original model router runs as-is. +# Models where MoE is embedded in the decoder layer (no separate SparseMoeBlock). EXPERTS_ONLY_BLOCK = { - # gemma4: hybrid MLP+MoE in decoder layer, custom Gemma4TextRouter, - # no SparseMoeBlock. Experts use @use_experts_implementation with - # standard 3D param layout (gate_up_proj [E, 2*I, H], down_proj [E, H, I]). "gemma4_text": "Gemma4TextExperts", } def resolve_experts_class(model_type: str): - """Resolve the Experts class for models that need experts-level patching. - - Returns the class, or None if the model uses SparseMoeBlock-level patching. - """ + """Resolve the Experts class for a known model type, or ``None``.""" entry = EXPERTS_ONLY_BLOCK.get(model_type) if entry is None: return None @@ -93,41 +33,4 @@ def resolve_experts_class(model_type: str): def is_experts_only_model(model_type: str) -> bool: - """Check if a model type requires experts-level (not block-level) patching.""" return model_type in EXPERTS_ONLY_BLOCK - - -def resolve_moe_block_classes(model_type: str): - """Resolve all MoE block classes from transformers for the given model type. - - Returns a list of classes (one for most models, multiple for models with - distinct MoE block types like qwen3_omni_moe). - """ - entry = SPARSE_MOE_BLOCK.get(model_type) - if entry is None: - raise ValueError( - f"Unsupported MoE model type '{model_type}'. " - f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}" - ) - - cls_names = entry if isinstance(entry, list) else [entry] - module_path = f"transformers.models.{model_type}.modeling_{model_type}" - try: - module = importlib.import_module(module_path) - except ModuleNotFoundError: - # Text sub-model types (e.g. qwen3_5_moe_text) share the parent module - if model_type.endswith("_text"): - parent_type = model_type.removesuffix("_text") - module_path = f"transformers.models.{parent_type}.modeling_{parent_type}" - module = importlib.import_module(module_path) - else: - raise - - classes = [] - for cls_name in cls_names: - moe_cls = getattr(module, cls_name, None) - if moe_cls is None: - raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'") - classes.append(moe_cls) - - return classes diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py similarity index 72% rename from src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py rename to src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py index 66623e0173..ba11249ec4 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/gemma4_experts.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py @@ -1,18 +1,7 @@ -""" -ScatterMoE-accelerated experts forward for Gemma4. - -Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer. -The decoder layer handles routing (Gemma4TextRouter) and calls -``experts(hidden_states, top_k_index, top_k_weights)`` directly. +"""ScatterMoE experts forward for the transformers ExpertsInterface. -This module registers a ``"scattermoe"`` implementation in the transformers -``ExpertsInterface``, which the ``@use_experts_implementation`` decorator -dispatches to when ``config._experts_implementation == "scattermoe"``. - -This is the clean way to hook into transformers' MoE dispatch — no -monkeypatching required. Works for Gemma4 and any future model that uses -``@use_experts_implementation`` with the standard forward signature -``(hidden_states, top_k_index, top_k_weights) -> Tensor``. +PEFT LoRA on ``gate_up_proj`` / ``down_proj`` is fused into the +ScatterMoE Triton call via ``parallel_linear_lora``. """ import torch @@ -139,12 +128,7 @@ def scattermoe_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ScatterMoE-accelerated experts forward. - - Drop-in replacement for the standard Experts forward signature used by - ``@use_experts_implementation``-decorated classes (Gemma4, Mixtral, etc.): - ``(hidden_states [T, H], top_k_index [T, K], top_k_weights [T, K]) -> [T, H]`` - """ + """ScatterMoE experts forward with fused-LoRA support.""" K = top_k_index.shape[1] routing_weights = top_k_weights.to(hidden_states.dtype) @@ -193,22 +177,24 @@ def scattermoe_experts_forward( return output -def register_scattermoe_experts(): - """Register ``"scattermoe"`` in the transformers ExpertsInterface. +_SCATTERMOE_PATCHED = False - After calling this, any model with ``@use_experts_implementation`` will - dispatch to ScatterMoE when ``config._experts_implementation == "scattermoe"``. - Also patches ``get_correct_experts_implementation`` to accept ``"scattermoe"`` - as a valid value (transformers hardcodes an allowlist). +def register_scattermoe_experts(): + """Register ``"scattermoe"`` in the ExpertsInterface and the validator allowlist. + + Idempotent. """ + global _SCATTERMOE_PATCHED + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS from transformers.modeling_utils import PreTrainedModel - # 1. Register the forward function in the global interface ALL_EXPERTS_FUNCTIONS.register("scattermoe", scattermoe_experts_forward) - # 2. Patch the validation to accept "scattermoe" + if _SCATTERMOE_PATCHED: + return + _original_get_correct = PreTrainedModel.get_correct_experts_implementation def _patched_get_correct(self_model, requested_experts: str | None) -> str: @@ -217,19 +203,4 @@ def _patched_get_correct(self_model, requested_experts: str | None) -> str: return _original_get_correct(self_model, requested_experts) PreTrainedModel.get_correct_experts_implementation = _patched_get_correct - - -# Legacy monkeypatch approach (kept for backward compat with existing tests) -def patch_gemma4_scattermoe(): - """Monkeypatch Gemma4TextExperts.forward with ScatterMoE kernel.""" - from axolotl.integrations.kernels.constants import resolve_experts_class - - experts_cls = resolve_experts_class("gemma4_text") - if experts_cls is None: - raise ValueError("Could not resolve Gemma4TextExperts class") - - if hasattr(experts_cls, "_original_forward"): - return # already patched - - experts_cls._original_forward = experts_cls.forward - experts_cls.forward = scattermoe_experts_forward + _SCATTERMOE_PATCHED = True diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py b/src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py index d1f5e5f603..5cd9cf0fd7 100644 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py @@ -1,3 +1,6 @@ -from .patch import patch_sonicmoe +from .experts import register_sonicmoe_experts, sonicmoe_experts_forward_with_lora -__all__ = ["patch_sonicmoe"] +__all__ = [ + "register_sonicmoe_experts", + "sonicmoe_experts_forward_with_lora", +] diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/experts.py b/src/axolotl/integrations/kernels/libs/sonicmoe/experts.py new file mode 100644 index 0000000000..785b74e45a --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/experts.py @@ -0,0 +1,143 @@ +"""LoRA-aware sonicmoe experts forward for the transformers ExpertsInterface. + +Wraps upstream ``_sonicmoe_wrapper`` and materializes expert LoRA via +``MoELoRAMaterialize`` before the CUTLASS call. +""" + +from __future__ import annotations + +import torch + +from .lora import ( + MoELoRAMaterialize, + get_lora_params_from_wrapper, + has_lora, + materialize_expert_lora, + unwrap_experts_lora, +) + + +def _maybe_unwrap_param_wrapper(param): + """Return ``(base_tensor, lora_params_or_None)`` for a PEFT-wrapped Parameter.""" + try: + from peft.tuners.param_wrapper import ParamWrapper + except ImportError: + return param, None + + if not isinstance(param, ParamWrapper): + return param, None + + base = param.original_parameter + lora_A, lora_B, scaling = get_lora_params_from_wrapper(param) + if lora_A is None: + return base, None + return base, (lora_A, lora_B, scaling) + + +def _resolve_weights_and_lora(experts_module): + """Resolve raw expert weights/biases + optional LoRA tuples. + + Handles both PEFT layouts: module-level wrap (walked via ``unwrap_experts_lora``) + and per-parameter ``ParamWrapper``. No layout permute applied. + """ + if has_lora(experts_module): + base_experts, lora_dict = unwrap_experts_lora(experts_module) + w1 = base_experts.gate_up_proj + w2 = base_experts.down_proj + b1 = getattr(base_experts, "gate_up_proj_bias", None) + b2 = getattr(base_experts, "down_proj_bias", None) + return w1, b1, w2, b2, lora_dict.get("gate_up_proj"), lora_dict.get("down_proj") + + w1, lora_w1 = _maybe_unwrap_param_wrapper(experts_module.gate_up_proj) + w2, lora_w2 = _maybe_unwrap_param_wrapper(experts_module.down_proj) + b1 = getattr(experts_module, "gate_up_proj_bias", None) + b2 = getattr(experts_module, "down_proj_bias", None) + return w1, b1, w2, b2, lora_w1, lora_w2 + + +def sonicmoe_experts_forward_with_lora( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """Sonicmoe experts forward with PEFT LoRA materialization.""" + from transformers.integrations.sonicmoe import _sonicmoe_wrapper + + if not getattr(self, "has_gate", True): + raise ValueError("sonicmoe requires gated experts (has_gate=True)") + if hidden_states.device.type != "cuda": + raise ValueError("sonicmoe requires CUDA device") + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + + # Flatten — token indices must be int32 and sorted ascending (sonic-moe requirement). + token_idx = ( + torch.arange(num_tokens, device=device) + .unsqueeze(1) + .expand(-1, num_top_k) + .reshape(-1) + .int() + ) + router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) + expert_ids = top_k_index.reshape(-1).int() + + w1, b1, w2, b2, lora_w1, lora_w2 = _resolve_weights_and_lora(self) + if not getattr(self, "has_bias", False): + b1 = b2 = None + + # FSDP2 / EP wraps parameters as DTensors but sonic-moe takes raw CUTLASS pointers, + # so unwrap to local shards before the materialize/permute. to_local() is + # autograd-aware — backward will rewrap the gradient as a DTensor again. + if isinstance(w1, torch.distributed.tensor.DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + b1 = b1.to_local() if b1 is not None else None + b2 = b2.to_local() if b2 is not None else None + + # Materialize W_eff = W + scaling * (B @ A) per expert. No-op when no LoRA. + if lora_w1 is not None: + w1 = MoELoRAMaterialize.apply(w1, *lora_w1) + if lora_w2 is not None: + w2 = MoELoRAMaterialize.apply(w2, *lora_w2) + + # Match upstream layout expectations: + # is_transposed=False: gate_up [E, 2*I, H] / down [E, H, I] -> permute(1, 2, 0) + # is_transposed=True: gate_up [E, H, 2*I] / down [E, I, H] -> permute(2, 1, 0) + perm = (2, 1, 0) if getattr(self, "is_transposed", False) else (1, 2, 0) + w1 = w1.permute(*perm) + w2 = w2.permute(*perm) + + act_name = getattr(self.config, "hidden_act", "silu").lower() + + return _sonicmoe_wrapper( + hidden_states=hidden_states, + router_scores=router_scores, + expert_ids=expert_ids, + token_idx=token_idx, + w1=w1, + b1=b1, + w2=w2, + b2=b2, + act_name=act_name, + num_experts=self.num_experts, + concat_layout=getattr(self, "is_concatenated", True), + is_inference_mode_enabled=not torch.is_grad_enabled(), + ) + + +def register_sonicmoe_experts() -> None: + """Register the LoRA-aware ``"sonicmoe"`` forward, overriding upstream. Idempotent.""" + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + ALL_EXPERTS_FUNCTIONS.register("sonicmoe", sonicmoe_experts_forward_with_lora) + + +# Re-export utilities for tests / external callers. +__all__ = [ + "sonicmoe_experts_forward_with_lora", + "register_sonicmoe_experts", + "materialize_expert_lora", +] diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py b/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py deleted file mode 100644 index a4025dd842..0000000000 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/gemma4_experts.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -SonicMoE-accelerated experts forward for Gemma4. - -Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer. -This module provides a drop-in replacement for ``Gemma4TextExperts.forward`` -that uses SonicMoE kernels while preserving the original call signature. -""" - -import torch - -from .lora import has_lora, materialize_expert_lora, unwrap_experts_lora - - -def _get_expert_weights_gemma4(experts_module): - """Extract expert weights from Gemma4TextExperts, applying LoRA if active. - - Returns: - (gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E]. - """ - if has_lora(experts_module): - base_experts, lora_dict = unwrap_experts_lora(experts_module) - gate_up = materialize_expert_lora( - base_experts.gate_up_proj, lora_dict.get("gate_up_proj") - ) - down = materialize_expert_lora( - base_experts.down_proj, lora_dict.get("down_proj") - ) - else: - gate_up = experts_module.gate_up_proj - down = experts_module.down_proj - - # Permute to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - return gate_up.permute(1, 2, 0), down.permute(1, 2, 0) - - -def gemma4_sonicmoe_experts_forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """SonicMoE-accelerated replacement for Gemma4TextExperts.forward. - - Same signature as the original: (hidden_states [T, H], top_k_index [T, K], - top_k_weights [T, K]) -> output [T, H]. - """ - from sonicmoe import moe_general_routing_inputs - from sonicmoe.enums import ActivationType - - T, _ = hidden_states.shape - K = top_k_index.shape[1] - E = self.num_experts - - # Convert routing outputs to SonicMoE's flat format - # Token indices sorted ascending (required by SonicMoE) - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - flat_scores = top_k_weights.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = top_k_index.to(torch.int32).reshape(-1) # [T*K] - - # Get weights (with LoRA materialization if needed) - gate_up_weight, down_weight = _get_expert_weights_gemma4(self) - gate_up_weight = gate_up_weight.to(hidden_states.dtype) - down_weight = down_weight.to(hidden_states.dtype) - - if not torch.cuda.is_available(): - raise RuntimeError("SonicMoE requires CUDA. No CUDA device available.") - cuda_stream = torch.cuda.current_stream().cuda_stream - - output, _ = moe_general_routing_inputs( - hidden_states, - flat_scores, - flat_token_idx, - flat_expert_idx, - gate_up_weight, - None, # b1 (no gate/up bias) - down_weight, - None, # b2 (no down bias) - E, - cuda_stream, - ActivationType.GEGLU, - False, # is_inference_mode - ) - - return output - - -def patch_gemma4_sonicmoe(): - """Monkeypatch Gemma4TextExperts.forward with SonicMoE kernel.""" - from axolotl.integrations.kernels.constants import resolve_experts_class - - experts_cls = resolve_experts_class("gemma4_text") - if experts_cls is None: - raise ValueError("Could not resolve Gemma4TextExperts class") - - if hasattr(experts_cls, "_original_forward"): - return # already patched - - experts_cls._original_forward = experts_cls.forward - experts_cls.forward = gemma4_sonicmoe_experts_forward diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py b/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py index 4d7a21925b..1fe08828cb 100644 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py @@ -61,33 +61,6 @@ def get_lora_params_from_wrapper(module) -> tuple: return lora_A, lora_B, scaling -def unwrap_gate_lora(gate_module): - """Unwrap PEFT ParamWrapper on the router gate. - - When PEFT targets ``gate.weight``, ``self.gate`` becomes:: - - ParamWrapper(weight) - -> base_layer: Router (the real module) - - Returns: - (base_gate, gate_weight, gate_lora_delta_or_None) - - ``base_gate`` is the original router module (with ``.top_k``, etc.). - ``gate_weight`` is the base router weight tensor. - ``gate_lora_delta_or_None`` is the LoRA delta if active, else None. - Kept separate to avoid mixing DTensor + Tensor under FSDP. - """ - if has_lora(gate_module): - base_gate = gate_module.base_layer - lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module) - if lora_A is not None: - delta = scaling * (lora_B @ lora_A) - return base_gate, base_gate.weight, delta - return base_gate, base_gate.weight, None - - return gate_module, gate_module.weight, None - - def unwrap_experts_lora(experts_module): """Walk a PEFT ParamWrapper chain on ``self.experts``. @@ -129,18 +102,12 @@ def unwrap_experts_lora(experts_module): class MoELoRAMaterialize(torch.autograd.Function): - """Materialize effective weight W_eff = W + scaling * (B @ A) per expert. - - Inserts into the autograd graph between PEFT's LoRA parameters and - SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff, - which this function decomposes into dA and dB via the chain rule. - - Weight layouts (PEFT rank-major): - base_weight: [E, dim1, dim2] (frozen expert parameter) - lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e) - lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e) + """Materialize ``W_eff = W + scaling * (B @ A)`` per expert and route grads. - Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2] + Layout matches PEFT >= 0.19.1 ``ParamWrapper``: ``base [E, dim1, dim2]``, + ``lora_A [r*E, dim2]`` (E-outer, r-inner rows), ``lora_B [dim1, r*E]`` + (r-outer, E-inner cols). Equivalent to + ``einsum("o r e, e r i -> e o i", lora_B.reshape(dim1, r, E), lora_A.reshape(E, r, dim2))``. """ @staticmethod diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/patch.py b/src/axolotl/integrations/kernels/libs/sonicmoe/patch.py deleted file mode 100644 index 65095a9871..0000000000 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/patch.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -SonicMoE patching for SparseMoeBlock forward pass. - -Monkeypatches the SparseMoeBlock class for a given model type to use -SonicMoE's optimized kernels. Two forward paths are supported: - -1. **General routing path** (routing_fn is not None): - Uses a custom routing function + ``moe_general_routing_inputs``. - Suitable for models with non-standard routing (softmax->topk, sigmoid->topk). - -2. **Fused topk->softmax path** (routing_fn is None): - Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation. - Suitable for models with simple topk->softmax routing. - -Weight format conversion (interleave/deinterleave) is handled by the -WeightConverter system, so the forward assumes weights are already in -interleaved format. - -Shared experts are handled generically: if the block has a ``shared_expert`` -or ``shared_experts`` attribute, its output is computed alongside the routed -experts and added to the final output. An optional ``shared_expert_gate`` -applies sigmoid gating to the shared expert contribution. -""" - -import torch -import torch.nn.functional as F - -from axolotl.integrations.kernels.constants import resolve_moe_block_classes -from axolotl.utils.logging import get_logger - -from .lora import ( - has_lora, - materialize_expert_lora, - unwrap_experts_lora, - unwrap_gate_lora, -) - -LOG = get_logger(__name__) - - -def _get_expert_weights(experts_module): - """Extract expert weights, applying LoRA materialization if PEFT is active. - - Returns: - (gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E]. - """ - if has_lora(experts_module): - base_experts, lora_dict = unwrap_experts_lora(experts_module) - gate_up = materialize_expert_lora( - base_experts.gate_up_proj, lora_dict.get("gate_up_proj") - ) - down = materialize_expert_lora( - base_experts.down_proj, lora_dict.get("down_proj") - ) - else: - gate_up = experts_module.gate_up_proj - down = experts_module.down_proj - - # Permute to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - return gate_up.permute(1, 2, 0), down.permute(1, 2, 0) - - -def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str): - """Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders.""" - if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text": - return - - try: - from transformers.conversion_mapping import ( - get_checkpoint_conversion_mapping, - register_checkpoint_conversion_mapping, - ) - from transformers.core_model_loading import WeightRenaming - except ImportError: - return - - text_mapping = get_checkpoint_conversion_mapping(model_type) - if text_mapping and isinstance(text_mapping[0], WeightRenaming): - text_mapping.pop(0) - register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True) - LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode") - - -def patch_sonicmoe( - model_type: str, - torch_compile: bool = False, - base_model_type: str | None = None, -): - """Patch SparseMoeBlock for SonicMoE support.""" - from .routing import get_model_moe_config - from .weight_converter import register_sonicmoe_weight_converter - - _fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type) - - routing_fn, activation, router_attr = get_model_moe_config(model_type) - - if torch_compile and routing_fn is not None: - routing_fn = _try_compile_routing(routing_fn) - - for moe_cls in resolve_moe_block_classes(model_type): - _patch_forward(moe_cls, routing_fn, activation, router_attr) - register_sonicmoe_weight_converter(model_type) - - -def _try_compile_routing(routing_fn): - """Attempt to torch.compile the routing function, fall back to eager on failure.""" - try: - compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False) - LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}") - return compiled_fn - except Exception as exc: # pylint: disable=broad-except - LOG.warning( - f"torch.compile failed for routing function {routing_fn.__name__}, " - f"falling back to eager: {exc}" - ) - return routing_fn - - -def _patch_forward(moe_cls, routing_fn, activation, router_attr): - """Monkeypatch the SparseMoeBlock class with a SonicMoE forward. - - The patched forward handles shared experts generically: if - ``self.shared_expert`` or ``self.shared_experts`` exists, it is computed - and added to the routed output. If ``self.shared_expert_gate`` also exists, - it applies sigmoid gating to the shared expert contribution (as in qwen2_moe). - - Args: - moe_cls: The SparseMoeBlock class to patch. - routing_fn: Routing function (e.g. softmax_topk_routing), or None - for the fused moe_TC_softmax_topk_layer path. - activation: SonicMoE ActivationType enum value. - router_attr: Name of the router module attribute on the MoE block. - """ - if hasattr(moe_cls, "_original_forward"): - LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping") - return - - original_forward = moe_cls.forward - - if routing_fn is not None: - _make_general_forward(moe_cls, routing_fn, activation) - else: - _make_fused_forward(moe_cls, activation, router_attr) - - moe_cls._original_forward = original_forward - LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation") - - -def _make_general_forward(moe_cls, routing_fn, activation): - """Create forward using routing_fn + moe_general_routing_inputs.""" - - def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - from sonicmoe import moe_general_routing_inputs - - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, hidden_dim) - - # Shared expert (computed early, matching original model ordering) - shared_expert_output = _compute_shared_expert(self, hidden_states_flat) - - # Routing - router_scores, token_indices, expert_indices, _router_logits = routing_fn( - hidden_states_flat, self - ) - - # Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout - gate_up_weight, down_weight = _get_expert_weights(self.experts) - gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype) - down_weight = down_weight.to(hidden_states_flat.dtype) - E = gate_up_weight.shape[-1] - - output, _ = moe_general_routing_inputs( - hidden_states_flat, - router_scores, - token_indices, - expert_indices, - gate_up_weight, - None, # b1 (no gate/up bias) - down_weight, - None, # b2 (no down bias) - E, - torch.cuda.current_stream().cuda_stream, - activation, - False, # is_inference_mode - ) - - # Add shared expert contribution if present - if shared_expert_output is not None: - if hasattr(self, "shared_expert_gate"): - shared_expert_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states_flat)) - * shared_expert_output - ) - output = output + shared_expert_output - - return output.view(batch_size, sequence_length, hidden_dim) - - moe_cls.forward = sonicmoe_forward - - -def _make_fused_forward(moe_cls, activation, router_attr): - """Create forward using moe_TC_softmax_topk_layer (topk -> softmax).""" - - def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - from sonicmoe import moe_TC_softmax_topk_layer - - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, hidden_dim) - - # Shared expert (computed early, matching original model ordering) - shared_expert_output = _compute_shared_expert(self, hidden_states_flat) - - # Unwrap router for attribute access + optional LoRA delta - raw_router = getattr(self, router_attr) - base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router) - if router_lora_delta is not None: - # Materialize local tensor to avoid DTensor + Tensor add under FSDP - if hasattr(router_weight, "to_local"): - router_weight = router_weight.to_local() - effective_router_weight = router_weight + router_lora_delta - else: - effective_router_weight = router_weight - - # Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout - gate_up_weight, down_weight = _get_expert_weights(self.experts) - gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype) - down_weight = down_weight.to(hidden_states_flat.dtype) - - output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer( - hidden_states_flat, - effective_router_weight, - gate_up_weight, - None, # b1 (no gate/up bias) - down_weight, - None, # b2 (no down bias) - base_router.top_k, - torch.cuda.current_stream().cuda_stream, - activation, - False, # is_inference_mode - ) - - # Add shared expert contribution if present - if shared_expert_output is not None: - if hasattr(self, "shared_expert_gate"): - shared_expert_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states_flat)) - * shared_expert_output - ) - output = output + shared_expert_output - - return output.view(batch_size, sequence_length, hidden_dim) - - moe_cls.forward = sonicmoe_fused_forward - - -def _compute_shared_expert(moe_block, hidden_states_flat): - """Compute shared expert output if the block has one. - - Handles singular (qwen2_moe: ``shared_expert``), plural - (glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP - (hunyuan_v1_moe: ``shared_mlp``) attribute names. - """ - shared_expert = ( - getattr(moe_block, "shared_expert", None) - or getattr(moe_block, "shared_experts", None) - or getattr(moe_block, "shared_mlp", None) - ) - if shared_expert is not None: - return shared_expert(hidden_states_flat) - return None diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py b/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py deleted file mode 100644 index 68654d0868..0000000000 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py +++ /dev/null @@ -1,576 +0,0 @@ -""" -Routing functions for SonicMoE integration. - -Different MoE architectures use different routing strategies: -- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization) -- mistral4: softmax -> group selection -> topk (with renormalization and scaling) -- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection) -- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing) -- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing) -- gemma4_text: RMSNorm -> scale -> proj -> softmax -> topk -> renorm -> per_expert_scale (gemma4_routing) -- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED] - -Each model type maps to a (routing_fn, activation_type, router_attr) triple. -When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. -""" - -import torch -import torch.nn.functional as F - -from .lora import unwrap_gate_lora - - -def get_model_moe_config(model_type: str): - """Returns (routing_fn, activation, router_attr) for a given model type. - - Args: - model_type: HuggingFace model type string. - - Returns: - routing_fn: Callable or None. None signals the fused - moe_TC_softmax_topk_layer path (topk -> softmax models). - activation: SonicMoE ActivationType enum value. - router_attr: Name of the router module attribute on the MoE block - (e.g. "gate" or "router"). - - The activation type cannot be derived from config.hidden_act because - e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU - (act_fn(gate) * up pattern). So we specify it per model type. - """ - from sonicmoe.enums import ActivationType - - if model_type in ( - "qwen2_moe", - "qwen3_moe", - "qwen3_5_moe", - "qwen3_5_moe_text", - "qwen3_next", - "qwen3_vl_moe", - "qwen3_omni_moe", - "olmoe", - "mixtral", - "minimax", - ): - return softmax_topk_routing, ActivationType.SWIGLU, "gate" - elif model_type in ("mistral4",): - return softmax_group_topk_routing, ActivationType.SWIGLU, "gate" - elif model_type in ( - "glm_moe_dsa", - "deepseek_v3", - "glm4_moe", - "glm4_moe_lite", - "glm4v_moe", - "minimax_m2", - ): - return sigmoid_topk_routing, ActivationType.SWIGLU, "gate" - elif model_type in ("ernie4_5_moe",): - return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate" - elif model_type in ("hunyuan_v1_moe",): - return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate" - elif model_type in ("gemma4_text",): - return gemma4_routing, ActivationType.GEGLU, "router" - # Fused topk -> softmax path (routing_fn=None): - # elif model_type in ("gpt_oss",): - # # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer - # # ignores (it only takes router_w, not bias). Also has transposed - # # weight layout [E, H, 2*I] and custom GLU activation. - # return None, ActivationType.SWIGLU, "router" - else: - raise ValueError(f"SonicMoE: unsupported model type '{model_type}'") - - -def softmax_topk_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm. - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.gate.*) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) - T, H = hidden_states.shape - K = base_gate.top_k - - # Compute router logits and softmax over all experts. - # Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP. - # Cast to float32 to match LoRA delta dtype (PEFT computes in fp32). - router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), gate_lora_delta.float() - ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - - # Select top-k experts per token - top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each - - # Renormalize if configured (default True for models without the attribute, - # e.g. Mixtral/MiniMax which always normalize) - if getattr(base_gate, "norm_topk_prob", True): - top_values = top_values / top_values.sum(dim=-1, keepdim=True) - - # no-op: matches transformers which casts to softmax output dtype (float32). - # top_values = top_values.to(router_probs.dtype) - - # Flatten for moe_general_routing_inputs. - # Token indices are naturally sorted ascending from the [T, K] layout: - # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE. - # Expert sorting is handled internally by general_routing_router_metadata. - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = top_values.reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def softmax_group_topk_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale.""" - base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) - T, _ = hidden_states.shape - K = moe_block.top_k - E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0]) - n_group = getattr(moe_block, "n_group", 1) - - router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), gate_lora_delta.float() - ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - - scores_for_choice = router_probs - - # Group selection: pick top groups, mask the rest - if n_group > 1: - group_scores = ( - scores_for_choice.view(-1, n_group, E // n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk( - group_scores, k=moe_block.topk_group, dim=-1, sorted=False - )[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) - ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) - - topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] - topk_weights = router_probs.gather(1, topk_indices) - - # Renormalization + scaling - norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) - if norm_topk_prob: - topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) - routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) - topk_weights = topk_weights * routed_scaling_factor - - # Flatten for moe_general_routing_inputs - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def sigmoid_topk_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Sigmoid-based routing: sigmoid -> optional group selection -> topk. - - Supports two variants: - - **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1, - bias on gate, group-based masking before topk. - - **No group selection** (minimax_m2): n_group == 1 (or absent), - bias on moe_block, straight topk from all experts. - - Final routing weights come from the original sigmoid scores (not - bias-corrected), with optional renormalization and scaling. - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.gate.* and - optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob, - .routed_scaling_factor, .n_routed_experts) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) - T, _ = hidden_states.shape - K = moe_block.top_k - E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0]) - n_group = getattr(moe_block, "n_group", 1) - - # Compute router logits and sigmoid probabilities - router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), gate_lora_delta.float() - ) - router_probs = router_logits.sigmoid() # [T, E] - - # Bias-corrected scores for expert selection (not used for final weights). - # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block. - e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None) - if e_score_correction_bias is None: - e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None) - if e_score_correction_bias is None: - raise AttributeError( - f"sigmoid_topk_routing requires e_score_correction_bias on " - f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it" - ) - scores_for_choice = router_probs + e_score_correction_bias - - # Group-based selection: pick top groups, mask the rest (skip when n_group == 1) - if n_group > 1: - group_scores = ( - scores_for_choice.view(-1, n_group, E // n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) # [T, n_group] - group_idx = torch.topk( - group_scores, k=moe_block.topk_group, dim=-1, sorted=False - )[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) - ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) - - # Final topk from (possibly masked) scores - topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] - - # Gather weights from original sigmoid scores (not bias-corrected) - topk_weights = router_probs.gather(1, topk_indices) - - # Optional renormalization + scaling - norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) - if norm_topk_prob: - topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) - routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) - topk_weights = topk_weights * routed_scaling_factor - - # Flatten for moe_general_routing_inputs. - # Token indices are naturally sorted ascending from the [T, K] layout. - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def softmax_bias_topk_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm. - - Differs from standard softmax_topk_routing in three ways: - 1. A learned e_score_correction_bias is added to softmax probs *before* topk - (selection uses biased scores, but final weights use original probs). - 2. The bias is applied via gate.moe_statics module (not a raw tensor). - 3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon. - - Reference: Ernie4_5_MoeTopKRouter.forward in transformers. - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.gate.*) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) - T, H = hidden_states.shape - K = base_gate.top_k - - # Compute router logits and softmax (force float32 for numerical stability) - router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), gate_lora_delta.float() - ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - - # Bias-corrected scores for expert selection (via moe_statics module) - scores_for_choice = base_gate.moe_statics(router_probs) # [T, E] - - # Select top-k experts using biased scores - _, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K] - - # Gather weights from *original* (unbiased) softmax probs - top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K] - - # Renormalize with clamp(min=norm_min) instead of sum+epsilon - norm_min = getattr(base_gate, "norm_min", 1e-20) - top_values = top_values / torch.clamp( - top_values.sum(dim=-1, keepdim=True), min=norm_min - ) - - # Flatten for moe_general_routing_inputs - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def softmax_group_limited_topk_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale. - - Differs from softmax_group_topk_routing (Mistral4) in several ways: - 1. Uses ``num_group`` attribute (not ``n_group``). - 2. Group score = max per group (not sum of top-2). - 3. Supports ``greedy`` method (plain topk without groups). - 4. No renormalization — just ``topk_weight * routed_scaling_factor``. - 5. Gate is ``nn.Linear`` (access weight via ``gate.weight``). - - Reference: DeepseekV2Moe.route_tokens_to_experts in transformers. - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.gate, .num_group, - .topk_group, .top_k, .topk_method, .routed_scaling_factor) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) - T, H = hidden_states.shape - K = moe_block.top_k - num_group = getattr(moe_block, "num_group", 1) - num_experts = gate_weight.shape[0] - topk_method = getattr(moe_block, "topk_method", "greedy") - - # Compute logits in float32 and softmax - router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), gate_lora_delta.float() - ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - - if topk_method == "greedy" or num_group == 1: - topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False) - elif topk_method == "group_limited_greedy": - # Guard: selected groups must contain enough experts for topk - group_size = num_experts // num_group - if moe_block.topk_group * group_size < K: - raise ValueError( - f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size " - f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). " - f"Not enough experts in selected groups for topk selection." - ) - # Group selection: pick top groups by max score per group - group_scores = ( - router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values - ) # [T, num_group] - group_idx = torch.topk( - group_scores, k=moe_block.topk_group, dim=-1, sorted=False - )[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(T, num_group, num_experts // num_group) - .reshape(T, -1) - ) - tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0) - topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False) - else: - raise ValueError( - f"DeepSeek V2: unsupported topk_method '{topk_method}'. " - f"Expected 'greedy' or 'group_limited_greedy'." - ) - - # Scale only — no renormalization (weights won't sum to 1.0 per token). - # This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior. - routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) - topk_weights = topk_weights * routed_scaling_factor - - # Flatten for moe_general_routing_inputs - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def softmax_topk_wg_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg). - - Differs from standard softmax_topk_routing in: - 1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``). - 2. ``top_k`` is on ``moe_block`` (not ``gate``). - 3. Always renormalizes (no ``norm_topk_prob`` flag). - - Reference: HunYuanMoEV1Moe.route_tokens_to_experts and - HunYuanMoEV1Gate.forward in transformers. - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - gate = moe_block.gate - T, H = hidden_states.shape - K = moe_block.top_k - - # Gate computes logits via gate.wg (nn.Linear, float32) - # Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container - base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg) - router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E] - if wg_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states.float(), wg_lora_delta.float() - ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - - # Select top-k experts - top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each - - # Always renormalize (HunYuan V1 has no norm_topk_prob flag) - top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20) - - # Flatten for moe_general_routing_inputs - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits - - -def gemma4_routing( - hidden_states: torch.Tensor, moe_block -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Gemma4-style routing: RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale. - - Gemma4's router (``Gemma4TextRouter``) has a unique structure: - 1. RMSNorm (without learnable scale) on hidden states - 2. Multiply by ``scale * hidden_size**-0.5`` - 3. Linear projection to expert scores - 4. Softmax → topk - 5. Normalize top-k weights to sum to 1 - 6. Multiply by per-expert learned scales - - The router lives at ``moe_block.router`` (not ``moe_block.gate``). - LoRA on the router targets ``router.proj`` (nn.Linear). - - Args: - hidden_states: [T, H] flattened token representations - moe_block: MoE block module (accesses moe_block.router) - - Returns: - router_scores: [T*K] flattened scores (float32) - token_indices: [T*K] which token each entry belongs to (int32), sorted ascending - expert_indices: [T*K] which expert (int32) - router_logits: [T, E] original logits for aux loss - """ - router = moe_block.router - - # Unwrap PEFT LoRA on router.proj (the nn.Linear) - _, proj_weight, proj_lora_delta = unwrap_gate_lora(router.proj) - - T, _ = hidden_states.shape - K = router.top_k if hasattr(router, "top_k") else router.config.top_k_experts - - # Reproduce Gemma4TextRouter.forward: - # 1. RMSNorm (no scale) + scale param * hidden_size**-0.5 - normed = router.norm(hidden_states) - scaled = normed * router.scale * router.scalar_root_size - - # 2. Project to expert scores - router_logits = F.linear(scaled.float(), proj_weight.float()) # [T, E] - if proj_lora_delta is not None: - router_logits = router_logits + F.linear( - scaled.float(), proj_lora_delta.float() - ) - - # 3. Softmax → topk - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] - top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] - - # 4. Normalize top-k weights - top_values = top_values / top_values.sum(dim=-1, keepdim=True) - - # 5. Per-expert scale - top_values = top_values * router.per_expert_scale[top_indices] - - # Flatten for moe_general_routing_inputs - token_indices = ( - torch.arange(T, device=hidden_states.device, dtype=torch.int32) - .unsqueeze(1) - .expand(T, K) - ) - - flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] - flat_token_idx = token_indices.reshape(-1) # [T*K] - flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] - - return flat_scores, flat_token_idx, flat_expert_idx, router_logits diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py b/src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py deleted file mode 100644 index 20da27ff0a..0000000000 --- a/src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -Custom WeightConverter operations for SonicMoE weight format conversion. - -SonicMoE requires gate_up_proj weights in interleaved format: -- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up -- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...] - -These ConversionOps integrate with transformers' WeightConverter system so that -weights are transparently converted during loading and reverted during saving. -""" - -from typing import Any - -import torch -from einops import rearrange -from transformers.core_model_loading import ConversionOps - -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - - -def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: - """[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension.""" - return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2) - - -def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: - """[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension.""" - return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2) - - -class ConcatenatedToInterleaved(ConversionOps): - """Convert concatenated gate/up projections to interleaved format. - - Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H] - Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...] - - This operation is applied along ``dim`` (default 1, the 2*I dimension). - """ - - def __init__(self, dim: int = 1): - self.dim = dim - - @torch.no_grad() - def convert( - self, - input_dict: dict[str, Any], - source_patterns: list[str], - target_patterns: list[str], - **kwargs, - ) -> dict[str, torch.Tensor]: - target_pattern = self._get_target_pattern( - input_dict, source_patterns, target_patterns - ) - tensors = next(iter(input_dict.values())) - tensor = tensors[0] if isinstance(tensors, list) else tensors - - interleaved = interleave_gate_up(tensor) - - return {target_pattern: interleaved} - - def _get_target_pattern( - self, - input_dict: dict[str, Any], - source_patterns: list[str], - target_patterns: list[str], - ) -> str: - # Follow the same logic as Transpose.get_target_pattern - if len(input_dict) != 1: - raise ValueError("Undefined Operation encountered!") - if len(target_patterns) > 1: - if len(source_patterns) == 1: - return source_patterns[0] - raise ValueError("Undefined Operation encountered!") - return target_patterns[0] - - @property - def reverse_op(self) -> ConversionOps: - return InterleavedToConcatenated(self.dim) - - -class InterleavedToConcatenated(ConversionOps): - """Convert interleaved gate/up projections back to concatenated format. - - Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...] - Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H] - - This is the reverse of ``ConcatenatedToInterleaved``. - """ - - def __init__(self, dim: int = 1): - self.dim = dim - - @torch.no_grad() - def convert( - self, - input_dict: dict[str, Any], - source_patterns: list[str], - target_patterns: list[str], - **kwargs, - ) -> dict[str, torch.Tensor]: - target_pattern = self._get_target_pattern( - input_dict, source_patterns, target_patterns - ) - tensors = next(iter(input_dict.values())) - tensor = tensors[0] if isinstance(tensors, list) else tensors - - concatenated = deinterleave_gate_up(tensor) - - return {target_pattern: concatenated} - - def _get_target_pattern( - self, - input_dict: dict[str, Any], - source_patterns: list[str], - target_patterns: list[str], - ) -> str: - if len(input_dict) != 1: - raise ValueError("Undefined Operation encountered!") - if len(target_patterns) > 1: - if len(source_patterns) == 1: - return source_patterns[0] - raise ValueError("Undefined Operation encountered!") - return target_patterns[0] - - @property - def reverse_op(self) -> ConversionOps: - return ConcatenatedToInterleaved(self.dim) - - -def _make_same_key_interleave_converter(): - """Create a WeightConverter that interleaves an already-fused gate_up_proj.""" - from transformers.core_model_loading import WeightConverter - - return WeightConverter( - source_patterns="mlp.experts.gate_up_proj", - target_patterns="mlp.experts.gate_up_proj", - operations=[ConcatenatedToInterleaved(dim=1)], - ) - - -def _has_same_key_interleave(mapping) -> bool: - """Check whether the mapping already has a same-key gate_up_proj interleave converter.""" - for conv in mapping: - if ( - hasattr(conv, "source_patterns") - and conv.source_patterns == ["mlp.experts.gate_up_proj"] - and conv.target_patterns == ["mlp.experts.gate_up_proj"] - and hasattr(conv, "operations") - and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations) - ): - return True - return False - - -def register_sonicmoe_weight_converter(model_type: str): - """Register weight converters to interleave gate_up_proj for SonicMoE. - - Handles two checkpoint formats: - 1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the - existing merge chain (MergeModulelist -> Concatenate -> Interleave). - 2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key - converter (gate_up_proj -> gate_up_proj with Interleave). - - The loader matches whichever source pattern exists in the checkpoint. - """ - from transformers.conversion_mapping import ( - get_checkpoint_conversion_mapping, - register_checkpoint_conversion_mapping, - ) - - existing = get_checkpoint_conversion_mapping(model_type) - - if existing is None: - # No mapping at all — create one with just the same-key converter - mapping = [_make_same_key_interleave_converter()] - register_checkpoint_conversion_mapping(model_type, mapping) - LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'") - return - - # Append interleave to any existing many-to-one merge chain - for converter in existing: - if hasattr(converter, "operations") and any( - "gate_up_proj" in pat for pat in converter.target_patterns - ): - has_separate_sources = any( - "gate_proj" in pat or "up_proj" in pat - for pat in converter.source_patterns - ) - if has_separate_sources and not any( - isinstance(op, ConcatenatedToInterleaved) for op in converter.operations - ): - converter.operations.append(ConcatenatedToInterleaved(dim=1)) - break - - # Also add a same-key converter for already-fused checkpoints - if not _has_same_key_interleave(existing): - existing.append(_make_same_key_interleave_converter()) - - register_checkpoint_conversion_mapping(model_type, existing, overwrite=True) - LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'") diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index d713095a5f..f105e5ecff 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,6 +1,5 @@ import importlib import os -from pathlib import Path import torch @@ -61,119 +60,29 @@ def get_input_args(self): return "axolotl.integrations.kernels.KernelsArgs" def pre_model_load(self, cfg): - from axolotl.integrations.kernels.constants import ( - SPARSE_MOE_BLOCK, - is_experts_only_model, - ) - - # Prefer text backbone type for VLMs, but fall back to base type - # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text) - moe_model_type = cfg.model_config_type_text or cfg.model_config_type - if ( - moe_model_type not in SPARSE_MOE_BLOCK - and not is_experts_only_model(moe_model_type) - and cfg.model_config_type in SPARSE_MOE_BLOCK - ): - moe_model_type = cfg.model_config_type - - # When expert parallelism is enabled, the EP plugin sets - # `experts_implementation` to `deep_ep_scattermoe` / `deep_ep_sonicmoe` - # and dispatches the kernel inside the experts-level forward (after - # DeepEP all-to-all). Skip the SparseMoeBlock-level patch in that case - # — patching the block-level forward bypasses EP routing and reads - # FSDP-sharded expert weights as DTensors, which the kernels do not - # accept. - ep_active = (getattr(cfg, "expert_parallel_size", 1) or 1) > 1 + """Register the requested kernel into ``ALL_EXPERTS_FUNCTIONS`` and pin cfg. + Architecture-agnostic: routing stays in each model's SparseMoEBlock; only + the experts call is dispatched through the registry. + """ if cfg.use_scattermoe: - self._register_kernels() - if is_experts_only_model(moe_model_type): - # Models like Gemma4 where MoE is embedded in the decoder layer - # — register ScatterMoE in the ExpertsInterface so that - # @use_experts_implementation dispatches to it. - self._register_experts_interface() - if not ep_active: - cfg.experts_implementation = "scattermoe" - elif ep_active: - LOG.info( - "expert_parallel_size > 1: skipping SparseMoeBlock-level " - "ScatterMoE patch; the deep_ep_scattermoe registered " - "function handles the kernel under EP." - ) - else: - self._kernelize_model(moe_model_type) - elif cfg.use_sonicmoe: - if not importlib.util.find_spec("sonicmoe"): - raise RuntimeError( - "SonicMoE is not installed. See installation instructions at " - "https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation" - ) + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( + register_scattermoe_experts, + ) + register_scattermoe_experts() + cfg.experts_implementation = "scattermoe" + LOG.info("Registered 'scattermoe' in transformers ExpertsInterface") + elif cfg.use_sonicmoe: _check_sonicmoe_gpu_compat() - if is_experts_only_model(moe_model_type): - from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import ( - patch_gemma4_sonicmoe, - ) - - LOG.info( - f"Applying SonicMoE experts-level patch for model type: {moe_model_type}" - ) - patch_gemma4_sonicmoe() - # TODO(EP+SonicMoE): grad norms explode during training. Re-enable - # once the root cause is identified. Same shape as the ScatterMoE - # branch above, but SonicMoE additionally needs the gate_up_proj - # interleave converter since its w1 layout is [g0, u0, g1, u1, ...] - # while the checkpoint stores it concatenated [gate..., up...]. - # - # elif ep_active: - # from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( - # register_sonicmoe_weight_converter, - # ) - # - # LOG.info( - # "expert_parallel_size > 1: skipping SparseMoeBlock-level " - # "SonicMoE patch; the deep_ep_sonicmoe registered function " - # "handles the kernel under EP. Registering gate_up_proj " - # "interleave converter." - # ) - # register_sonicmoe_weight_converter(moe_model_type) - else: - from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe - - LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}") - patch_sonicmoe( - moe_model_type, - torch_compile=bool(getattr(cfg, "torch_compile", False)), - base_model_type=cfg.model_config_type, - ) - - def _register_kernels(self): - from kernels import ( - LocalLayerRepository, - Mode, - register_kernel_mapping, - ) + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, + ) - plugin_root = Path(__file__).parent - register_kernel_mapping( - { - "HFScatterMoEParallelExperts": { - "cuda": { - Mode.TRAINING: LocalLayerRepository( - repo_path=plugin_root / "libs" / "scattermoe_lora", - package_name="scattermoe_lora", - layer_name="HFScatterMoEGatedMLP", - ), - Mode.INFERENCE: LocalLayerRepository( - repo_path=plugin_root / "libs" / "scattermoe_lora", - package_name="scattermoe_lora", - layer_name="HFScatterMoEGatedMLP", - ), - }, - } - } - ) + register_sonicmoe_experts() + cfg.experts_implementation = "sonicmoe" + LOG.info("Registered 'sonicmoe' in transformers ExpertsInterface") def add_callbacks_pre_trainer(self, cfg, model): callbacks = [] @@ -184,26 +93,3 @@ def add_callbacks_pre_trainer(self, cfg, model): callbacks.append(AutotuneReportCallback()) return callbacks - - def _kernelize_model(self, model_type: str): - from kernels import replace_kernel_forward_from_hub - - from axolotl.integrations.kernels.constants import resolve_moe_block_classes - - for model_moe_cls in resolve_moe_block_classes(model_type): - replace_kernel_forward_from_hub( - model_moe_cls, "HFScatterMoEParallelExperts" - ) - - def _register_experts_interface(self): - """Register ScatterMoE in the transformers ExpertsInterface. - - This allows @use_experts_implementation-decorated Experts classes - to dispatch to ScatterMoE when config._experts_implementation == "scattermoe". - """ - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( - register_scattermoe_experts, - ) - - register_scattermoe_experts() - LOG.info("Registered 'scattermoe' in transformers ExpertsInterface") diff --git a/tests/e2e/integrations/test_sonicmoe.py b/tests/e2e/integrations/test_sonicmoe.py index ff8620b2fa..021b41dc44 100644 --- a/tests/e2e/integrations/test_sonicmoe.py +++ b/tests/e2e/integrations/test_sonicmoe.py @@ -1,13 +1,19 @@ """ End-to-end gradient and convergence tests for SonicMoE integration. -Requires: - - H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90) - - sonicmoe package installed - - transformers with Qwen3MoE support +After the ExpertsInterface refactor, the flow is: + + register_sonicmoe_experts() # plug into ALL_EXPERTS_FUNCTIONS + config._experts_implementation = "sonicmoe" + model = AutoModelForCausalLM.from_config(config) # transformers dispatches + +No weight interleaving needed (cute-DSL ``concat_layout=True``); no per-arch +SparseMoEBlock monkeypatching. -Usage: - pytest tests/e2e/integrations/test_sonicmoe.py -v -s +Requires: + - Hopper (sm_90) or Blackwell (sm_100+) GPU + - sonicmoe kernel available via HF kernels-community + - transformers >= 5.8 with Qwen3MoE Experts class """ import importlib.util @@ -16,20 +22,29 @@ import pytest import torch -_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None -_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) + +def _is_hopper_or_newer() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + pytestmark = [ pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"), pytest.mark.skipif( - not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)" + not _is_hopper_or_newer(), + reason="SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+)", + ), + pytest.mark.skipif( + importlib.util.find_spec("kernels") is None, + reason="HF `kernels` package not installed", ), - pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"), ] -def _create_tiny_qwen3_config(): - """Create a minimal Qwen3MoE config for fast testing.""" +def _create_tiny_qwen3_config(experts_implementation: str): + """Create a minimal Qwen3MoE config bound to the requested experts impl.""" from transformers import AutoConfig config = AutoConfig.for_model("qwen3_moe") @@ -46,137 +61,85 @@ def _create_tiny_qwen3_config(): config.max_position_embeddings = 128 config.norm_topk_prob = True config.torch_dtype = torch.bfloat16 + config._experts_implementation = experts_implementation return config -def _interleave_gate_up_weights(model): - """Interleave all gate_up_proj parameters in-place for SonicMoE.""" - from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( - interleave_gate_up, - ) - - with torch.no_grad(): - for name, param in model.named_parameters(): - if "gate_up_proj" in name: - param.copy_(interleave_gate_up(param)) - +def _build_model(experts_implementation: str): + from transformers import AutoModelForCausalLM -def _unpatch_sonicmoe(): - """Restore original forward on the MoE block class if it was patched.""" - from axolotl.integrations.kernels.constants import resolve_moe_block_classes + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, + ) - for moe_cls in resolve_moe_block_classes("qwen3_moe"): - if hasattr(moe_cls, "_original_forward"): - moe_cls.forward = moe_cls._original_forward - del moe_cls._original_forward + register_sonicmoe_experts() + config = _create_tiny_qwen3_config(experts_implementation) + return AutoModelForCausalLM.from_config(config).cuda().bfloat16(), config class TestSonicMoEForwardCorrectness: - """Verify SonicMoE-patched model produces same output as original.""" - - def teardown_method(self): - _unpatch_sonicmoe() - - def test_forward_output_matches(self): - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + """SonicMoE-dispatched model produces output close to eager baseline.""" - # Original model - model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + def test_forward_output_matches_eager(self): + input_ids = torch.randint(0, 1000, (1, 16), device="cuda") + eager_model, _ = _build_model("eager") with torch.no_grad(): - out_orig = model_orig(input_ids) + out_eager = eager_model(input_ids).logits - # Patched model (same weights, interleaved for SonicMoE) - model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - model_patched.load_state_dict(model_orig.state_dict()) - - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model_patched) + sonic_model, _ = _build_model("sonicmoe") + sonic_model.load_state_dict(eager_model.state_dict()) with torch.no_grad(): - out_patched = model_patched(input_ids) + out_sonic = sonic_model(input_ids).logits - max_diff = (out_orig.logits - out_patched.logits).abs().max().item() - assert torch.allclose( - out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1 - ), f"Output mismatch: max diff={max_diff:.6f}" + max_diff = (out_eager - out_sonic).abs().max().item() + assert torch.allclose(out_eager, out_sonic, atol=1e-1, rtol=1e-1), ( + f"Output mismatch: max diff={max_diff:.6f}" + ) class TestSonicMoEGradientCorrectness: - """Compare gradients between original HuggingFace and SonicMoE-patched forward.""" - - def teardown_method(self): - _unpatch_sonicmoe() + """Compare gradients between eager and SonicMoE-dispatched forward.""" def test_gradients_match(self): - """Verify all parameter gradients match between original and patched.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( - deinterleave_gate_up, - ) + input_ids = torch.randint(0, 1000, (1, 16), device="cuda") - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + eager_model, _ = _build_model("eager") + out_eager = eager_model(input_ids, labels=input_ids) + out_eager.loss.backward() + grads_eager = { + n: p.grad.float().clone() + for n, p in eager_model.named_parameters() + if p.grad is not None + } + loss_eager = out_eager.loss.item() - # ---------- Original model ---------- - model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - out_orig = model_orig(input_ids, labels=input_ids) - out_orig.loss.backward() - grads_orig = { + sonic_model, _ = _build_model("sonicmoe") + sonic_model.load_state_dict(eager_model.state_dict()) + out_sonic = sonic_model(input_ids, labels=input_ids) + out_sonic.loss.backward() + grads_sonic = { n: p.grad.float().clone() - for n, p in model_orig.named_parameters() + for n, p in sonic_model.named_parameters() if p.grad is not None } - loss_orig = out_orig.loss.item() - - # ---------- SonicMoE-patched model (same weights, interleaved) ---------- - model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - model_patched.load_state_dict(model_orig.state_dict()) - - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model_patched) - - out_patched = model_patched(input_ids, labels=input_ids) - out_patched.loss.backward() - grads_patched = {} - for n, p in model_patched.named_parameters(): - if p.grad is None: - continue - g = p.grad.float().clone() - # gate_up_proj grads are in interleaved layout, de-interleave to match orig - if "gate_up_proj" in n: - g = deinterleave_gate_up(g) - grads_patched[n] = g - loss_patched = out_patched.loss.item() - - # ---------- Compare ---------- - assert abs(loss_orig - loss_patched) < 0.5, ( - f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}" + loss_sonic = out_sonic.loss.item() + + assert abs(loss_eager - loss_sonic) < 0.5, ( + f"Loss mismatch: eager={loss_eager:.4f}, sonic={loss_sonic:.4f}" ) - # All parameters with gradients in original should have them in patched - missing = set(grads_orig.keys()) - set(grads_patched.keys()) - assert not missing, f"Missing gradients in patched model: {missing}" + missing = set(grads_eager.keys()) - set(grads_sonic.keys()) + assert not missing, f"Missing gradients in sonicmoe model: {missing}" - # Compare gradient values - # bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge, - # so use generous tolerance: flag only if both rel >10% AND abs >1e-2 + # bf16 + different GEMM backends can diverge; tolerate both rel >10% AND + # abs >1e-2 together. mismatches = [] - for name in grads_orig: - if name not in grads_patched: - continue - g_orig = grads_orig[name] - g_patched = grads_patched[name] - max_diff = (g_orig - g_patched).abs().max().item() - rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8) - + for name, g_eager in grads_eager.items(): + g_sonic = grads_sonic[name] + max_diff = (g_eager - g_sonic).abs().max().item() + rel_diff = max_diff / (g_eager.abs().max().item() + 1e-8) if rel_diff > 0.1 and max_diff > 1e-2: mismatches.append( f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}" @@ -188,18 +151,8 @@ def test_gradients_match(self): ) def test_router_weights_receive_gradients(self): - """Verify that router (gate) weights get non-zero gradients.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) - + input_ids = torch.randint(0, 1000, (1, 16), device="cuda") + model, _ = _build_model("sonicmoe") out = model(input_ids, labels=input_ids) out.loss.backward() @@ -216,21 +169,9 @@ def test_router_weights_receive_gradients(self): class TestSonicMoETrainingConvergence: """Verify loss decreases during training with SonicMoE.""" - def teardown_method(self): - _unpatch_sonicmoe() - def test_loss_decreases(self): - """Run 30 training steps, verify loss decreases and no NaN/Inf.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model, _ = _build_model("sonicmoe") optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) losses = [] @@ -251,24 +192,14 @@ def test_loss_decreases(self): ) def test_expert_weights_update(self): - """Verify expert weights change during training (not frozen).""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) - - # Snapshot expert weights before training - expert_weights_before = {} - for name, param in model.named_parameters(): - if "experts" in name: - expert_weights_before[name] = param.data.clone() + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model, _ = _build_model("sonicmoe") + expert_weights_before = { + name: param.data.clone() + for name, param in model.named_parameters() + if "experts" in name + } assert expert_weights_before, "No expert parameters found" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) @@ -278,11 +209,10 @@ def test_expert_weights_update(self): optimizer.step() optimizer.zero_grad() - # Check that expert weights changed - changed = 0 - for name, param in model.named_parameters(): - if name in expert_weights_before: - if not torch.equal(param.data, expert_weights_before[name]): - changed += 1 - + changed = sum( + 1 + for name, param in model.named_parameters() + if name in expert_weights_before + and not torch.equal(param.data, expert_weights_before[name]) + ) assert changed > 0, "No expert weights changed after 5 training steps" diff --git a/tests/e2e/integrations/test_sonicmoe_lora.py b/tests/e2e/integrations/test_sonicmoe_lora.py index 74721ee57a..240b015141 100644 --- a/tests/e2e/integrations/test_sonicmoe_lora.py +++ b/tests/e2e/integrations/test_sonicmoe_lora.py @@ -3,20 +3,24 @@ # Licensed under the Apache License, Version 2.0 """ -End-to-end tests for SonicMoE + LoRA integration. +End-to-end tests for SonicMoE + LoRA after the ExpertsInterface refactor. -Verifies that PEFT-wrapped MoE models work correctly with SonicMoE's -runtime LoRA materialization: gradients flow to adapters, base weights -stay frozen, and loss converges. +The new flow: -Requires: - - H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90) - - sonicmoe package installed - - peft package installed - - transformers with Qwen3MoE support + register_sonicmoe_experts() # plug into ALL_EXPERTS_FUNCTIONS + config._experts_implementation = "sonicmoe" + model = AutoModelForCausalLM.from_config(config) + model = get_peft_model(model, lora_config) # PEFT wraps params/modules + +Our registered ``sonicmoe_experts_forward_with_lora`` detects the PEFT +wrappers and materializes ``W_eff = W + scaling * (B @ A)`` via +:class:`MoELoRAMaterialize`, so adapters train through the CUTLASS kernels. -Usage: - pytest tests/e2e/integrations/test_sonicmoe_lora.py -v -s +Requires: + - Hopper (sm_90) or Blackwell (sm_100+) GPU + - sonicmoe kernel available via HF kernels-community + - peft installed + - transformers >= 5.8 with Qwen3MoE Experts class """ import importlib.util @@ -25,22 +29,31 @@ import pytest import torch -_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None -_peft_available = importlib.util.find_spec("peft") is not None -_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) + +def _is_hopper_or_newer() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + pytestmark = [ pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"), pytest.mark.skipif( - not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)" + not _is_hopper_or_newer(), + reason="SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+)", + ), + pytest.mark.skipif( + importlib.util.find_spec("kernels") is None, + reason="HF `kernels` package not installed", + ), + pytest.mark.skipif( + importlib.util.find_spec("peft") is None, reason="PEFT not installed" ), - pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"), - pytest.mark.skipif(not _peft_available, reason="PEFT not installed"), ] def _create_tiny_qwen3_config(): - """Create a minimal Qwen3MoE config for fast testing.""" from transformers import AutoConfig config = AutoConfig.for_model("qwen3_moe") @@ -57,33 +70,23 @@ def _create_tiny_qwen3_config(): config.max_position_embeddings = 128 config.norm_topk_prob = True config.torch_dtype = torch.bfloat16 + config._experts_implementation = "sonicmoe" return config -def _interleave_gate_up_weights(model): - """Interleave all gate_up_proj parameters in-place for SonicMoE.""" - from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( - interleave_gate_up, - ) - - with torch.no_grad(): - for name, param in model.named_parameters(): - if "gate_up_proj" in name: - param.copy_(interleave_gate_up(param)) - +def _build_sonic_model(): + from transformers import AutoModelForCausalLM -def _unpatch_sonicmoe(): - """Restore original forward on the MoE block class if it was patched.""" - from axolotl.integrations.kernels.constants import resolve_moe_block_classes + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, + ) - for moe_cls in resolve_moe_block_classes("qwen3_moe"): - if hasattr(moe_cls, "_original_forward"): - moe_cls.forward = moe_cls._original_forward - del moe_cls._original_forward + register_sonicmoe_experts() + config = _create_tiny_qwen3_config() + return AutoModelForCausalLM.from_config(config).cuda().bfloat16() def _apply_lora(model, target_modules): - """Apply PEFT LoRA to the model.""" from peft import LoraConfig, get_peft_model lora_config = LoraConfig( @@ -97,37 +100,23 @@ def _apply_lora(model, target_modules): class TestSonicMoELoRATraining: - """Verify SonicMoE + LoRA training works end-to-end.""" - - def teardown_method(self): - _unpatch_sonicmoe() + """SonicMoE + LoRA on expert projections trains end-to-end.""" def test_loss_decreases(self): - """Run 30 training steps with LoRA on experts, verify loss decreases.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model = _build_sonic_model() model = _apply_lora(model, ["gate_up_proj", "down_proj"]) optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=1e-3 ) losses = [] - for step in range(30): out = model(input_ids, labels=input_ids) loss = out.loss assert not math.isnan(loss.item()), f"NaN loss at step {step}" assert not math.isinf(loss.item()), f"Inf loss at step {step}" losses.append(loss.item()) - loss.backward() optimizer.step() optimizer.zero_grad() @@ -137,24 +126,15 @@ def test_loss_decreases(self): ) def test_base_weights_frozen(self): - """Verify base (non-LoRA) weights don't change during training.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model = _build_sonic_model() model = _apply_lora(model, ["gate_up_proj", "down_proj"]) - # Snapshot frozen weights - frozen_before = {} - for name, param in model.named_parameters(): - if not param.requires_grad: - frozen_before[name] = param.data.clone() + frozen_before = { + name: param.data.clone() + for name, param in model.named_parameters() + if not param.requires_grad + } optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=1e-3 @@ -165,24 +145,13 @@ def test_base_weights_frozen(self): optimizer.step() optimizer.zero_grad() - for name, param in model.named_parameters(): - if name in frozen_before: - assert torch.equal(param.data, frozen_before[name]), ( - f"Frozen weight changed: {name}" - ) + for name, before in frozen_before.items(): + after = dict(model.named_parameters())[name] + assert torch.equal(after.data, before), f"Frozen weight changed: {name}" def test_lora_adapters_receive_gradients(self): - """Verify LoRA A and B matrices get non-zero gradients.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (1, 16), device="cuda") + model = _build_sonic_model() model = _apply_lora(model, ["gate_up_proj", "down_proj"]) out = model(input_ids, labels=input_ids) @@ -200,25 +169,15 @@ def test_lora_adapters_receive_gradients(self): assert lora_grads_found > 0, "No LoRA parameters found with gradients" def test_lora_adapters_update(self): - """Verify LoRA adapter weights change during training.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model = _build_sonic_model() model = _apply_lora(model, ["gate_up_proj", "down_proj"]) - # Snapshot LoRA weights - lora_before = {} - for name, param in model.named_parameters(): - if "lora_" in name and param.requires_grad: - lora_before[name] = param.data.clone() - + lora_before = { + name: param.data.clone() + for name, param in model.named_parameters() + if "lora_" in name and param.requires_grad + } assert lora_before, "No LoRA parameters found" optimizer = torch.optim.AdamW( @@ -239,38 +198,23 @@ def test_lora_adapters_update(self): class TestSonicMoEGateOnlyLoRA: - """Verify LoRA targeting only the gate (router) works with SonicMoE.""" - - def teardown_method(self): - _unpatch_sonicmoe() + """LoRA only on the router (gate) — expert path takes the no-LoRA fast path.""" def test_gate_only_lora_loss_decreases(self): - """LoRA only on gate — expert path should have zero materialization overhead.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) - # Only target the gate (router), not expert projections + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model = _build_sonic_model() model = _apply_lora(model, ["gate"]) optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=1e-3 ) losses = [] - for step in range(20): out = model(input_ids, labels=input_ids) loss = out.loss assert not math.isnan(loss.item()), f"NaN loss at step {step}" assert not math.isinf(loss.item()), f"Inf loss at step {step}" losses.append(loss.item()) - loss.backward() optimizer.step() optimizer.zero_grad() @@ -281,34 +225,20 @@ def test_gate_only_lora_loss_decreases(self): class TestSonicMoENoLoRARegression: - """Verify SonicMoE without LoRA still works after LoRA code was added.""" - - def teardown_method(self): - _unpatch_sonicmoe() + """Full fine-tuning (no PEFT) still works through the registered forward.""" def test_no_lora_loss_decreases(self): - """Full fine-tuning (no PEFT) with SonicMoE — regression test.""" - from transformers import AutoModelForCausalLM - - from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe - - config = _create_tiny_qwen3_config() - input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") - - model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() - patch_sonicmoe("qwen3_moe") - _interleave_gate_up_weights(model) + input_ids = torch.randint(0, 1000, (2, 32), device="cuda") + model = _build_sonic_model() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) losses = [] - for step in range(20): out = model(input_ids, labels=input_ids) loss = out.loss assert not math.isnan(loss.item()), f"NaN loss at step {step}" assert not math.isinf(loss.item()), f"Inf loss at step {step}" losses.append(loss.item()) - loss.backward() optimizer.step() optimizer.zero_grad() diff --git a/tests/integrations/test_gemma4_moe.py b/tests/integrations/test_gemma4_moe.py index 412d49b2f9..e3408c7822 100644 --- a/tests/integrations/test_gemma4_moe.py +++ b/tests/integrations/test_gemma4_moe.py @@ -462,94 +462,6 @@ class TestGemma4SonicMoE: import are skipped on unsupported GPUs. """ - @pytest.mark.skipif( - not _can_import_sonicmoe(), - reason="sonicmoe requires Hopper/Blackwell GPU", - ) - def test_gemma4_routing_function_config(self, gemma4_config): - """Gemma4 is registered with correct routing config.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - get_model_moe_config, - ) - - routing_fn, activation, router_attr = get_model_moe_config("gemma4_text") - - assert router_attr == "router" - assert routing_fn is not None - assert routing_fn.__name__ == "gemma4_routing" - - from sonicmoe.enums import ActivationType - - assert activation == ActivationType.GEGLU - - @pytest.mark.skipif( - not _can_import_sonicmoe(), - reason="sonicmoe requires Hopper/Blackwell GPU", - ) - def test_gemma4_routing_matches_reference(self, gemma4_config): - """Routing function output matches reference Gemma4TextRouter.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - get_model_moe_config, - ) - - routing_fn, _, _ = get_model_moe_config("gemma4_text") - H = gemma4_config["hidden_size"] - E = gemma4_config["num_experts"] - K = gemma4_config["top_k"] - T = 16 - - router = Gemma4TextRouter(H, E, K) - nn.init.normal_(router.proj.weight, std=0.01) - - class MockGemma4MoeBlock: - pass - - mock_block = MockGemma4MoeBlock() - mock_block.router = router - - hidden_states = torch.randn(T, H) - - # Reference - _ref_probs, ref_weights, ref_indices = router(hidden_states) - - # Routing function - flat_scores, flat_token_idx, flat_expert_idx, router_logits = routing_fn( - hidden_states, mock_block - ) - - # Check shapes - assert flat_scores.shape == (T * K,) - assert flat_token_idx.shape == (T * K,) - assert flat_expert_idx.shape == (T * K,) - assert router_logits.shape == (T, E) - - # Reconstruct per-token routing from flat output and compare - for t in range(T): - mask = flat_token_idx == t - assert mask.sum() == K, f"Token {t} should have {K} entries" - - flat_experts_for_t = flat_expert_idx[mask].sort().values - ref_experts_for_t = ref_indices[t].sort().values.to(torch.int32) - assert torch.equal(flat_experts_for_t, ref_experts_for_t), ( - f"Token {t}: experts mismatch" - ) - - # Verify scores match reference per-token - for t in range(T): - mask = flat_token_idx == t - flat_experts_t = flat_expert_idx[mask] - flat_scores_t = flat_scores[mask] - - sort_idx = flat_experts_t.argsort() - flat_scores_sorted = flat_scores_t[sort_idx] - - ref_sort_idx = ref_indices[t].argsort() - ref_scores_sorted = ref_weights[t][ref_sort_idx].float() - - torch.testing.assert_close( - flat_scores_sorted, ref_scores_sorted, atol=1e-4, rtol=1e-4 - ) - def test_gemma4_weight_layout_compatible(self, gemma4_config): """Verify Gemma4 expert weight layout is compatible with SonicMoE.""" E = gemma4_config["num_experts"] @@ -582,12 +494,6 @@ def test_gemma4_is_experts_only_model(self): assert cls is not None assert cls.__name__ == "Gemma4TextExperts" - def test_gemma4_not_in_sparse_moe_block(self): - """Verify gemma4_text is NOT in SPARSE_MOE_BLOCK (has no SparseMoeBlock).""" - from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK - - assert "gemma4_text" not in SPARSE_MOE_BLOCK - # ============================================================================ # Integration Tests (full layer with real model config) @@ -790,7 +696,7 @@ def test_register_scattermoe_in_experts_interface(self): """register_scattermoe_experts adds 'scattermoe' to the global interface.""" from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, scattermoe_experts_forward, ) @@ -806,7 +712,7 @@ def test_experts_implementation_dispatches_to_scattermoe(self, device): Gemma4TextExperts as HFGemma4TextExperts, ) - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, ) @@ -853,7 +759,7 @@ def test_validation_accepts_scattermoe(self): """get_correct_experts_implementation accepts 'scattermoe' after registration.""" from transformers.modeling_utils import PreTrainedModel - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, ) @@ -869,7 +775,7 @@ def test_eager_still_works_after_registration(self, device): Gemma4TextExperts as HFGemma4TextExperts, ) - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, ) @@ -982,7 +888,7 @@ def model_setup(self, request, device): """Create an Experts instance for each model type.""" import importlib - from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import ( + from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, ) diff --git a/tests/integrations/test_routing_parity.py b/tests/integrations/test_routing_parity.py deleted file mode 100644 index 8852068096..0000000000 --- a/tests/integrations/test_routing_parity.py +++ /dev/null @@ -1,492 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) Axolotl AI -# Licensed under the Apache License, Version 2.0 - -""" -Parity tests between scattermoe-lora and sonicmoe routing implementations. - -These tests verify that both implementations produce numerically identical -results for the same inputs, ensuring safe centralization of the routing code. - -ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K]. -The core algorithm should be identical — only the output format differs. -""" - -from types import SimpleNamespace - -import pytest -import torch - - -def _require_triton(): - pytest.importorskip("triton") - - -# ============================================================================ -# Fixtures / helpers -# ============================================================================ - - -def _make_softmax_block(T=8, H=16, E=4, K=2): - """Qwen/OLMoE-style block usable by both implementations.""" - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - num_experts=E, - norm_topk_prob=True, - ) - moe_block = SimpleNamespace(gate=gate) - hidden = torch.randn(T, H) - return moe_block, gate, hidden, T, H, E, K - - -def _make_sigmoid_block( - T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True -): - """GLM/DeepSeek-style block usable by both implementations.""" - if bias_on_gate: - gate = SimpleNamespace( - weight=torch.randn(E, H), - e_score_correction_bias=torch.zeros(E), - ) - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - n_routed_experts=E, - n_group=n_group, - topk_group=topk_group, - norm_topk_prob=True, - routed_scaling_factor=1.0, - ) - else: - # minimax_m2 style: bias on block - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - ) - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - e_score_correction_bias=torch.zeros(E), - ) - return moe_block, gate, hidden_states(T, H), T, H, E, K - - -def hidden_states(T, H): - return torch.randn(T, H) - - -# ============================================================================ -# 1. Softmax routing parity -# ============================================================================ - - -class TestSoftmaxRoutingParity: - """Verify scattermoe and sonicmoe softmax routing produce identical results.""" - - @pytest.fixture(autouse=True) - def _require(self): - _require_triton() - - def test_weights_match(self): - """2D weights from scattermoe == reshaped 1D weights from sonicmoe.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _softmax_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_softmax_block() - - # ScatterMoE path (no LoRA delta) - sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - - # SonicMoE path - sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing( - hidden, moe_block - ) - - # ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - assert sm_topk == K - assert sm_E == E - - # Both should select the same experts and produce the same weights - assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)) - assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6) - - def test_logits_not_returned_by_scattermoe(self): - """ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_softmax_block() - _, _, _, logits = softmax_topk_routing(hidden, moe_block) - assert logits.shape == (T, E) - - def test_no_renorm(self): - """With norm_topk_prob=False, both should skip renormalization.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _softmax_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_softmax_block() - gate.norm_topk_prob = False - - sm_weights, sm_experts, _, _ = _softmax_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)) - assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6) - - def test_various_expert_counts(self): - """Parity across different E and K values.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _softmax_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_routing, - ) - - for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]: - moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K) - - sm_weights, sm_experts, _, _ = _softmax_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), ( - f"Expert mismatch for E={E}, K={K}" - ) - assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), ( - f"Weight mismatch for E={E}, K={K}" - ) - - -# ============================================================================ -# 2. Sigmoid routing parity -# ============================================================================ - - -class TestSigmoidRoutingParity: - """Verify scattermoe and sonicmoe sigmoid routing produce identical results.""" - - @pytest.fixture(autouse=True) - def _require(self): - _require_triton() - - def test_weights_match_with_groups(self): - """Both implementations should produce identical weights with group selection.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _sigmoid_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True - ) - - sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - - sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing( - hidden, moe_block - ) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - assert sm_topk == K - assert sm_E == E - - # Sort experts within each token to handle different topk orderings - sm_sorted, sm_order = sm_experts.sort(dim=-1) - sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) - - assert torch.equal(sm_sorted, sonic_sorted) - - # Gather weights in sorted order for comparison - sm_weights_sorted = sm_weights.gather(1, sm_order) - sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) - assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) - - def test_weights_match_no_groups(self): - """Both implementations match without group selection (n_group=1).""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _sigmoid_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True - ) - - sm_weights, sm_experts, _, _ = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - # Sort for comparison (topk with sorted=False may differ in order) - sm_sorted, sm_order = sm_experts.sort(dim=-1) - sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) - - assert torch.equal(sm_sorted, sonic_sorted) - sm_weights_sorted = sm_weights.gather(1, sm_order) - sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) - assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) - - def test_bias_on_block_parity(self): - """minimax_m2 style: bias on block, not gate.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _sigmoid_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - E=16, K=4, n_group=1, bias_on_gate=False - ) - - sm_weights, sm_experts, _, _ = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - sm_sorted, sm_order = sm_experts.sort(dim=-1) - sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) - - assert torch.equal(sm_sorted, sonic_sorted) - sm_weights_sorted = sm_weights.gather(1, sm_order) - sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) - assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) - - def test_scaling_factor_parity(self): - """routed_scaling_factor applied identically by both.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _sigmoid_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - n_group=1, bias_on_gate=True - ) - moe_block.routed_scaling_factor = 2.5 - - sm_weights, sm_experts, _, _ = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - sm_sorted, sm_order = sm_experts.sort(dim=-1) - sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) - - assert torch.equal(sm_sorted, sonic_sorted) - sm_weights_sorted = sm_weights.gather(1, sm_order) - sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) - assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) - - def test_no_renorm_parity(self): - """norm_topk_prob=False produces same results in both.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _sigmoid_topk_route, - ) - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - n_group=1, bias_on_gate=True - ) - moe_block.norm_topk_prob = False - - sm_weights, sm_experts, _, _ = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - sonic_weights_2d = sonic_scores.reshape(T, K) - sonic_experts_2d = sonic_exp_idx.reshape(T, K) - - sm_sorted, sm_order = sm_experts.sort(dim=-1) - sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) - - assert torch.equal(sm_sorted, sonic_sorted) - sm_weights_sorted = sm_weights.gather(1, sm_order) - sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) - assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) - - -# ============================================================================ -# 3. Shared expert parity -# ============================================================================ - - -class TestSharedExpertParity: - """Verify both _compute_shared_expert implementations behave identically.""" - - @pytest.fixture(autouse=True) - def _require(self): - _require_triton() - - def _get_both_fns(self): - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _compute_shared_expert as scatter_compute, - ) - from axolotl.integrations.kernels.libs.sonicmoe.patch import ( - _compute_shared_expert as sonic_compute, - ) - - return scatter_compute, sonic_compute - - def test_shared_expert_singular(self): - scatter_fn, sonic_fn = self._get_both_fns() - out = torch.randn(4, 8) - block = SimpleNamespace(shared_expert=lambda x: out) - hidden = torch.randn(4, 8) - - assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) - - def test_shared_experts_plural(self): - scatter_fn, sonic_fn = self._get_both_fns() - out = torch.randn(4, 8) - block = SimpleNamespace(shared_experts=lambda x: out) - hidden = torch.randn(4, 8) - - assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) - - def test_shared_mlp(self): - scatter_fn, sonic_fn = self._get_both_fns() - out = torch.randn(4, 8) - block = SimpleNamespace(shared_mlp=lambda x: out) - hidden = torch.randn(4, 8) - - assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) - - def test_no_shared_expert(self): - scatter_fn, sonic_fn = self._get_both_fns() - block = SimpleNamespace() - hidden = torch.randn(4, 8) - - assert scatter_fn(block, hidden) is None - assert sonic_fn(block, hidden) is None - - def test_shared_expert_gate_only_in_scattermoe(self): - """ScatterMoE's _compute_shared_expert handles shared_expert_gate; - SonicMoE's patch.py handles it externally in the forward function. - - This documents the known divergence: the scattermoe version applies - sigmoid gating inline, while sonicmoe applies it in the forward. - """ - scatter_fn, sonic_fn = self._get_both_fns() - - H = 8 - expert_out = torch.ones(4, H) - gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5 - - block = SimpleNamespace( - shared_expert=lambda x: expert_out, - shared_expert_gate=gate_fn, - ) - hidden = torch.randn(4, H) - - scatter_result = scatter_fn(block, hidden) - sonic_result = sonic_fn(block, hidden) - - # ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5 - expected_gated = expert_out * 0.5 - assert torch.allclose(scatter_result, expected_gated, atol=1e-6) - - # SonicMoE does NOT apply the gate here (it does it in the forward) - assert torch.equal(sonic_result, expert_out) - - -# ============================================================================ -# 4. Route dispatcher parity -# ============================================================================ - - -class TestRouteDispatcherParity: - """Verify _route in scattermoe dispatches correctly and matches individual fns.""" - - @pytest.fixture(autouse=True) - def _require(self): - _require_triton() - - def test_route_dispatches_softmax(self): - """_route should use softmax when no e_score_correction_bias.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _route, - _softmax_topk_route, - ) - - moe_block, gate, hidden, T, H, E, K = _make_softmax_block() - - route_w, route_e, route_k, route_E = _route( - moe_block, gate, hidden, gate.weight, None - ) - direct_w, direct_e, direct_k, direct_E = _softmax_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - - assert torch.equal(route_w, direct_w) - assert torch.equal(route_e, direct_e) - assert route_k == direct_k - assert route_E == direct_E - - def test_route_dispatches_sigmoid(self): - """_route should use sigmoid when e_score_correction_bias is present.""" - from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( - _route, - _sigmoid_topk_route, - ) - - moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( - n_group=1, bias_on_gate=True - ) - - route_w, route_e, route_k, route_E = _route( - moe_block, gate, hidden, gate.weight, None - ) - direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route( - moe_block, gate, hidden, gate.weight, None - ) - - assert torch.equal(route_w, direct_w) - assert torch.equal(route_e, direct_e) - assert route_k == direct_k - assert route_E == direct_E diff --git a/tests/integrations/test_sonicmoe.py b/tests/integrations/test_sonicmoe.py index 864abca36d..f7261a85d9 100644 --- a/tests/integrations/test_sonicmoe.py +++ b/tests/integrations/test_sonicmoe.py @@ -1,4 +1,4 @@ -"""Unit tests for the SonicMoE integration.""" +"""Unit tests for the SonicMoE ExpertsInterface registration.""" from types import SimpleNamespace @@ -6,15 +6,6 @@ import torch from axolotl.integrations.kernels.args import KernelsArgs -from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - softmax_topk_routing, -) -from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( - ConcatenatedToInterleaved, - InterleavedToConcatenated, - register_sonicmoe_weight_converter, -) class TestKernelsArgs: @@ -43,777 +34,202 @@ def test_disables_mlp_kernel_when_sonicmoe(self): assert result["lora_mlp_kernel"] is False assert result["mlp_kernel"] is False + def test_experts_implementation_auto_sonicmoe(self): + out = KernelsArgs.check_experts_implementation({"use_sonicmoe": True}) + assert out["experts_implementation"] == "sonicmoe" -class TestConcatenatedToInterleaved: - @pytest.fixture - def sample_tensor(self): - """Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values.""" - E, I, H = 2, 2, 3 # noqa: E741 - gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H) - up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H) - return torch.cat([gate, up], dim=1) - - def test_interleave_rows_alternate(self, sample_tensor): - op = ConcatenatedToInterleaved(dim=1) - result = op.convert( - {"test": sample_tensor}, - source_patterns=["test"], - target_patterns=["test"], - ) - interleaved = result["test"] - - # For expert 0: even rows should be gate, odd rows should be up - E, two_I, H = sample_tensor.shape - I = two_I // 2 # noqa: E741 - gate_orig = sample_tensor[:, :I, :] - up_orig = sample_tensor[:, I:, :] - - assert torch.equal(interleaved[:, 0::2, :], gate_orig) - assert torch.equal(interleaved[:, 1::2, :], up_orig) - - def test_interleave_handles_list_input(self, sample_tensor): - op = ConcatenatedToInterleaved(dim=1) - result = op.convert( - {"test": [sample_tensor]}, - source_patterns=["test"], - target_patterns=["test"], - ) - assert result["test"].shape == sample_tensor.shape - - def test_reverse_op_type(self): - op = ConcatenatedToInterleaved(dim=1) - assert isinstance(op.reverse_op, InterleavedToConcatenated) - assert op.reverse_op.dim == 1 - - -class TestInterleavedToConcatenated: - @pytest.fixture - def interleaved_tensor(self): - """Create an interleaved tensor [E=2, 2*I=4, H=3].""" - E, I, H = 2, 2, 3 # noqa: E741 - gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H) - up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H) - interleaved = torch.empty(E, 2 * I, H) - interleaved[:, 0::2, :] = gate - interleaved[:, 1::2, :] = up - return interleaved - - def test_deinterleave_gate_up_separated(self, interleaved_tensor): - op = InterleavedToConcatenated(dim=1) - result = op.convert( - {"test": interleaved_tensor}, - source_patterns=["test"], - target_patterns=["test"], - ) - concatenated = result["test"] - - E, two_I, H = concatenated.shape - I = two_I // 2 # noqa: E741 - - # First half should be gate (even rows from interleaved) - assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :]) - # Second half should be up (odd rows from interleaved) - assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :]) - - def test_reverse_op_type(self): - op = InterleavedToConcatenated(dim=1) - assert isinstance(op.reverse_op, ConcatenatedToInterleaved) - assert op.reverse_op.dim == 1 - - -class TestRoundTrip: - @pytest.fixture - def concat_tensor(self): - E, I, H = 4, 8, 16 # noqa: E741 - gate = torch.randn(E, I, H) - up = torch.randn(E, I, H) - return torch.cat([gate, up], dim=1) - - def test_interleave_then_deinterleave_is_identity(self, concat_tensor): - fwd = ConcatenatedToInterleaved(dim=1) - rev = InterleavedToConcatenated(dim=1) - - interleaved = fwd.convert( - {"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"] - )["k"] - recovered = rev.convert( - {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] - )["k"] - - assert torch.equal(concat_tensor, recovered) - - def test_reverse_op_chain_is_identity(self, concat_tensor): - """Verify that op.reverse_op produces an exact inverse.""" - op = ConcatenatedToInterleaved(dim=1) - rev = op.reverse_op - - interleaved = op.convert( - {"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"] - )["k"] - recovered = rev.convert( - {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] - )["k"] - - assert torch.equal(concat_tensor, recovered) - - def test_various_shapes(self): - """Test with different expert counts and dimensions.""" - fwd = ConcatenatedToInterleaved(dim=1) - rev = InterleavedToConcatenated(dim=1) - - for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741 - concat = torch.randn(E, 2 * I, H) - interleaved = fwd.convert( - {"k": concat}, source_patterns=["k"], target_patterns=["k"] - )["k"] - recovered = rev.convert( - {"k": interleaved}, source_patterns=["k"], target_patterns=["k"] - )["k"] - assert torch.equal(concat, recovered), ( - f"Failed for shape ({E}, {2 * I}, {H})" - ) - - -class TestWeightConverterRegistration: - def test_register_appends_interleave_op(self): - from transformers.conversion_mapping import get_checkpoint_conversion_mapping - - register_sonicmoe_weight_converter("qwen3_moe") - - modified = get_checkpoint_conversion_mapping("qwen3_moe") - # Find the gate_up_proj converter - gate_up_converter = None - for conv in modified: - if hasattr(conv, "operations") and any( - "gate_up_proj" in pat for pat in conv.target_patterns - ): - gate_up_converter = conv - break - - assert gate_up_converter is not None - assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved) - - def test_double_registration_is_idempotent(self): - from transformers.conversion_mapping import get_checkpoint_conversion_mapping - - register_sonicmoe_weight_converter("qwen3_moe") - register_sonicmoe_weight_converter("qwen3_moe") - - modified = get_checkpoint_conversion_mapping("qwen3_moe") - for conv in modified: - if hasattr(conv, "operations") and any( - "gate_up_proj" in pat for pat in conv.target_patterns - ): - interleave_count = sum( - isinstance(op, ConcatenatedToInterleaved) for op in conv.operations - ) - assert interleave_count == 1, ( - f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}" - ) - break - - def test_register_adds_same_key_converter(self): - from transformers.conversion_mapping import get_checkpoint_conversion_mapping - - register_sonicmoe_weight_converter("qwen3_moe") - - modified = get_checkpoint_conversion_mapping("qwen3_moe") - # Should have a same-key converter for already-fused checkpoints - same_key = [ - c - for c in modified - if hasattr(c, "source_patterns") - and c.source_patterns == ["mlp.experts.gate_up_proj"] - and c.target_patterns == ["mlp.experts.gate_up_proj"] - ] - assert len(same_key) == 1 - assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved) - - def test_register_creates_mapping_when_none(self): - from transformers.conversion_mapping import get_checkpoint_conversion_mapping - - # qwen3_5_moe has no conversion mapping in transformers - register_sonicmoe_weight_converter("qwen3_5_moe") - - mapping = get_checkpoint_conversion_mapping("qwen3_5_moe") - assert mapping is not None - same_key = [ - c - for c in mapping - if hasattr(c, "source_patterns") - and c.source_patterns == ["mlp.experts.gate_up_proj"] - and c.target_patterns == ["mlp.experts.gate_up_proj"] - ] - assert len(same_key) == 1 - assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved) - - -def _make_qwen_moe_block(T=8, H=16, E=4, K=2): - """Create a mock qwen-style MoE block for routing tests.""" - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - num_experts=E, - norm_topk_prob=True, - ) - return SimpleNamespace(gate=gate), T, H, E, K - - -def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1): - """Create a mock GLM5-style MoE block for routing tests.""" - gate = SimpleNamespace( - weight=torch.randn(E, H), - e_score_correction_bias=torch.zeros(E), - ) - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - n_routed_experts=E, - n_group=n_group, - topk_group=topk_group, - norm_topk_prob=True, - routed_scaling_factor=1.0, - ) - return moe_block, T, H, E, K - - -def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4): - """Create a mock minimax_m2-style MoE block for routing tests. - - minimax_m2 uses sigmoid->topk WITHOUT group selection: - - e_score_correction_bias is on the moe_block (not on gate) - - No n_group / topk_group attributes - - Always normalizes (norm_topk_prob defaults to True) - - No routed_scaling_factor (defaults to 1.0) - """ - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - ) - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - e_score_correction_bias=torch.zeros(E), - ) - return moe_block, T, H, E, K - - -class TestSoftmaxTopkRouting: - def test_output_shapes(self): - moe_block, T, H, E, K = _make_qwen_moe_block() - hidden = torch.randn(T, H) - - scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block) - - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) - - def test_scores_are_float32(self): - moe_block, T, H, E, K = _make_qwen_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = softmax_topk_routing(hidden, moe_block) - assert scores.dtype == torch.float32 - - def test_token_indices_sorted_ascending(self): - moe_block, T, H, E, K = _make_qwen_moe_block() - hidden = torch.randn(T, H) - - _, token_idx, _, _ = softmax_topk_routing(hidden, moe_block) - - # Token indices must be sorted ascending (SonicMoE requirement) - diffs = token_idx[1:] - token_idx[:-1] - assert (diffs >= 0).all() - - def test_expert_indices_in_range(self): - moe_block, T, H, E, K = _make_qwen_moe_block() - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block) - - assert (expert_idx >= 0).all() - assert (expert_idx < E).all() - - def test_renormalized_scores_sum_to_one(self): - moe_block, T, H, E, K = _make_qwen_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = softmax_topk_routing(hidden, moe_block) - per_token_sums = scores.reshape(T, K).sum(dim=-1) - assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) - - -class TestSigmoidTopkRouting: - def test_output_shapes(self): - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) - - scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block) - - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) - - def test_scores_are_float32(self): - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - assert scores.dtype == torch.float32 - - def test_token_indices_sorted_ascending(self): - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) - - _, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block) - - diffs = token_idx[1:] - token_idx[:-1] - assert (diffs >= 0).all() + def test_experts_implementation_auto_scattermoe(self): + out = KernelsArgs.check_experts_implementation({"use_scattermoe": True}) + assert out["experts_implementation"] == "scattermoe" - def test_expert_indices_in_range(self): - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) + def test_experts_implementation_default_eager(self): + out = KernelsArgs.check_experts_implementation({}) + assert out["experts_implementation"] == "eager" - _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - assert (expert_idx >= 0).all() - assert (expert_idx < E).all() - - def test_scores_are_nonnegative(self): - """Sigmoid outputs are in [0, 1], so scores should be non-negative.""" - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - assert (scores >= 0).all() - - def test_scaling_factor_applied(self): - moe_block, T, H, E, K = _make_glm_moe_block() - hidden = torch.randn(T, H) - - # Get scores with scaling_factor=1.0 - scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - - # Get scores with scaling_factor=2.0 - moe_block.routed_scaling_factor = 2.0 - scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - - assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5) - - def test_group_selection_restricts_experts(self): - """With n_group=4 and topk_group=1, only 1/4 of experts should be selectable.""" - moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1) - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - # Each token's experts should all fall within a single group (size E//n_group=4) - expert_idx_2d = expert_idx.reshape(T, K) - for t in range(T): - experts = expert_idx_2d[t] - groups = experts // (E // moe_block.n_group) - # All selected experts should be from the same group - assert (groups == groups[0]).all() - - -class TestMiniMaxM2SigmoidRouting: - """Tests for minimax_m2 routing: sigmoid->topk without group selection.""" - - def test_output_shapes(self): - """Validates getattr defaults work: n_group=1, E from gate.weight.shape[0].""" - moe_block, T, H, E, K = _make_minimax_m2_moe_block() - hidden = torch.randn(T, H) - - scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block) - - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) - - def test_bias_on_block_not_gate(self): - """Verify that e_score_correction_bias on the block (not gate) is used.""" - T, H, E, K = 8, 16, 8, 2 - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - ) - # Large positive bias on expert 0 should make it selected more often - bias = torch.zeros(E) - bias[0] = 100.0 - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - e_score_correction_bias=bias, + def test_sonicmoe_impl_requires_flag(self): + out = KernelsArgs.check_experts_implementation( + {"experts_implementation": "sonicmoe"} ) - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block) - - # Expert 0 should appear for every token due to the large bias - expert_idx_2d = expert_idx.reshape(T, K) - for t in range(T): - assert 0 in expert_idx_2d[t] - - -# ============================================================================ -# Ernie 4.5 MoE: softmax -> bias correction -> topk -# ============================================================================ - - -def _make_ernie_moe_block(T=8, H=16, E=8, K=2, norm_min=1e-20): - """Create a mock Ernie 4.5 MoE block for routing tests. - - Ernie 4.5 uses a gate.moe_statics module that adds bias to softmax probs - before topk selection, then gathers from original probs. - """ - bias = torch.zeros(E) - - class MockMoeStatics: - def __init__(self, bias_tensor): - self.e_score_correction_bias = bias_tensor - - def __call__(self, probs): - return probs + self.e_score_correction_bias - - gate = SimpleNamespace( - weight=torch.randn(E, H), - top_k=K, - moe_statics=MockMoeStatics(bias), - norm_min=norm_min, - ) - moe_block = SimpleNamespace(gate=gate) - return moe_block, bias, T, H, E, K - + assert out["experts_implementation"] == "eager" -class TestSoftmaxBiasTopkRouting: - """Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing).""" - - def test_output_shapes(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, + def test_scattermoe_impl_requires_flag(self): + out = KernelsArgs.check_experts_implementation( + {"experts_implementation": "scattermoe"} ) + assert out["experts_implementation"] == "eager" - moe_block, _, T, H, E, K = _make_ernie_moe_block() - hidden = torch.randn(T, H) - - scores, token_idx, expert_idx, logits = softmax_bias_topk_routing( - hidden, moe_block + def test_unknown_impl_falls_back_to_eager(self): + out = KernelsArgs.check_experts_implementation( + {"experts_implementation": "not-a-real-impl"} ) + assert out["experts_implementation"] == "eager" - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) + def test_builtin_impls_pass_through(self): + for impl in ("eager", "batched_mm", "grouped_mm"): + out = KernelsArgs.check_experts_implementation( + {"experts_implementation": impl} + ) + assert out["experts_implementation"] == impl - def test_scores_are_float32(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, - ) - moe_block, _, T, H, E, K = _make_ernie_moe_block() - hidden = torch.randn(T, H) +class TestSonicMoERegistration: + """Test that register_sonicmoe_experts plugs into ALL_EXPERTS_FUNCTIONS.""" - scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block) - assert scores.dtype == torch.float32 + def test_register_adds_entry(self): + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS - def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, + sonicmoe_experts_forward_with_lora, ) - moe_block, _, T, H, E, K = _make_ernie_moe_block() - hidden = torch.randn(T, H) + register_sonicmoe_experts() + assert "sonicmoe" in ALL_EXPERTS_FUNCTIONS + assert ALL_EXPERTS_FUNCTIONS["sonicmoe"] is sonicmoe_experts_forward_with_lora - _, token_idx, _, _ = softmax_bias_topk_routing(hidden, moe_block) - diffs = token_idx[1:] - token_idx[:-1] - assert (diffs >= 0).all() + def test_register_is_idempotent(self): + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS - def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, ) - moe_block, _, T, H, E, K = _make_ernie_moe_block() - hidden = torch.randn(T, H) + register_sonicmoe_experts() + register_sonicmoe_experts() + # Just one entry, no error + assert "sonicmoe" in ALL_EXPERTS_FUNCTIONS - _, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block) - assert (expert_idx >= 0).all() - assert (expert_idx < E).all() + def test_register_overrides_upstream(self): + """Axolotl's LoRA-aware variant replaces upstream's plain forward.""" + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + from transformers.integrations.sonicmoe import sonicmoe_experts_forward - def test_renormalized_scores_sum_to_one(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + register_sonicmoe_experts, + sonicmoe_experts_forward_with_lora, ) - moe_block, _, T, H, E, K = _make_ernie_moe_block() - hidden = torch.randn(T, H) + register_sonicmoe_experts() + assert ALL_EXPERTS_FUNCTIONS["sonicmoe"] is sonicmoe_experts_forward_with_lora + assert ALL_EXPERTS_FUNCTIONS["sonicmoe"] is not sonicmoe_experts_forward - scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block) - per_token_sums = scores.reshape(T, K).sum(dim=-1) - assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) - def test_bias_affects_expert_selection(self): - """Large positive bias on expert 0 should make it always selected.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_bias_topk_routing, - ) +class TestMoELoRAMaterialize: + """Verify the LoRA materialization autograd Function used by the registered forward.""" - moe_block, bias, T, H, E, K = _make_ernie_moe_block() - bias[0] = 100.0 # mutate the bias to strongly favor expert 0 - hidden = torch.randn(T, H) + def test_forward_shape_and_identity_with_zero_lora(self): + """W_eff == base when LoRA tensors are zero, regardless of layout convention.""" + from axolotl.integrations.kernels.libs.sonicmoe.lora import MoELoRAMaterialize - _, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block) - expert_idx_2d = expert_idx.reshape(T, K) - for t in range(T): - assert 0 in expert_idx_2d[t] + E, dim1, dim2, r = 4, 8, 6, 2 + base = torch.randn(E, dim1, dim2) + lora_A = torch.zeros(r * E, dim2) + lora_B = torch.zeros(dim1, r * E) + scaling = 0.5 + W_eff = MoELoRAMaterialize.apply(base, lora_A, lora_B, scaling) + assert W_eff.shape == base.shape + torch.testing.assert_close(W_eff, base, atol=1e-6, rtol=1e-6) -# ============================================================================ -# DeepSeek V2: softmax -> group_limited_greedy / greedy -> topk -# ============================================================================ + def test_forward_scaling_linearity(self): + """Doubling scaling should double the LoRA delta.""" + from axolotl.integrations.kernels.libs.sonicmoe.lora import MoELoRAMaterialize + E, dim1, dim2, r = 4, 8, 6, 2 + base = torch.randn(E, dim1, dim2) + lora_A = torch.randn(r * E, dim2) + lora_B = torch.randn(dim1, r * E) -def _make_deepseek_v2_moe_block( - T=8, H=16, E=16, K=4, num_group=2, topk_group=1, topk_method="group_limited_greedy" -): - """Create a mock DeepSeek V2 MoE block for routing tests. + W_1 = MoELoRAMaterialize.apply(base, lora_A, lora_B, 1.0) + W_2 = MoELoRAMaterialize.apply(base, lora_A, lora_B, 2.0) + torch.testing.assert_close(W_2 - base, 2 * (W_1 - base), atol=1e-5, rtol=1e-5) - DeepSeek V2 uses num_group (not n_group), gate is nn.Linear, - and supports greedy / group_limited_greedy topk methods. - """ - gate = SimpleNamespace(weight=torch.randn(E, H)) - moe_block = SimpleNamespace( - gate=gate, - top_k=K, - num_group=num_group, - topk_group=topk_group, - topk_method=topk_method, - routed_scaling_factor=1.0, - ) - return moe_block, T, H, E, K - - -class TestSoftmaxGroupLimitedTopkRouting: - """Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing).""" - - def test_output_shapes_group_limited(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) + def test_forward_matches_peft_einsum(self): + """Delta matches PEFT's ParamWrapper.get_delta_weight einsum convention. - moe_block, T, H, E, K = _make_deepseek_v2_moe_block( - topk_method="group_limited_greedy" - ) - hidden = torch.randn(T, H) + Reference: ``peft.tuners.lora.layer.ParamWrapper.get_delta_weight`` + on PEFT 0.19.x — ``einsum("o r e, e r i -> e o i", B_3d, A_3d)`` where + ``B_3d = lora_B.reshape(dim1, r, E)`` and ``A_3d = lora_A.reshape(E, r, dim2)``. + """ + from axolotl.integrations.kernels.libs.sonicmoe.lora import MoELoRAMaterialize - scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing( - hidden, moe_block - ) + E, dim1, dim2, r = 3, 5, 4, 2 + base = torch.zeros(E, dim1, dim2) + lora_A = torch.randn(r * E, dim2) + lora_B = torch.randn(dim1, r * E) + scaling = 0.7 - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) + W_eff = MoELoRAMaterialize.apply(base, lora_A, lora_B, scaling) - def test_output_shapes_greedy(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) + # PEFT's reference computation + A_3d = lora_A.reshape(E, r, dim2) + B_3d = lora_B.reshape(dim1, r, E) + peft_delta = torch.einsum("o r e, e r i -> e o i", B_3d, A_3d) * scaling - moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy") - hidden = torch.randn(T, H) + torch.testing.assert_close(W_eff, peft_delta, atol=1e-5, rtol=1e-5) - scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing( - hidden, moe_block - ) + def test_gradient_flows_to_lora(self): + from axolotl.integrations.kernels.libs.sonicmoe.lora import MoELoRAMaterialize - assert scores.shape == (T * K,) - assert logits.shape == (T, E) + E, dim1, dim2, r = 4, 8, 6, 2 + base = torch.randn(E, dim1, dim2, requires_grad=False) + lora_A = torch.randn(r * E, dim2, requires_grad=True) + lora_B = torch.randn(dim1, r * E, requires_grad=True) + scaling = 0.5 - def test_scores_are_float32(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) + W_eff = MoELoRAMaterialize.apply(base, lora_A, lora_B, scaling) + loss = W_eff.sum() + loss.backward() - moe_block, T, H, E, K = _make_deepseek_v2_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) - assert scores.dtype == torch.float32 - - def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) - - moe_block, T, H, E, K = _make_deepseek_v2_moe_block() - hidden = torch.randn(T, H) - - _, token_idx, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) - diffs = token_idx[1:] - token_idx[:-1] - assert (diffs >= 0).all() - - def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) - - moe_block, T, H, E, K = _make_deepseek_v2_moe_block() - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block) - assert (expert_idx >= 0).all() - assert (expert_idx < E).all() - - def test_scaling_factor_applied(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) - - moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy") - hidden = torch.randn(T, H) - - scores_1x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) - - moe_block.routed_scaling_factor = 2.5 - scores_2x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) - - assert torch.allclose(scores_2x, scores_1x * 2.5, atol=1e-5) - - def test_group_selection_restricts_experts(self): - """With num_group=4 and topk_group=1, experts should come from selected groups.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) + assert lora_A.grad is not None + assert lora_B.grad is not None + assert lora_A.grad.abs().max() > 0 + assert lora_B.grad.abs().max() > 0 + # Base weight is frozen — no grad expected. + assert base.grad is None - moe_block, T, H, E, K = _make_deepseek_v2_moe_block( - E=16, K=2, num_group=4, topk_group=1, topk_method="group_limited_greedy" + def test_no_lora_returns_base_unchanged(self): + from axolotl.integrations.kernels.libs.sonicmoe.lora import ( + materialize_expert_lora, ) - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block) - expert_idx_2d = expert_idx.reshape(T, K) - group_size = E // moe_block.num_group - for t in range(T): - experts = expert_idx_2d[t] - groups = experts // group_size - # All selected experts should be from the same group - assert (groups == groups[0]).all() - - def test_unsupported_topk_method_raises(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_group_limited_topk_routing, - ) - - moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="invalid") - hidden = torch.randn(T, H) - - with pytest.raises(ValueError, match="unsupported topk_method"): - softmax_group_limited_topk_routing(hidden, moe_block) - -# ============================================================================ -# HunYuan V1 MoE: softmax -> topk -> renorm (via gate.wg) -# ============================================================================ + base = torch.randn(4, 8, 6) + result = materialize_expert_lora(base, None) + assert result is base -def _make_hunyuan_moe_block(T=8, H=16, E=8, K=2): - """Create a mock HunYuan V1 MoE block for routing tests. - - HunYuan V1 uses gate.wg (nn.Linear-like) instead of gate.weight, - and top_k on the moe_block instead of the gate. +class TestExpertsClassMetadata: + """The forward reads `has_gate`/`has_bias`/`is_transposed`/`is_concatenated` + that are set by transformers' @use_experts_implementation decorator. + Verify our forward respects these without an actual CUDA kernel call. """ - wg = SimpleNamespace(weight=torch.randn(E, H)) - gate = SimpleNamespace(wg=wg) - moe_block = SimpleNamespace(gate=gate, top_k=K) - return moe_block, T, H, E, K - -class TestSoftmaxTopkWgRouting: - """Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing).""" - - def test_output_shapes(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, + def test_rejects_non_gated(self): + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + sonicmoe_experts_forward_with_lora, ) - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) + fake_self = SimpleNamespace(has_gate=False) + hidden = torch.zeros(2, 4) + top_k_index = torch.zeros(2, 1, dtype=torch.long) + top_k_weights = torch.ones(2, 1) - scores, token_idx, expert_idx, logits = softmax_topk_wg_routing( - hidden, moe_block - ) - - assert scores.shape == (T * K,) - assert token_idx.shape == (T * K,) - assert expert_idx.shape == (T * K,) - assert logits.shape == (T, E) - - def test_scores_are_float32(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, - ) - - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) - assert scores.dtype == torch.float32 - - def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, - ) - - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) - - _, token_idx, _, _ = softmax_topk_wg_routing(hidden, moe_block) - diffs = token_idx[1:] - token_idx[:-1] - assert (diffs >= 0).all() - - def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, - ) - - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) - - _, _, expert_idx, _ = softmax_topk_wg_routing(hidden, moe_block) - assert (expert_idx >= 0).all() - assert (expert_idx < E).all() - - def test_renormalized_scores_sum_to_one(self): - """HunYuan V1 always renormalizes (no norm_topk_prob flag).""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, - ) - - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) - - scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) - per_token_sums = scores.reshape(T, K).sum(dim=-1) - assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) + with pytest.raises(ValueError, match="has_gate"): + sonicmoe_experts_forward_with_lora( + fake_self, hidden, top_k_index, top_k_weights + ) - def test_uses_gate_wg_weight(self): - """Verify that modifying gate.wg.weight changes the routing output.""" - from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - softmax_topk_wg_routing, + def test_rejects_non_cuda(self): + from axolotl.integrations.kernels.libs.sonicmoe.experts import ( + sonicmoe_experts_forward_with_lora, ) - moe_block, T, H, E, K = _make_hunyuan_moe_block() - hidden = torch.randn(T, H) + fake_self = SimpleNamespace(has_gate=True) + hidden = torch.zeros(2, 4) # CPU tensor + top_k_index = torch.zeros(2, 1, dtype=torch.long) + top_k_weights = torch.ones(2, 1) - scores1, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) - - # Change the wg weight and verify scores change - moe_block.gate.wg.weight = torch.randn(E, H) - scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) - - assert not torch.equal(scores1, scores2) + with pytest.raises(ValueError, match="CUDA"): + sonicmoe_experts_forward_with_lora( + fake_self, hidden, top_k_index, top_k_weights + ) diff --git a/tests/integrations/test_sonicmoe_gradients.py b/tests/integrations/test_sonicmoe_gradients.py deleted file mode 100644 index cb5ef7663d..0000000000 --- a/tests/integrations/test_sonicmoe_gradients.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Gradient correctness tests for SonicMoE routing functions (CPU-only). - -Uses torch.autograd.gradcheck with float32 inputs to match the production -code path where routing happens in float32. -""" - -import torch - -from axolotl.integrations.kernels.libs.sonicmoe.routing import ( - sigmoid_topk_routing, - softmax_topk_routing, -) - -_GC_EPS = 1e-3 -_GC_ATOL = 1e-3 -_GC_RTOL = 1e-3 - - -def _make_softmax_moe_block(weight): - gate = torch.nn.Module() - gate.weight = weight - gate.top_k = 2 - gate.norm_topk_prob = True - - moe_block = torch.nn.Module() - moe_block.gate = gate - return moe_block - - -def _make_sigmoid_moe_block(weight, bias): - gate = torch.nn.Module() - gate.weight = weight - gate.e_score_correction_bias = bias - - moe_block = torch.nn.Module() - moe_block.gate = gate - moe_block.top_k = 2 - moe_block.n_routed_experts = weight.shape[0] - moe_block.n_group = 1 - moe_block.norm_topk_prob = True - moe_block.routed_scaling_factor = 1.0 - return moe_block - - -class TestSoftmaxTopkRoutingGradcheck: - """Numerical gradient verification for softmax_topk_routing.""" - - def test_gradcheck_wrt_gate_weight(self): - T, H, E = 4, 8, 4 - - hidden = torch.randn(T, H, dtype=torch.float32) - - def fn(weight): - moe_block = _make_softmax_moe_block(weight) - scores, _, _, _ = softmax_topk_routing(hidden, moe_block) - return scores - - weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - def test_gradcheck_wrt_hidden_states(self): - T, H, E = 4, 8, 4 - - weight = torch.randn(E, H, dtype=torch.float32) - moe_block = _make_softmax_moe_block(weight) - - def fn(hidden): - scores, _, _, _ = softmax_topk_routing(hidden, moe_block) - return scores - - hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - def test_gradcheck_wrt_router_logits(self): - T, H, E = 4, 8, 4 - - hidden = torch.randn(T, H, dtype=torch.float32) - - def fn(weight): - moe_block = _make_softmax_moe_block(weight) - _, _, _, router_logits = softmax_topk_routing(hidden, moe_block) - return router_logits - - weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - def test_no_norm_variant(self): - T, H, E = 4, 8, 4 - - hidden = torch.randn(T, H, dtype=torch.float32) - - def fn(weight): - moe_block = _make_softmax_moe_block(weight) - moe_block.gate.norm_topk_prob = False - scores, _, _, _ = softmax_topk_routing(hidden, moe_block) - return scores - - weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - -class TestSigmoidTopkRoutingGradcheck: - """Numerical gradient verification for sigmoid_topk_routing.""" - - def test_gradcheck_wrt_gate_weight(self): - T, H, E = 4, 8, 4 - - hidden = torch.randn(T, H, dtype=torch.float32) - bias = torch.zeros(E, dtype=torch.float32) - - def fn(weight): - moe_block = _make_sigmoid_moe_block(weight, bias) - scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - return scores - - weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - def test_gradcheck_wrt_hidden_states(self): - T, H, E = 4, 8, 4 - - weight = torch.randn(E, H, dtype=torch.float32) - bias = torch.zeros(E, dtype=torch.float32) - moe_block = _make_sigmoid_moe_block(weight, bias) - - def fn(hidden): - scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - return scores - - hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck( - fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL - ) - - def test_gradcheck_wrt_bias(self): - T, H, E = 4, 8, 4 - - hidden = torch.randn(T, H, dtype=torch.float32) - weight = torch.randn(E, H, dtype=torch.float32) - - def fn(bias): - moe_block = _make_sigmoid_moe_block(weight, bias) - scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block) - return scores - - bias = torch.zeros(E, dtype=torch.float32, requires_grad=True) - torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL) diff --git a/tests/integrations/test_sonicmoe_lora.py b/tests/integrations/test_sonicmoe_lora.py index 4b25843fee..7b1cf1765d 100644 --- a/tests/integrations/test_sonicmoe_lora.py +++ b/tests/integrations/test_sonicmoe_lora.py @@ -15,7 +15,6 @@ has_lora, materialize_expert_lora, unwrap_experts_lora, - unwrap_gate_lora, ) # ============================================================================= @@ -44,21 +43,6 @@ def _make_mock_lora_module(weight_A, weight_B, scaling_val, param_name=None): return mock -def _make_peft_gate(hidden_size, num_experts, rank, scaling=0.5): - """Create a mock PEFT-wrapped gate module.""" - base_gate = MagicMock() - base_gate.weight = torch.randn(num_experts, hidden_size) - base_gate.top_k = 2 - base_gate.norm_topk_prob = True - - lora_A = torch.randn(rank, hidden_size) - lora_B = torch.randn(num_experts, rank) - - wrapper = _make_mock_lora_module(lora_A, lora_B, scaling) - wrapper.base_layer = base_gate - return wrapper, base_gate - - def _make_peft_experts( num_experts, gate_up_dim, down_dim, hidden_size, rank, scaling=0.5 ): @@ -134,39 +118,6 @@ def test_no_active_adapters(self): assert get_lora_params_from_wrapper(wrapper) == (None, None, None) -# ============================================================================= -# Tests: unwrap_gate_lora -# ============================================================================= - - -class TestUnwrapGateLora: - def test_plain_gate(self): - gate = MagicMock(spec=["weight", "top_k"]) - del gate.base_layer - del gate.lora_A - gate.weight = torch.randn(8, 64) - base, weight, delta = unwrap_gate_lora(gate) - assert base is gate - assert torch.equal(weight, gate.weight) - assert delta is None - - def test_wrapped_gate(self): - wrapper, base_gate = _make_peft_gate( - hidden_size=64, num_experts=8, rank=4, scaling=0.5 - ) - base, weight, delta = unwrap_gate_lora(wrapper) - assert base is base_gate - assert torch.equal(weight, base_gate.weight) - assert delta is not None - assert delta.shape == base_gate.weight.shape - - # Verify delta = scaling * B @ A - lora_A = wrapper.lora_A["default"].weight - lora_B = wrapper.lora_B["default"].weight - expected = 0.5 * (lora_B @ lora_A) - assert torch.allclose(delta, expected) - - # ============================================================================= # Tests: unwrap_experts_lora # ============================================================================= From d2e42fd4dc12653fefee92f990774c3da7691686 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 May 2026 15:54:57 +0700 Subject: [PATCH 2/5] chore: add to optim doc --- docs/optimizations.qmd | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd index 720519ec03..621c2c0bcd 100644 --- a/docs/optimizations.qmd +++ b/docs/optimizations.qmd @@ -75,11 +75,12 @@ Provides efficient Triton kernels to improve training speed and reduce memory us ### Expert Kernels -Optimized kernel implementations for Mixture of Experts (MoE) model training. +Optimized per-expert grouped-GEMM kernels for MoE training, with LoRA support. -- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support. -- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs. +- **ScatterMoE**: Triton, any CUDA GPU. +- **SonicMoE**: CUTLASS / cute-DSL, Hopper+ only. +- **Config:** `use_scattermoe: true` or `use_sonicmoe: true` - **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration) ## Long Context Models @@ -117,7 +118,7 @@ To train models that don't fit on a single GPU, you'll need to use a distributed ### N-D Parallelism (Beta) -For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once. +For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence, Expert Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once. - **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd) From a2815b91dfb626c67b3d2c87d9bdffcecdbacc8e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 May 2026 15:55:11 +0700 Subject: [PATCH 3/5] feat: update sonicmoe version --- src/axolotl/integrations/kernels/README.md | 8 +++++--- src/axolotl/integrations/kernels/plugin.py | 21 ++++++++++++++++++++ tests/e2e/integrations/test_sonicmoe.py | 10 ++++------ tests/e2e/integrations/test_sonicmoe_lora.py | 13 ++++++------ 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 338381054a..3ec5b55898 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -39,13 +39,15 @@ use_sonicmoe: true - PyTorch 2.7+ - For B300: Triton 3.6.x -Sonic-MoE itself is loaded lazily from the HF [`kernels-community/sonic-moe`](https://huggingface.co/kernels-community/sonic-moe) hub on first use via the `kernels` package — no manual install is needed for the runtime. For from-source development: +Install [`sonic-moe`](https://github.com/Dao-AILab/sonic-moe) `>= 0.1.2` from source: ```bash -pip install --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git" \ - "nvidia-cutlass-dsl>=4.4.0" "quack-kernels>=0.3.0" +pip install --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@0.1.2" \ + "nvidia-cutlass-dsl==4.4.2" "quack-kernels>=0.3.11" ``` +The plugin checks the installed version at startup and raises if it's below `0.1.2`. + **Note:** Blackwell support is in upstream beta. On Blackwell GPUs Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels. ## How It Works diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index f105e5ecff..f00f732652 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,13 +1,33 @@ import importlib import os +from importlib.metadata import PackageNotFoundError, version as _pkg_version import torch +from packaging.version import Version from axolotl.integrations.base import BasePlugin from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +_SONICMOE_MIN_VERSION = "0.1.2" + + +def _check_sonicmoe_version(): + """Require sonic-moe >= 0.1.2 for the ``concat_layout=True`` path.""" + try: + installed = _pkg_version("sonic-moe") + except PackageNotFoundError as err: + raise RuntimeError( + f"sonic-moe is not installed. Install >= {_SONICMOE_MIN_VERSION} from " + "https://github.com/Dao-AILab/sonic-moe." + ) from err + if Version(installed) < Version(_SONICMOE_MIN_VERSION): + raise RuntimeError( + f"sonic-moe {installed} is too old; require >= {_SONICMOE_MIN_VERSION}. " + "Upgrade from https://github.com/Dao-AILab/sonic-moe." + ) + def _check_sonicmoe_gpu_compat(): """Validate GPU compute capability for SonicMoE and configure env. @@ -74,6 +94,7 @@ def pre_model_load(self, cfg): cfg.experts_implementation = "scattermoe" LOG.info("Registered 'scattermoe' in transformers ExpertsInterface") elif cfg.use_sonicmoe: + _check_sonicmoe_version() _check_sonicmoe_gpu_compat() from axolotl.integrations.kernels.libs.sonicmoe.experts import ( diff --git a/tests/e2e/integrations/test_sonicmoe.py b/tests/e2e/integrations/test_sonicmoe.py index 021b41dc44..b74e570d02 100644 --- a/tests/e2e/integrations/test_sonicmoe.py +++ b/tests/e2e/integrations/test_sonicmoe.py @@ -1,18 +1,16 @@ -""" -End-to-end gradient and convergence tests for SonicMoE integration. +"""End-to-end gradient and convergence tests for SonicMoE integration. -After the ExpertsInterface refactor, the flow is: +Flow: register_sonicmoe_experts() # plug into ALL_EXPERTS_FUNCTIONS config._experts_implementation = "sonicmoe" model = AutoModelForCausalLM.from_config(config) # transformers dispatches -No weight interleaving needed (cute-DSL ``concat_layout=True``); no per-arch -SparseMoEBlock monkeypatching. +No weight interleaving needed (``concat_layout=True``). Requires: - Hopper (sm_90) or Blackwell (sm_100+) GPU - - sonicmoe kernel available via HF kernels-community + - sonic-moe >= 0.1.2 installed from source - transformers >= 5.8 with Qwen3MoE Experts class """ diff --git a/tests/e2e/integrations/test_sonicmoe_lora.py b/tests/e2e/integrations/test_sonicmoe_lora.py index 240b015141..cc58f4dccd 100644 --- a/tests/e2e/integrations/test_sonicmoe_lora.py +++ b/tests/e2e/integrations/test_sonicmoe_lora.py @@ -2,23 +2,22 @@ # Copyright (c) Axolotl AI # Licensed under the Apache License, Version 2.0 -""" -End-to-end tests for SonicMoE + LoRA after the ExpertsInterface refactor. +"""End-to-end tests for SonicMoE + LoRA. -The new flow: +Flow: register_sonicmoe_experts() # plug into ALL_EXPERTS_FUNCTIONS config._experts_implementation = "sonicmoe" model = AutoModelForCausalLM.from_config(config) model = get_peft_model(model, lora_config) # PEFT wraps params/modules -Our registered ``sonicmoe_experts_forward_with_lora`` detects the PEFT -wrappers and materializes ``W_eff = W + scaling * (B @ A)`` via -:class:`MoELoRAMaterialize`, so adapters train through the CUTLASS kernels. +``sonicmoe_experts_forward_with_lora`` detects the PEFT wrappers and +materializes ``W_eff = W + scaling * (B @ A)`` via :class:`MoELoRAMaterialize`, +so adapters train through the CUTLASS kernels. Requires: - Hopper (sm_90) or Blackwell (sm_100+) GPU - - sonicmoe kernel available via HF kernels-community + - sonic-moe >= 0.1.2 installed from source - peft installed - transformers >= 5.8 with Qwen3MoE Experts class """ From 8ee58b1e7e992fda56406bc924ffe475ccfe6136 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 22 May 2026 15:10:54 +0700 Subject: [PATCH 4/5] chore: cleanup with DEEPEP and kernels compat --- src/axolotl/integrations/kernels/README.md | 11 +++----- src/axolotl/integrations/kernels/args.py | 22 ++++++++++++++- src/axolotl/integrations/kernels/plugin.py | 32 ++++++---------------- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 3ec5b55898..aae2fee97f 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -39,15 +39,12 @@ use_sonicmoe: true - PyTorch 2.7+ - For B300: Triton 3.6.x -Install [`sonic-moe`](https://github.com/Dao-AILab/sonic-moe) `>= 0.1.2` from source: +The sonic-moe kernel ships through the HF [`kernels`](https://github.com/huggingface/kernels) package. Transformers v5.8+ auto-fetches a prebuilt kernel from [`kernels-community/sonic-moe`](https://huggingface.co/kernels-community/sonic-moe) on first use: ```bash -pip install --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@0.1.2" \ - "nvidia-cutlass-dsl==4.4.2" "quack-kernels>=0.3.11" +pip install kernels "nvidia-cutlass-dsl==4.4.2" ``` -The plugin checks the installed version at startup and raises if it's below `0.1.2`. - **Note:** Blackwell support is in upstream beta. On Blackwell GPUs Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels. ## How It Works @@ -87,9 +84,9 @@ Any model whose `Experts` class is decorated with `@use_experts_implementation` | `ernie4_5_moe` | Yes | Yes | | `hunyuan_v1_moe` | Yes | Yes | | `gemma4_text` | Yes | Yes | -| `gpt_oss` | Yes | Yes | +| `gpt_oss` | No | Yes | -For `gpt_oss` the upstream decorator carries `is_concatenated=False, is_transposed=True, has_bias=True`; the registered forward reads these flags off `self` and adjusts permutation / bias handling accordingly. +`gpt_oss` carries the decorator with `is_concatenated=False, is_transposed=True, has_bias=True` and uses a sigmoid-GLU activation with clamping. The SonicMoE forward reads these flags off `self` and dispatches accordingly. The ScatterMoE forward assumes the standard `[E, 2*I, H]` concat layout and SiLU-GLU without bias, so it does not yet support `gpt_oss`. ## Feature comparison diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index bff9e85081..2cf6d2f3ba 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -11,9 +11,18 @@ # - "grouped_mm" : transformers' built-in grouped matmul path (cache-efficient) # - "scattermoe" : axolotl-registered Triton kernels with LoRA support # - "sonicmoe" : axolotl-registered CUTLASS / cute-DSL kernels with LoRA support +# - "deep_ep[_*]": EP-plugin composites; passed through when expert_parallel_size > 1 _BUILTIN_EXPERTS_IMPLS = {"eager", "batched_mm", "grouped_mm"} _KERNEL_EXPERTS_IMPLS = {"scattermoe", "sonicmoe"} -_VALID_EXPERTS_IMPLS = _BUILTIN_EXPERTS_IMPLS | _KERNEL_EXPERTS_IMPLS +_EP_EXPERTS_IMPLS = { + "deep_ep", + "deep_ep_grouped_mm", + "deep_ep_scattermoe", + "deep_ep_sonicmoe", +} +_VALID_EXPERTS_IMPLS = ( + _BUILTIN_EXPERTS_IMPLS | _KERNEL_EXPERTS_IMPLS | _EP_EXPERTS_IMPLS +) class KernelsArgs(BaseModel): @@ -41,6 +50,17 @@ def check_use_kernels(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_sonicmoe_ep_unsupported(cls, data): + """SonicMoE + EP is not yet implemented (EP `_sonicmoe_local` raises).""" + if data.get("use_sonicmoe") and (data.get("expert_parallel_size") or 1) > 1: + raise ValueError( + "use_sonicmoe=true is not supported with expert_parallel_size > 1. " + "Use use_scattermoe=true under EP, or set expert_parallel_size=1." + ) + return data + @model_validator(mode="before") @classmethod def check_experts_implementation(cls, data): diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index f00f732652..ddddb160f7 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,33 +1,13 @@ import importlib import os -from importlib.metadata import PackageNotFoundError, version as _pkg_version import torch -from packaging.version import Version from axolotl.integrations.base import BasePlugin from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -_SONICMOE_MIN_VERSION = "0.1.2" - - -def _check_sonicmoe_version(): - """Require sonic-moe >= 0.1.2 for the ``concat_layout=True`` path.""" - try: - installed = _pkg_version("sonic-moe") - except PackageNotFoundError as err: - raise RuntimeError( - f"sonic-moe is not installed. Install >= {_SONICMOE_MIN_VERSION} from " - "https://github.com/Dao-AILab/sonic-moe." - ) from err - if Version(installed) < Version(_SONICMOE_MIN_VERSION): - raise RuntimeError( - f"sonic-moe {installed} is too old; require >= {_SONICMOE_MIN_VERSION}. " - "Upgrade from https://github.com/Dao-AILab/sonic-moe." - ) - def _check_sonicmoe_gpu_compat(): """Validate GPU compute capability for SonicMoE and configure env. @@ -85,16 +65,21 @@ def pre_model_load(self, cfg): Architecture-agnostic: routing stays in each model's SparseMoEBlock; only the experts call is dispatched through the registry. """ + # When EP is active, the ExpertParallelPlugin selects a `deep_ep_*` + # composite for `experts_implementation`. Don't overwrite that here — + # plugin order is YAML-defined, so we can't rely on EP running last. + ep_active = (getattr(cfg, "expert_parallel_size", 1) or 1) > 1 + if cfg.use_scattermoe: from axolotl.integrations.kernels.libs.scattermoe_lora.experts import ( register_scattermoe_experts, ) register_scattermoe_experts() - cfg.experts_implementation = "scattermoe" + if not ep_active: + cfg.experts_implementation = "scattermoe" LOG.info("Registered 'scattermoe' in transformers ExpertsInterface") elif cfg.use_sonicmoe: - _check_sonicmoe_version() _check_sonicmoe_gpu_compat() from axolotl.integrations.kernels.libs.sonicmoe.experts import ( @@ -102,7 +87,8 @@ def pre_model_load(self, cfg): ) register_sonicmoe_experts() - cfg.experts_implementation = "sonicmoe" + if not ep_active: + cfg.experts_implementation = "sonicmoe" LOG.info("Registered 'sonicmoe' in transformers ExpertsInterface") def add_callbacks_pre_trainer(self, cfg, model): From 1f3c5d236eeab1d37f7b46964a6b4873f701ec40 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 May 2026 13:54:32 -0400 Subject: [PATCH 5/5] gate/guard model expert setup --- .../kernels/libs/scattermoe_lora/experts.py | 16 ++++++++++++++++ src/axolotl/loaders/model.py | 19 +++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py index ba11249ec4..9199c8a595 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/experts.py @@ -129,6 +129,22 @@ def scattermoe_experts_forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: """ScatterMoE experts forward with fused-LoRA support.""" + # Assumes the standard expert layout: gate_up concatenated as [E, 2I, H], + # gated SwiGLU, no expert bias. gpt_oss-style experts (interleaved gate/up, + # transposed [E, H, 2I], expert bias) would be silently miscomputed by the + # fixed transpose/chunk below, so reject rather than corrupt training. + if ( + getattr(self, "is_transposed", False) + or not getattr(self, "is_concatenated", True) + or getattr(self, "has_bias", False) + or not getattr(self, "has_gate", True) + ): + raise NotImplementedError( + "scattermoe supports only concatenated, non-transposed, gated, biasless " + "experts (qwen/mixtral/deepseek/glm/...). This model's experts use an " + "unsupported layout; use use_sonicmoe or a built-in experts_implementation." + ) + K = top_k_index.shape[1] routing_weights = top_k_weights.to(hidden_states.dtype) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 9634c1b936..2f056e19f8 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -301,8 +301,23 @@ def _reinitialize_classification_head(self): ) def _configure_experts_implementation(self): - if self.cfg.experts_implementation is not None: - self.model.set_experts_implementation(self.cfg.experts_implementation) + impl = self.cfg.experts_implementation + if impl is None: + return + + if impl in ("scattermoe", "sonicmoe"): + model_classes = { + type(m) for m in self.model.modules() if isinstance(m, PreTrainedModel) + } + if not any(cls._can_set_experts_implementation() for cls in model_classes): + LOG.warning( + f"experts_implementation={impl!r} requested, but no submodule of " + f"{type(self.model).__name__} uses transformers' ExpertsInterface " + "(@use_experts_implementation). The kernel will NOT be applied; " + "training falls back to the model's native experts path." + ) + + self.model.set_experts_implementation(impl) def _apply_activation_checkpointing(self): if self.cfg.activation_offloading is True: