feat: add sonicmoe#3411
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds SonicMoE kernel integration with mutual exclusivity against ScatterMoE, including dynamic MoE block resolution, GPU compatibility checks, forward patching, routing implementations, weight converters, and comprehensive unit and end-to-end tests. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://69a8006696e3e160904ba462--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit c3c1a16 |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (4)
src/axolotl/integrations/kernels/sonicmoe/patch.py (1)
112-132: Weight permutation occurs on every forward pass.At Lines 115-116 and 165-166, the weight tensors are permuted on every forward call. For large models with many forward passes, this repeated permutation could impact performance. Consider caching the permuted weights on the module during the first forward pass.
⚡ Proposed optimization
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: from sonicmoe import moe_general_routing_inputs batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) # Shared expert (computed early, matching original model ordering) shared_expert_output = _compute_shared_expert(self, hidden_states_flat) # Routing router_scores, token_indices, expert_indices, _router_logits = routing_fn( hidden_states_flat, self ) - # Permute weights to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) - down_weight = self.experts.down_proj.permute(1, 2, 0) + # Permute weights to SonicMoE layout (cached on first call): + # gate_up: [E, 2*I, H] -> [2*I, H, E] + # down: [E, H, I] -> [H, I, E] + if not hasattr(self, "_sonicmoe_gate_up_weight"): + self._sonicmoe_gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0).contiguous() + self._sonicmoe_down_weight = self.experts.down_proj.permute(1, 2, 0).contiguous() + gate_up_weight = self._sonicmoe_gate_up_weight + down_weight = self._sonicmoe_down_weight E = gate_up_weight.shape[-1]Note: This assumes weights don't change during training. If LoRA or other adapters modify the base weights, this caching strategy would need adjustment.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/kernels/sonicmoe/patch.py` around lines 112 - 132, The permute operations on self.experts.gate_up_proj and self.experts.down_proj are being done every forward; cache the permuted tensors on the module (e.g., self._cached_gate_up_weight and self._cached_down_weight) and only compute them if the cached attributes are None or if the original parameter tensors have changed (compare .data_ptr() or a version flag); then pass the cached_gate_up_weight and cached_down_weight into moe_general_routing_inputs instead of recomputing each call (update this logic around the forward code that currently calls self.experts.gate_up_proj.permute and self.experts.down_proj.permute and before the moe_general_routing_inputs invocation).src/axolotl/integrations/kernels/sonicmoe/weight_converter.py (1)
42-44: Thedimparameter is stored but not used in the conversion.Both
ConcatenatedToInterleavedandInterleavedToConcatenatedaccept adimparameter in their constructors, but the actualconvertmethods callinterleave_gate_up/deinterleave_gate_upwhich always operate along a fixed dimension pattern. Ifdim != 1were ever passed, the conversion would still operate on the same dimension.Either remove the
dimparameter if it's not needed, or modify the interleave functions to respect it.Also applies to: 92-94
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/kernels/sonicmoe/weight_converter.py` around lines 42 - 44, The constructors for ConcatenatedToInterleaved and InterleavedToConcatenated currently store dim but never use it; update the conversion flow so the chosen dim is respected: either remove the unused dim parameter from __init__ of both classes (and any callers) if variable dimension support is not needed, or modify interleave_gate_up and deinterleave_gate_up to accept a dim argument and change the convert methods in ConcatenatedToInterleaved.convert and InterleavedToConcatenated.convert to call interleave_gate_up(..., dim=self.dim) / deinterleave_gate_up(..., dim=self.dim) instead of the current fixed-dimension calls so self.dim is actually applied.src/axolotl/integrations/kernels/sonicmoe/routing.py (1)
177-192: Group-based selection assumes E is divisible by n_group.The view operation at Line 180 (
scores_for_choice.view(-1, n_group, E // n_group)) will silently truncate experts ifEis not evenly divisible byn_group, potentially causing incorrect routing. While this is likely guaranteed by the model architecture, adding a defensive assertion would catch configuration errors early.🛡️ Proposed defensive check
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1) if n_group > 1: + assert E % n_group == 0, f"n_routed_experts ({E}) must be divisible by n_group ({n_group})" group_scores = ( scores_for_choice.view(-1, n_group, E // n_group)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/kernels/sonicmoe/routing.py` around lines 177 - 192, The group-based selection assumes E is divisible by n_group; before the view call that reshapes scores_for_choice (the block using scores_for_choice.view(-1, n_group, E // n_group) inside the routing logic), add a defensive check that E % n_group == 0 and raise a clear error (or assertion) if not, mentioning E, n_group and moe_block.topk_group so misconfiguration is caught early and the silent truncation is prevented.tests/integrations/test_sonicmoe.py (1)
215-218:test_register_unsupported_model_type_warnsneeds an assertion on warning output.Right now this test is pass-through and won’t catch regressions in warning behavior.
💡 Proposed fix
- def test_register_unsupported_model_type_warns(self): - # A model type with no conversion mapping should warn but not raise - register_sonicmoe_weight_converter("nonexistent_model_type_xyz") + def test_register_unsupported_model_type_warns(self, caplog): + # A model type with no conversion mapping should warn but not raise + with caplog.at_level("WARNING"): + register_sonicmoe_weight_converter("nonexistent_model_type_xyz") + assert any( + "No conversion mapping found for model type" in msg + for msg in caplog.messages + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integrations/test_sonicmoe.py` around lines 215 - 218, The test test_register_unsupported_model_type_warns currently calls register_sonicmoe_weight_converter("nonexistent_model_type_xyz") without asserting any warning; update it to capture and assert the warning is emitted (e.g., using pytest.warns or caplog) and verify the warning message contains a clear indicator like "unsupported" or the passed model type; ensure the assertion references the test name test_register_unsupported_model_type_warns and the function register_sonicmoe_weight_converter so future regressions in warning behavior are caught.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/integrations/kernels/plugin.py`:
- Around line 19-20: The function _check_sonicmoe_gpu_compat currently returns
silently when torch.cuda.is_available() is False; change it to fail fast by
checking if use_sonicmoe is enabled and, if so, raise a clear RuntimeError (or
similar) indicating SonicMoE requires a CUDA-capable GPU and CUDA is not
available; update the _check_sonicmoe_gpu_compat function (and any callers if
needed) to perform this conditional check and raise with a descriptive message
so users see the failure immediately instead of later during execution.
In `@src/axolotl/integrations/kernels/sonicmoe/routing.py`:
- Around line 78-126: In softmax_topk_routing, avoid a potential
division-by-zero when renormalizing top_values (the block using
gate.norm_topk_prob); change the renorm to divide by top_values.sum(dim=-1,
keepdim=True) plus a tiny epsilon (e.g. 1e-20) so the denominator can never be
zero, mirroring the sigmoid variant's protection—update the renormalization
expression where top_values is reassigned.
In `@tests/e2e/integrations/test_sonicmoe.py`:
- Around line 230-233: Tests that patch_sonicmoe("qwen3_moe") are missing the
gate-up weight interleaving step, causing training-path tests to use a different
expert weight layout than forward/gradient parity checks; call the helper that
performs the interleave after patching and before optimizer/forward steps by
invoking _interleave_gate_up_weights(model) (the same function used in the
forward/gradient parity tests) so both training-path tests (the block around
model = AutoModelForCausalLM.from_config(...); patch_sonicmoe("qwen3_moe");
optimizer = ...) and the analogous block at lines ~260-262 perform the
interleave and use the correct expert weight layout for convergence/update
checks.
In `@tests/integrations/test_sonicmoe.py`:
- Around line 164-171: The test creates an OOM-prone case (E,I,H) =
(128,768,2048) which builds huge tensors (concat/interleaved/recovered); update
test_various_shapes to remove or downscale that tuple: replace (128,768,2048)
with a smaller shape (for example (32,384,1024)) or drop it entirely so the loop
over E,I,H uses only safe sizes; locate the loop using variables E, I, H and the
tensors concat and calls to fwd.convert and rev.convert and change the tuple
list accordingly to avoid CI memory failures.
---
Nitpick comments:
In `@src/axolotl/integrations/kernels/sonicmoe/patch.py`:
- Around line 112-132: The permute operations on self.experts.gate_up_proj and
self.experts.down_proj are being done every forward; cache the permuted tensors
on the module (e.g., self._cached_gate_up_weight and self._cached_down_weight)
and only compute them if the cached attributes are None or if the original
parameter tensors have changed (compare .data_ptr() or a version flag); then
pass the cached_gate_up_weight and cached_down_weight into
moe_general_routing_inputs instead of recomputing each call (update this logic
around the forward code that currently calls self.experts.gate_up_proj.permute
and self.experts.down_proj.permute and before the moe_general_routing_inputs
invocation).
In `@src/axolotl/integrations/kernels/sonicmoe/routing.py`:
- Around line 177-192: The group-based selection assumes E is divisible by
n_group; before the view call that reshapes scores_for_choice (the block using
scores_for_choice.view(-1, n_group, E // n_group) inside the routing logic), add
a defensive check that E % n_group == 0 and raise a clear error (or assertion)
if not, mentioning E, n_group and moe_block.topk_group so misconfiguration is
caught early and the silent truncation is prevented.
In `@src/axolotl/integrations/kernels/sonicmoe/weight_converter.py`:
- Around line 42-44: The constructors for ConcatenatedToInterleaved and
InterleavedToConcatenated currently store dim but never use it; update the
conversion flow so the chosen dim is respected: either remove the unused dim
parameter from __init__ of both classes (and any callers) if variable dimension
support is not needed, or modify interleave_gate_up and deinterleave_gate_up to
accept a dim argument and change the convert methods in
ConcatenatedToInterleaved.convert and InterleavedToConcatenated.convert to call
interleave_gate_up(..., dim=self.dim) / deinterleave_gate_up(..., dim=self.dim)
instead of the current fixed-dimension calls so self.dim is actually applied.
In `@tests/integrations/test_sonicmoe.py`:
- Around line 215-218: The test test_register_unsupported_model_type_warns
currently calls register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
without asserting any warning; update it to capture and assert the warning is
emitted (e.g., using pytest.warns or caplog) and verify the warning message
contains a clear indicator like "unsupported" or the passed model type; ensure
the assertion references the test name
test_register_unsupported_model_type_warns and the function
register_sonicmoe_weight_converter so future regressions in warning behavior are
caught.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
src/axolotl/integrations/kernels/README.mdsrc/axolotl/integrations/kernels/args.pysrc/axolotl/integrations/kernels/constants.pysrc/axolotl/integrations/kernels/plugin.pysrc/axolotl/integrations/kernels/sonicmoe/__init__.pysrc/axolotl/integrations/kernels/sonicmoe/patch.pysrc/axolotl/integrations/kernels/sonicmoe/routing.pysrc/axolotl/integrations/kernels/sonicmoe/weight_converter.pytests/e2e/integrations/test_sonicmoe.pytests/integrations/test_scattermoe_lora.pytests/integrations/test_sonicmoe.pytests/integrations/test_sonicmoe_gradients.py
| if not torch.cuda.is_available(): | ||
| return |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# First, check if file exists and get content around lines 19-20
echo "=== File content around lines 19-20 ==="
cat -n src/axolotl/integrations/kernels/plugin.py | head -40
echo ""
echo "=== Full search for sonicmoe references ==="
rg -n -C3 '_check_sonicmoe_gpu_compat|use_sonicmoe|torch\.cuda\.is_available' src/axolotl/integrations/kernels/plugin.py || trueRepository: axolotl-ai-cloud/axolotl
Length of output: 2614
Fail fast when SonicMoE is enabled but CUDA is unavailable.
The _check_sonicmoe_gpu_compat() function returns silently when torch.cuda.is_available() is False (lines 19-20), allowing use_sonicmoe=True to proceed on CPU-only systems and fail later during execution. Since SonicMoE requires a CUDA-capable GPU (validated by subsequent GPU capability checks), this should raise immediately with a clear error.
Proposed fix
if not torch.cuda.is_available():
- return
+ raise RuntimeError("SonicMoE requires a CUDA-capable GPU.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not torch.cuda.is_available(): | |
| return | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("SonicMoE requires a CUDA-capable GPU.") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/kernels/plugin.py` around lines 19 - 20, The
function _check_sonicmoe_gpu_compat currently returns silently when
torch.cuda.is_available() is False; change it to fail fast by checking if
use_sonicmoe is enabled and, if so, raise a clear RuntimeError (or similar)
indicating SonicMoE requires a CUDA-capable GPU and CUDA is not available;
update the _check_sonicmoe_gpu_compat function (and any callers if needed) to
perform this conditional check and raise with a descriptive message so users see
the failure immediately instead of later during execution.
| def softmax_topk_routing( | ||
| hidden_states: torch.Tensor, moe_block | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm. | ||
|
|
||
| Args: | ||
| hidden_states: [T, H] flattened token representations | ||
| moe_block: MoE block module (accesses moe_block.gate.*) | ||
|
|
||
| Returns: | ||
| router_scores: [T*K] flattened scores (float32) | ||
| token_indices: [T*K] which token each entry belongs to (int32), sorted ascending | ||
| expert_indices: [T*K] which expert (int32) | ||
| router_logits: [T, E] original logits for aux loss | ||
| """ | ||
| gate = moe_block.gate | ||
| T, H = hidden_states.shape | ||
| K = gate.top_k | ||
|
|
||
| # Compute router logits and softmax over all experts | ||
| router_logits = F.linear(hidden_states, gate.weight) # [T, E] | ||
| router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] | ||
|
|
||
| # Select top-k experts per token | ||
| top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each | ||
|
|
||
| # Renormalize if configured (default True for models without the attribute, | ||
| # e.g. Mixtral/MiniMax which always normalize) | ||
| if getattr(gate, "norm_topk_prob", True): | ||
| top_values = top_values / top_values.sum(dim=-1, keepdim=True) | ||
|
|
||
| # no-op: matches transformers which casts to softmax output dtype (float32). | ||
| # top_values = top_values.to(router_probs.dtype) | ||
|
|
||
| # Flatten for moe_general_routing_inputs. | ||
| # Token indices are naturally sorted ascending from the [T, K] layout: | ||
| # [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE. | ||
| # Expert sorting is handled internally by general_routing_router_metadata. | ||
| token_indices = ( | ||
| torch.arange(T, device=hidden_states.device, dtype=torch.int32) | ||
| .unsqueeze(1) | ||
| .expand(T, K) | ||
| ) | ||
|
|
||
| flat_scores = top_values.reshape(-1) # [T*K] | ||
| flat_token_idx = token_indices.reshape(-1) # [T*K] | ||
| flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] | ||
|
|
||
| return flat_scores, flat_token_idx, flat_expert_idx, router_logits |
There was a problem hiding this comment.
Potential division by zero in renormalization.
At Line 107, when norm_topk_prob is True, the code divides by top_values.sum(dim=-1, keepdim=True). If all top-k values are zero (which could theoretically happen with extreme logits), this would cause a division by zero. The sigmoid variant at Line 203 correctly adds an epsilon (1e-20) to prevent this.
Consider adding similar protection for consistency:
🛡️ Proposed fix
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
- top_values = top_values / top_values.sum(dim=-1, keepdim=True)
+ top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)🧰 Tools
🪛 Ruff (0.15.2)
[warning] 94-94: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/kernels/sonicmoe/routing.py` around lines 78 - 126,
In softmax_topk_routing, avoid a potential division-by-zero when renormalizing
top_values (the block using gate.norm_topk_prob); change the renorm to divide by
top_values.sum(dim=-1, keepdim=True) plus a tiny epsilon (e.g. 1e-20) so the
denominator can never be zero, mirroring the sigmoid variant's protection—update
the renormalization expression where top_values is reassigned.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
Adds sonicmoe kernel integration with per MoE class routing. Faster than scattermoe, less vram cost as well. But, requires Hopper, see https://github.com/Dao-AILab/sonic-moe/tree/main?tab=readme-ov-file#-installation
How it works: replaces forward for MoE sparse block with our custom wrapper and reshapes weight to follow sonicmoe's requirements. It would apply model-specific routing -> moe kernel if fused not available.
How to test meanwhile (require Hopper and pip install below):
TODOS:
norm_topk_prob=True, softmax -> topk is equivalent to the fused topk -> softmaxExtras:
Limitations (unplanned):
Motivation and Context
How has this been tested?
Not yet tested, but verified routing against existing modeling code.
AI Usage Disclaimer
Manual initial routing -> Claude integration
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Documentation
Bug Fixes