Skip to content

feat: add sonicmoe#3411

Merged
winglian merged 20 commits into
mainfrom
feat/sonicmoe
Mar 5, 2026
Merged

feat: add sonicmoe#3411
winglian merged 20 commits into
mainfrom
feat/sonicmoe

Conversation

@NanoCode012
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 commented Feb 15, 2026

Description

Adds sonicmoe kernel integration with per MoE class routing. Faster than scattermoe, less vram cost as well. But, requires Hopper, see https://github.com/Dao-AILab/sonic-moe/tree/main?tab=readme-ov-file#-installation

How it works: replaces forward for MoE sparse block with our custom wrapper and reshapes weight to follow sonicmoe's requirements. It would apply model-specific routing -> moe kernel if fused not available.

How to test meanwhile (require Hopper and pip install below):

plugins:
  - axolotl.integrations.kernels.KernelsPlugin

use_kernels: true

# pip install git+https://github.com/Dao-AILab/sonic-moe@022992fef6a6aee53e0c3ba709e22f740cec547e
use_sonicmoe: true

TODOS:

  • Test general routings
    • softmax -> topk (qwen series)
    • sigmoid -> topk (glm series)
  • Follow up on fused softmax -> topk Support for Qwen3-style MoE router ? Dao-AILab/sonic-moe#5
    • While fused, also need to check handling of shared experts (qwen2 specific)
    • Performance might not be too different as routing's < ~1% of moe kernel computation
  • Test regular fused forward
    • topk -> softmax (gpt-oss, but it has some other gpt-oss specific integration issues)
  • GPT-OSS requires extra weight transpose, custom GLU implementation (hard, need fork), handle routing bias
  • Review whether with norm_topk_prob=True, softmax -> topk is equivalent to the fused topk -> softmax
    • it is equivalent in forward pass but backward is diff (softmax dense grad experts in one, sparse in other)

Extras:

  • torch.compile routing
    • test
  • fuse general routings (not much more better than torch compile)
  • cutlass softmax -> topk kernel
  • custom GLU in gpt-oss (need to see if without the results are close)
  • cutlass sigmoid -> topk kernel

Limitations (unplanned):

  • No jitter handling

Motivation and Context

How has this been tested?

Not yet tested, but verified routing against existing modeling code.

AI Usage Disclaimer

Manual initial routing -> Claude integration

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for SonicMoE as an alternative MoE kernel optimization alongside existing ScatterMoE option
    • Introduced mutually exclusive configuration flags for kernel selection
  • Documentation

    • Expanded kernel integration documentation with SonicMoE setup instructions, prerequisites, and model compatibility details
    • Added GPU compatibility notes and installation requirements
  • Bug Fixes

    • Fixed MLP kernel disabling logic to account for multiple MoE kernel options

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 15, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b93676dd-6e90-48fc-9718-5e79a22589b4

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SonicMoE kernel integration with mutual exclusivity against ScatterMoE, including dynamic MoE block resolution, GPU compatibility checks, forward patching, routing implementations, weight converters, and comprehensive unit and end-to-end tests.

Changes

Cohort / File(s) Summary
Core SonicMoE Infrastructure
src/axolotl/integrations/kernels/args.py, src/axolotl/integrations/kernels/constants.py, src/axolotl/integrations/kernels/plugin.py
Added mutual exclusivity validation between use_scattermoe and use_sonicmoe; introduced dynamic MoE block class resolution via SPARSE_MOE_BLOCK mapping; added GPU compatibility checks, pre_model_load handling, and replaced hard-coded model blocks with resolved classes.
SonicMoE Patch & Routing Modules
src/axolotl/integrations/kernels/sonicmoe/*
Introduced SonicMoE patching with forward monkeypatching, two routing strategies (softmax_topk and sigmoid_topk), and weight interleaving/de-interleaving converters for gate/up projections.
Documentation
src/axolotl/integrations/kernels/README.md
Updated to document SonicMoE support alongside ScatterMoE, added installation prerequisites, GPU requirements, workflow details, and limitations specific to each kernel type.
Test Updates
tests/integrations/test_scattermoe_lora.py
Renamed validator method references from disable_mlp_kernel_scattermoe to disable_mlp_kernel.
SonicMoE Unit Tests
tests/integrations/test_sonicmoe.py, tests/integrations/test_sonicmoe_gradients.py
Added comprehensive unit tests for KernelsArgs mutual exclusivity, weight converter registration, routing function behavior, and gradient correctness across mock MoE block implementations.
SonicMoE E2E Tests
tests/e2e/integrations/test_sonicmoe.py
Added end-to-end GPU tests validating forward correctness, gradient consistency, training convergence, and expert weight updates with Qwen3MoE under SonicMoE patching.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • djsaunde
  • winglian
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.88% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add sonicmoe' accurately and concisely describes the main change: adding SonicMoE kernel integration to the codebase.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/sonicmoe

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@NanoCode012 NanoCode012 marked this pull request as ready for review March 3, 2026 14:51
@NanoCode012 NanoCode012 added the scheduled_release This PR is slated for the upcoming release label Mar 3, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 3, 2026

📖 Documentation Preview: https://69a8006696e3e160904ba462--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit c3c1a16

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (4)
src/axolotl/integrations/kernels/sonicmoe/patch.py (1)

112-132: Weight permutation occurs on every forward pass.

At Lines 115-116 and 165-166, the weight tensors are permuted on every forward call. For large models with many forward passes, this repeated permutation could impact performance. Consider caching the permuted weights on the module during the first forward pass.

⚡ Proposed optimization
     def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         from sonicmoe import moe_general_routing_inputs

         batch_size, sequence_length, hidden_dim = hidden_states.shape
         hidden_states_flat = hidden_states.view(-1, hidden_dim)

         # Shared expert (computed early, matching original model ordering)
         shared_expert_output = _compute_shared_expert(self, hidden_states_flat)

         # Routing
         router_scores, token_indices, expert_indices, _router_logits = routing_fn(
             hidden_states_flat, self
         )

-        # Permute weights to SonicMoE layout:
-        #   gate_up: [E, 2*I, H] -> [2*I, H, E]
-        #   down:    [E, H, I]   -> [H, I, E]
-        gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
-        down_weight = self.experts.down_proj.permute(1, 2, 0)
+        # Permute weights to SonicMoE layout (cached on first call):
+        #   gate_up: [E, 2*I, H] -> [2*I, H, E]
+        #   down:    [E, H, I]   -> [H, I, E]
+        if not hasattr(self, "_sonicmoe_gate_up_weight"):
+            self._sonicmoe_gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0).contiguous()
+            self._sonicmoe_down_weight = self.experts.down_proj.permute(1, 2, 0).contiguous()
+        gate_up_weight = self._sonicmoe_gate_up_weight
+        down_weight = self._sonicmoe_down_weight
         E = gate_up_weight.shape[-1]

Note: This assumes weights don't change during training. If LoRA or other adapters modify the base weights, this caching strategy would need adjustment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/sonicmoe/patch.py` around lines 112 - 132,
The permute operations on self.experts.gate_up_proj and self.experts.down_proj
are being done every forward; cache the permuted tensors on the module (e.g.,
self._cached_gate_up_weight and self._cached_down_weight) and only compute them
if the cached attributes are None or if the original parameter tensors have
changed (compare .data_ptr() or a version flag); then pass the
cached_gate_up_weight and cached_down_weight into moe_general_routing_inputs
instead of recomputing each call (update this logic around the forward code that
currently calls self.experts.gate_up_proj.permute and
self.experts.down_proj.permute and before the moe_general_routing_inputs
invocation).
src/axolotl/integrations/kernels/sonicmoe/weight_converter.py (1)

42-44: The dim parameter is stored but not used in the conversion.

Both ConcatenatedToInterleaved and InterleavedToConcatenated accept a dim parameter in their constructors, but the actual convert methods call interleave_gate_up/deinterleave_gate_up which always operate along a fixed dimension pattern. If dim != 1 were ever passed, the conversion would still operate on the same dimension.

Either remove the dim parameter if it's not needed, or modify the interleave functions to respect it.

Also applies to: 92-94

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/sonicmoe/weight_converter.py` around lines
42 - 44, The constructors for ConcatenatedToInterleaved and
InterleavedToConcatenated currently store dim but never use it; update the
conversion flow so the chosen dim is respected: either remove the unused dim
parameter from __init__ of both classes (and any callers) if variable dimension
support is not needed, or modify interleave_gate_up and deinterleave_gate_up to
accept a dim argument and change the convert methods in
ConcatenatedToInterleaved.convert and InterleavedToConcatenated.convert to call
interleave_gate_up(..., dim=self.dim) / deinterleave_gate_up(..., dim=self.dim)
instead of the current fixed-dimension calls so self.dim is actually applied.
src/axolotl/integrations/kernels/sonicmoe/routing.py (1)

177-192: Group-based selection assumes E is divisible by n_group.

The view operation at Line 180 (scores_for_choice.view(-1, n_group, E // n_group)) will silently truncate experts if E is not evenly divisible by n_group, potentially causing incorrect routing. While this is likely guaranteed by the model architecture, adding a defensive assertion would catch configuration errors early.

🛡️ Proposed defensive check
     # Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
     if n_group > 1:
+        assert E % n_group == 0, f"n_routed_experts ({E}) must be divisible by n_group ({n_group})"
         group_scores = (
             scores_for_choice.view(-1, n_group, E // n_group)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/sonicmoe/routing.py` around lines 177 - 192,
The group-based selection assumes E is divisible by n_group; before the view
call that reshapes scores_for_choice (the block using scores_for_choice.view(-1,
n_group, E // n_group) inside the routing logic), add a defensive check that E %
n_group == 0 and raise a clear error (or assertion) if not, mentioning E,
n_group and moe_block.topk_group so misconfiguration is caught early and the
silent truncation is prevented.
tests/integrations/test_sonicmoe.py (1)

215-218: test_register_unsupported_model_type_warns needs an assertion on warning output.

Right now this test is pass-through and won’t catch regressions in warning behavior.

💡 Proposed fix
-    def test_register_unsupported_model_type_warns(self):
-        # A model type with no conversion mapping should warn but not raise
-        register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
+    def test_register_unsupported_model_type_warns(self, caplog):
+        # A model type with no conversion mapping should warn but not raise
+        with caplog.at_level("WARNING"):
+            register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
+        assert any(
+            "No conversion mapping found for model type" in msg
+            for msg in caplog.messages
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integrations/test_sonicmoe.py` around lines 215 - 218, The test
test_register_unsupported_model_type_warns currently calls
register_sonicmoe_weight_converter("nonexistent_model_type_xyz") without
asserting any warning; update it to capture and assert the warning is emitted
(e.g., using pytest.warns or caplog) and verify the warning message contains a
clear indicator like "unsupported" or the passed model type; ensure the
assertion references the test name test_register_unsupported_model_type_warns
and the function register_sonicmoe_weight_converter so future regressions in
warning behavior are caught.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/integrations/kernels/plugin.py`:
- Around line 19-20: The function _check_sonicmoe_gpu_compat currently returns
silently when torch.cuda.is_available() is False; change it to fail fast by
checking if use_sonicmoe is enabled and, if so, raise a clear RuntimeError (or
similar) indicating SonicMoE requires a CUDA-capable GPU and CUDA is not
available; update the _check_sonicmoe_gpu_compat function (and any callers if
needed) to perform this conditional check and raise with a descriptive message
so users see the failure immediately instead of later during execution.

In `@src/axolotl/integrations/kernels/sonicmoe/routing.py`:
- Around line 78-126: In softmax_topk_routing, avoid a potential
division-by-zero when renormalizing top_values (the block using
gate.norm_topk_prob); change the renorm to divide by top_values.sum(dim=-1,
keepdim=True) plus a tiny epsilon (e.g. 1e-20) so the denominator can never be
zero, mirroring the sigmoid variant's protection—update the renormalization
expression where top_values is reassigned.

In `@tests/e2e/integrations/test_sonicmoe.py`:
- Around line 230-233: Tests that patch_sonicmoe("qwen3_moe") are missing the
gate-up weight interleaving step, causing training-path tests to use a different
expert weight layout than forward/gradient parity checks; call the helper that
performs the interleave after patching and before optimizer/forward steps by
invoking _interleave_gate_up_weights(model) (the same function used in the
forward/gradient parity tests) so both training-path tests (the block around
model = AutoModelForCausalLM.from_config(...); patch_sonicmoe("qwen3_moe");
optimizer = ...) and the analogous block at lines ~260-262 perform the
interleave and use the correct expert weight layout for convergence/update
checks.

In `@tests/integrations/test_sonicmoe.py`:
- Around line 164-171: The test creates an OOM-prone case (E,I,H) =
(128,768,2048) which builds huge tensors (concat/interleaved/recovered); update
test_various_shapes to remove or downscale that tuple: replace (128,768,2048)
with a smaller shape (for example (32,384,1024)) or drop it entirely so the loop
over E,I,H uses only safe sizes; locate the loop using variables E, I, H and the
tensors concat and calls to fwd.convert and rev.convert and change the tuple
list accordingly to avoid CI memory failures.

---

Nitpick comments:
In `@src/axolotl/integrations/kernels/sonicmoe/patch.py`:
- Around line 112-132: The permute operations on self.experts.gate_up_proj and
self.experts.down_proj are being done every forward; cache the permuted tensors
on the module (e.g., self._cached_gate_up_weight and self._cached_down_weight)
and only compute them if the cached attributes are None or if the original
parameter tensors have changed (compare .data_ptr() or a version flag); then
pass the cached_gate_up_weight and cached_down_weight into
moe_general_routing_inputs instead of recomputing each call (update this logic
around the forward code that currently calls self.experts.gate_up_proj.permute
and self.experts.down_proj.permute and before the moe_general_routing_inputs
invocation).

In `@src/axolotl/integrations/kernels/sonicmoe/routing.py`:
- Around line 177-192: The group-based selection assumes E is divisible by
n_group; before the view call that reshapes scores_for_choice (the block using
scores_for_choice.view(-1, n_group, E // n_group) inside the routing logic), add
a defensive check that E % n_group == 0 and raise a clear error (or assertion)
if not, mentioning E, n_group and moe_block.topk_group so misconfiguration is
caught early and the silent truncation is prevented.

In `@src/axolotl/integrations/kernels/sonicmoe/weight_converter.py`:
- Around line 42-44: The constructors for ConcatenatedToInterleaved and
InterleavedToConcatenated currently store dim but never use it; update the
conversion flow so the chosen dim is respected: either remove the unused dim
parameter from __init__ of both classes (and any callers) if variable dimension
support is not needed, or modify interleave_gate_up and deinterleave_gate_up to
accept a dim argument and change the convert methods in
ConcatenatedToInterleaved.convert and InterleavedToConcatenated.convert to call
interleave_gate_up(..., dim=self.dim) / deinterleave_gate_up(..., dim=self.dim)
instead of the current fixed-dimension calls so self.dim is actually applied.

In `@tests/integrations/test_sonicmoe.py`:
- Around line 215-218: The test test_register_unsupported_model_type_warns
currently calls register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
without asserting any warning; update it to capture and assert the warning is
emitted (e.g., using pytest.warns or caplog) and verify the warning message
contains a clear indicator like "unsupported" or the passed model type; ensure
the assertion references the test name
test_register_unsupported_model_type_warns and the function
register_sonicmoe_weight_converter so future regressions in warning behavior are
caught.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2b6f4a6 and 466e1d5.

📒 Files selected for processing (12)
  • src/axolotl/integrations/kernels/README.md
  • src/axolotl/integrations/kernels/args.py
  • src/axolotl/integrations/kernels/constants.py
  • src/axolotl/integrations/kernels/plugin.py
  • src/axolotl/integrations/kernels/sonicmoe/__init__.py
  • src/axolotl/integrations/kernels/sonicmoe/patch.py
  • src/axolotl/integrations/kernels/sonicmoe/routing.py
  • src/axolotl/integrations/kernels/sonicmoe/weight_converter.py
  • tests/e2e/integrations/test_sonicmoe.py
  • tests/integrations/test_scattermoe_lora.py
  • tests/integrations/test_sonicmoe.py
  • tests/integrations/test_sonicmoe_gradients.py

Comment on lines +19 to +20
if not torch.cuda.is_available():
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# First, check if file exists and get content around lines 19-20
echo "=== File content around lines 19-20 ==="
cat -n src/axolotl/integrations/kernels/plugin.py | head -40

echo ""
echo "=== Full search for sonicmoe references ==="
rg -n -C3 '_check_sonicmoe_gpu_compat|use_sonicmoe|torch\.cuda\.is_available' src/axolotl/integrations/kernels/plugin.py || true

Repository: axolotl-ai-cloud/axolotl

Length of output: 2614


Fail fast when SonicMoE is enabled but CUDA is unavailable.

The _check_sonicmoe_gpu_compat() function returns silently when torch.cuda.is_available() is False (lines 19-20), allowing use_sonicmoe=True to proceed on CPU-only systems and fail later during execution. Since SonicMoE requires a CUDA-capable GPU (validated by subsequent GPU capability checks), this should raise immediately with a clear error.

Proposed fix
     if not torch.cuda.is_available():
-        return
+        raise RuntimeError("SonicMoE requires a CUDA-capable GPU.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not torch.cuda.is_available():
return
if not torch.cuda.is_available():
raise RuntimeError("SonicMoE requires a CUDA-capable GPU.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/plugin.py` around lines 19 - 20, The
function _check_sonicmoe_gpu_compat currently returns silently when
torch.cuda.is_available() is False; change it to fail fast by checking if
use_sonicmoe is enabled and, if so, raise a clear RuntimeError (or similar)
indicating SonicMoE requires a CUDA-capable GPU and CUDA is not available;
update the _check_sonicmoe_gpu_compat function (and any callers if needed) to
perform this conditional check and raise with a descriptive message so users see
the failure immediately instead of later during execution.

Comment on lines +78 to +126
def softmax_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.

Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.*)

Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = gate.top_k

# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]

# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each

# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)

# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)

# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
# Expert sorting is handled internally by general_routing_router_metadata.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)

flat_scores = top_values.reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]

return flat_scores, flat_token_idx, flat_expert_idx, router_logits
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential division by zero in renormalization.

At Line 107, when norm_topk_prob is True, the code divides by top_values.sum(dim=-1, keepdim=True). If all top-k values are zero (which could theoretically happen with extreme logits), this would cause a division by zero. The sigmoid variant at Line 203 correctly adds an epsilon (1e-20) to prevent this.

Consider adding similar protection for consistency:

🛡️ Proposed fix
     # Renormalize if configured (default True for models without the attribute,
     # e.g. Mixtral/MiniMax which always normalize)
     if getattr(gate, "norm_topk_prob", True):
-        top_values = top_values / top_values.sum(dim=-1, keepdim=True)
+        top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 94-94: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/sonicmoe/routing.py` around lines 78 - 126,
In softmax_topk_routing, avoid a potential division-by-zero when renormalizing
top_values (the block using gate.norm_topk_prob); change the renorm to divide by
top_values.sum(dim=-1, keepdim=True) plus a tiny epsilon (e.g. 1e-20) so the
denominator can never be zero, mirroring the sigmoid variant's protection—update
the renormalization expression where top_values is reassigned.

Comment thread tests/e2e/integrations/test_sonicmoe.py
Comment thread tests/integrations/test_sonicmoe.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 4, 2026

@winglian winglian merged commit 6a8baf8 into main Mar 5, 2026
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants