Skip to content

feat: add sonicmoe fused lora support#3519

Merged
winglian merged 12 commits into
mainfrom
feat/sonicmoe-lora
Apr 2, 2026
Merged

feat: add sonicmoe fused lora support#3519
winglian merged 12 commits into
mainfrom
feat/sonicmoe-lora

Conversation

@NanoCode012
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 commented Mar 20, 2026

Description

Purely experimental. We recommend folks interested in MoE kernel + LoRA to check out ScatterMoE fused LoRA instead as it's validated.

Concept: since we cannot patch internal CUTLASS code, during forward, we mod adapter -> weights -> kernel, whereas during backwards, kernel -> extract adapter grad back -> weights.

Motivation and Context

How has this been tested?

AI Usage Disclaimer

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added SonicMoE + LoRA integration enabling parameter-efficient fine-tuning of expert modules
    • Added configuration validation to warn about potential runtime overhead when adapting expert layers
  • Tests

    • Added comprehensive integration and end-to-end test coverage for SonicMoE + LoRA training scenarios

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 20, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d4cb2bde-a14b-4a55-b349-6237828ec958

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added SonicMoE + PEFT LoRA integration with runtime weight materialization. Introduces LoRA utilities for expert weight unwrapping and gradient computation, updates expert and router weight extraction to support LoRA deltas, adds configuration validation, reorganizes imports to new libs.sonicmoe structure, and provides comprehensive test coverage.

Changes

Cohort / File(s) Summary
SonicMoE LoRA Core Implementation
src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
New module implementing LoRA utilities: helpers for detecting and unwrapping PEFT-wrapped parameters (has_lora, get_lora_params_from_wrapper, unwrap_gate_lora, unwrap_experts_lora), a custom autograd function MoELoRAMaterialize for runtime weight materialization with gradient computation, and materialize_expert_lora helper.
SonicMoE Integration Updates
src/axolotl/integrations/kernels/libs/sonicmoe/patch.py, src/axolotl/integrations/kernels/libs/sonicmoe/routing.py
Updated expert weight extraction and routing to support LoRA-aware operations: patch.py adds _get_expert_weights helper and modifies forward functions to materialize LoRA deltas; routing.py refactors routing functions to unwrap gate LoRA and incorporate router LoRA delta via separate linear projection.
Configuration & Plugin
src/axolotl/integrations/kernels/args.py, src/axolotl/integrations/kernels/plugin.py
Added warn_sonicmoe_lora_overhead validator to KernelsArgs that warns when LoRA targets expert modules in SonicMoE configurations; updated import path in plugin.py to reference new libs.sonicmoe structure.
Test Import Updates
tests/e2e/integrations/test_sonicmoe.py, tests/integrations/test_sonicmoe.py, tests/integrations/test_sonicmoe_gradients.py
Updated import paths from axolotl.integrations.kernels.sonicmoe to axolotl.integrations.kernels.libs.sonicmoe across existing test files with no logic changes.
New Test Coverage
tests/e2e/integrations/test_sonicmoe_lora.py, tests/integrations/test_sonicmoe_lora.py
New end-to-end test module validating SonicMoE + LoRA training on Qwen3MoE with expert and gate-only LoRA targeting, and regression test without LoRA; new integration test module with comprehensive unit tests for unwrapping functions and MoELoRAMaterialize autograd function including gradient checking.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • #3411: Main PR extends the SonicMoE integration introduced in this PR by adding PEFT/LoRA-aware tooling and LoRA materialization logic.
  • #3439: Both PRs modify MoE + PEFT/LoRA handling with complementary approaches for PEFT target-parameter matching and LoRA integration.

Suggested reviewers

  • winglian
  • djsaunde
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.97% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding SonicMoE fused LoRA support. It directly corresponds to the PR objectives and the substantial code additions across multiple files implementing this feature.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/sonicmoe-lora

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@NanoCode012 NanoCode012 marked this pull request as ready for review April 2, 2026 04:45
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
tests/integrations/test_sonicmoe_lora.py (1)

268-275: Make the base-weight freeze test actually exercise backward().

W is loaded from the fixture with requires_grad=False, so this stays green even if MoELoRAMaterialize.backward() accidentally emits a base-weight gradient. Re-enable grads for W inside this test and keep asserting that W.grad stays None.

Suggested tightening
     def test_no_grad_for_base_weight(self, setup):
         W, A, B, scaling, E, r = setup
+        W = W.detach().clone().requires_grad_(True)
         W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
         loss = W_eff.sum()
         loss.backward()
         assert W.grad is None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integrations/test_sonicmoe_lora.py` around lines 268 - 275, The test
test_no_grad_for_base_weight currently uses a W tensor with requires_grad=False
so it won't catch accidental gradients; before calling
MoELoRAMaterialize.apply(W, A, B, scaling) set W.requires_grad_(True) to enable
gradient tracking, then run loss.backward() and keep asserting that W.grad is
None while A.grad and B.grad are not None; this ensures
MoELoRAMaterialize.backward() is actually exercised and does not populate
base-weight gradients.
tests/e2e/integrations/test_sonicmoe_lora.py (1)

267-271: Consider adding Inf check for consistency.

The test_loss_decreases test checks for both NaN and Inf (lines 127-128), but this test only checks for NaN. Consider adding the Inf check for consistency across tests.

♻️ Proposed fix
         for step in range(20):
             out = model(input_ids, labels=input_ids)
             loss = out.loss
             assert not math.isnan(loss.item()), f"NaN loss at step {step}"
+            assert not math.isinf(loss.item()), f"Inf loss at step {step}"
             losses.append(loss.item())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/e2e/integrations/test_sonicmoe_lora.py` around lines 267 - 271, The
loop that computes loss (for step in range(20): using model(input_ids,
labels=input_ids) and variable loss appended to losses) only asserts NaN; add an
Inf check for consistency by asserting the loss is finite (e.g., use
math.isfinite(loss.item()) or assert not math.isinf(loss.item())) right after
computing loss and before appending to losses, and include a clear assertion
message such as "Inf loss at step {step}" referencing the step, loss, model,
input_ids, and losses variables to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/integrations/kernels/libs/sonicmoe/lora.py`:
- Around line 50-55: The code currently only checks adapter_name in lora_A_dict
before accessing lora_B_dict[adapter_name] and scaling_dict[adapter_name], which
can raise a KeyError if the dicts are inconsistent; update the logic around
lora_A_dict, lora_B_dict and scaling_dict to defensively verify that
adapter_name exists in all three (e.g., check "if adapter_name not in
lora_A_dict or adapter_name not in lora_B_dict or adapter_name not in
scaling_dict") and return the same safe fallback (None, None, None) or log an
error when any are missing so the function that uses lora_A, lora_B and scaling
never encounters an unexpected KeyError.

In `@src/axolotl/integrations/kernels/libs/sonicmoe/patch.py`:
- Around line 192-205: The current fused branch recombines the DTensor router
and LoRA delta by doing router_weight + router_lora_delta (via
effective_router_weight), which reintroduces DTensor+Tensor ops and breaks
sharded/FSDP runs; instead, when unwrap_gate_lora() returns a non-None
router_lora_delta, either (a) materialize a local router tensor before adding
the delta (so the add is between local tensors) or (b) follow the two-operation
pattern used in general_routing_inputs()/routing.py by keeping router_weight and
router_lora_delta separate and applying two F.linear calls rather than summing
them, or (c) disable the fused gate LoRA path until the kernel supports
consuming the delta separately; update the code around unwrap_gate_lora,
effective_router_weight, and the call site to moe_TC_softmax_topk_layer to
implement one of these fixes.

In `@src/axolotl/integrations/kernels/libs/sonicmoe/routing.py`:
- Around line 98-106: The softmax path uses F.linear(hidden_states,
gate_lora_delta) while hidden_states may be bf16 and gate_lora_delta is fp32
(from PEFT), causing a dtype mismatch; update the softmax routes (the F.linear
calls that add gate_lora_delta to router_logits) to mirror the sigmoid branch:
cast both hidden_states and gate_lora_delta to float32 before the F.linear,
perform the matmul in fp32, then cast the result back to the router_logits dtype
(or add it back in float32 consistently) so unwrap_gate_lora/gate_lora_delta,
hidden_states, and router_logits all use compatible dtypes during the LoRA
contribution.

---

Nitpick comments:
In `@tests/e2e/integrations/test_sonicmoe_lora.py`:
- Around line 267-271: The loop that computes loss (for step in range(20): using
model(input_ids, labels=input_ids) and variable loss appended to losses) only
asserts NaN; add an Inf check for consistency by asserting the loss is finite
(e.g., use math.isfinite(loss.item()) or assert not math.isinf(loss.item()))
right after computing loss and before appending to losses, and include a clear
assertion message such as "Inf loss at step {step}" referencing the step, loss,
model, input_ids, and losses variables to locate the change.

In `@tests/integrations/test_sonicmoe_lora.py`:
- Around line 268-275: The test test_no_grad_for_base_weight currently uses a W
tensor with requires_grad=False so it won't catch accidental gradients; before
calling MoELoRAMaterialize.apply(W, A, B, scaling) set W.requires_grad_(True) to
enable gradient tracking, then run loss.backward() and keep asserting that
W.grad is None while A.grad and B.grad are not None; this ensures
MoELoRAMaterialize.backward() is actually exercised and does not populate
base-weight gradients.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3274f4ae-8b22-468d-97ae-0fc5cdb0f1a7

📥 Commits

Reviewing files that changed from the base of the PR and between b55706b and b01dcfb.

📒 Files selected for processing (12)
  • src/axolotl/integrations/kernels/args.py
  • src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py
  • src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
  • src/axolotl/integrations/kernels/libs/sonicmoe/patch.py
  • src/axolotl/integrations/kernels/libs/sonicmoe/routing.py
  • src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py
  • src/axolotl/integrations/kernels/plugin.py
  • tests/e2e/integrations/test_sonicmoe.py
  • tests/e2e/integrations/test_sonicmoe_lora.py
  • tests/integrations/test_sonicmoe.py
  • tests/integrations/test_sonicmoe_gradients.py
  • tests/integrations/test_sonicmoe_lora.py

Comment thread src/axolotl/integrations/kernels/libs/sonicmoe/lora.py Outdated
Comment on lines +192 to +205
# Unwrap router for attribute access + optional LoRA delta
raw_router = getattr(self, router_attr)
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
if router_lora_delta is not None:
effective_router_weight = router_weight + router_lora_delta
else:
effective_router_weight = router_weight

# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
gate_up_weight, down_weight = _get_expert_weights(self.experts)

output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
effective_router_weight,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In PyTorch DTensor/FSDP, is adding a DTensor to a regular Tensor directly supported, or must both operands be the same distributed tensor type?

💡 Result:

In PyTorch DTensor and FSDP (fully_shard), adding a DTensor to a regular torch.Tensor is not directly supported. Both operands must be DTensors for distributed operators. Official PyTorch DTensor documentation explicitly states: "We disallow mixed Tensor-DTensor operations: if the input to any operations (e.g. torch.add) is a DTensor, then all Tensor inputs must be DTensors." Users must convert regular Tensors to DTensors using DTensor.from_local with appropriate placements (e.g., Replicate for constants). This raises a RuntimeError: "got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!" Numerous GitHub issues and examples confirm this behavior in DTensor and FSDP2 contexts, where FSDP parameters are DTensors but activations or biases may remain regular Tensors, requiring explicit conversion. Example fix: bias_dtensor = DTensor.from_local(bias_tensor, device_mesh=mesh, placements=[Replicate]) result = dtensor_param + bias_dtensor

Citations:


🏁 Script executed:

# First, let's examine the patch.py file around the flagged lines
cat -n src/axolotl/integrations/kernels/libs/sonicmoe/patch.py | head -220 | tail -60

Repository: axolotl-ai-cloud/axolotl

Length of output: 2884


🏁 Script executed:

# Find the unwrap_gate_lora function definition
rg -A 20 "def unwrap_gate_lora" src/axolotl/integrations/kernels/libs/sonicmoe/

Repository: axolotl-ai-cloud/axolotl

Length of output: 2040


🏁 Script executed:

# Check imports and understand tensor types used
head -50 src/axolotl/integrations/kernels/libs/sonicmoe/patch.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1845


🏁 Script executed:

# Look at the lora.py file to understand the context mentioned in the review
cat -n src/axolotl/integrations/kernels/libs/sonicmoe/lora.py | grep -A 30 "router_lora_delta"

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Get the full unwrap_gate_lora function implementation
rg -A 40 "def unwrap_gate_lora" src/axolotl/integrations/kernels/libs/sonicmoe/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1537


🏁 Script executed:

# Check for FSDP imports and usage in the sonicmoe module
rg "FSDP|DTensor|fully_shard" src/axolotl/integrations/kernels/libs/sonicmoe/

Repository: axolotl-ai-cloud/axolotl

Length of output: 336


🏁 Script executed:

# Look at get_lora_params_from_wrapper to understand tensor types
rg -B 5 -A 20 "def get_lora_params_from_wrapper" src/axolotl/integrations/kernels/libs/sonicmoe/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 928


🏁 Script executed:

# Check routing.py to see the workaround mentioned
rg -B 10 -A 10 "Two F.linear calls avoid mixing DTensor" src/axolotl/integrations/kernels/libs/sonicmoe/routing.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1130


🏁 Script executed:

# Check if there's any context about when FSDP is actually active
rg -B 5 -A 5 "FSDP\|fully_shard" src/axolotl/integrations/kernels/libs/sonicmoe/ || echo "No direct FSDP usage found"

Repository: axolotl-ai-cloud/axolotl

Length of output: 94


🏁 Script executed:

# Look at the general forward path to see how it handles router_lora_delta
rg -B 5 -A 15 "def _make_forward\(" src/axolotl/integrations/kernels/libs/sonicmoe/patch.py | head -60

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


Don't recombine the router DTensor and LoRA delta in fused mode.

unwrap_gate_lora() already splits these tensors to avoid DTensor + Tensor under FSDP. Doing the add here reintroduces that failure mode. The general routing path in routing.py avoids this correctly using two separate F.linear calls. Sharded LoRA runs will fail as soon as this fused branch is exercised.

Materialize a local router tensor explicitly before the add, or use the two-operation pattern from general_routing_inputs(), or gate fused router LoRA off until the kernel can consume the delta separately.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/libs/sonicmoe/patch.py` around lines 192 -
205, The current fused branch recombines the DTensor router and LoRA delta by
doing router_weight + router_lora_delta (via effective_router_weight), which
reintroduces DTensor+Tensor ops and breaks sharded/FSDP runs; instead, when
unwrap_gate_lora() returns a non-None router_lora_delta, either (a) materialize
a local router tensor before adding the delta (so the add is between local
tensors) or (b) follow the two-operation pattern used in
general_routing_inputs()/routing.py by keeping router_weight and
router_lora_delta separate and applying two F.linear calls rather than summing
them, or (c) disable the fused gate LoRA path until the kernel supports
consuming the delta separately; update the code around unwrap_gate_lora,
effective_router_weight, and the call site to moe_TC_softmax_topk_layer to
implement one of these fixes.

Comment on lines +98 to +106
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = base_gate.top_k

# Compute router logits and softmax over all experts.
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
router_logits = F.linear(hidden_states, gate_weight) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does torch.nn.functional.linear require input and weight to have the same dtype, and does PEFT keep LoRA adapter weights in fp32 by default when the base model runs in bf16?

💡 Result:

Yes to both questions. torch.nn.functional.linear requires input and weight to have the same dtype, as evidenced by RuntimeError "mat1 and mat2 must have the same dtype" when they differ, from PyTorch forums and code behavior using underlying matmul operations that enforce dtype matching. PEFT keeps LoRA adapter weights in fp32 by default when loading, even if the base model is in bf16. This is a deliberate design choice for training stability and better performance with higher precision adapters, as stated in PEFT GitHub issues #2421 and #1249. Users can disable this with autocast_adapter_dtype=False.

Citations:


🏁 Script executed:

cat -n src/axolotl/integrations/kernels/libs/sonicmoe/routing.py | sed -n '90,160p'

Repository: axolotl-ai-cloud/axolotl

Length of output: 3642


🏁 Script executed:

rg -n "def sigmoid_topk_routing" src/axolotl/integrations/kernels/libs/sonicmoe/routing.py -A 30

Repository: axolotl-ai-cloud/axolotl

Length of output: 1617


🏁 Script executed:

rg -n "def sigmoid_topk_routing" src/axolotl/integrations/kernels/libs/sonicmoe/routing.py -A 60 | tail -40

Repository: axolotl-ai-cloud/axolotl

Length of output: 2151


Run the new softmax LoRA matmuls in fp32.

gate_lora_delta comes back in adapter dtype (fp32 by default via PEFT), and these routes execute with bf16 activations. The F.linear(hidden_states, gate_lora_delta) calls at lines 99-101 and 142-144 are dtype-unsafe: PyTorch's F.linear requires matching input and weight dtypes, and will throw a RuntimeError when hidden_states are bf16 and gate_lora_delta is fp32. The sigmoid path already casts both operands to float32 (line 223-227); apply the same pattern to both softmax functions.

The current test helper in tests/integrations/test_sonicmoe_gradients.py still builds plain .weight gates only, so this code path is not exercised.

Suggested pattern
-    router_logits = F.linear(hidden_states, gate_weight)  # [T, E]
+    hidden_states_fp32 = hidden_states.float()
+    router_logits = F.linear(hidden_states_fp32, gate_weight.float())  # [T, E]
     if gate_lora_delta is not None:
-        router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
+        router_logits = router_logits + F.linear(
+            hidden_states_fp32, gate_lora_delta.float()
+        )
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 99-99: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/kernels/libs/sonicmoe/routing.py` around lines 98 -
106, The softmax path uses F.linear(hidden_states, gate_lora_delta) while
hidden_states may be bf16 and gate_lora_delta is fp32 (from PEFT), causing a
dtype mismatch; update the softmax routes (the F.linear calls that add
gate_lora_delta to router_logits) to mirror the sigmoid branch: cast both
hidden_states and gate_lora_delta to float32 before the F.linear, perform the
matmul in fp32, then cast the result back to the router_logits dtype (or add it
back in float32 consistently) so unwrap_gate_lora/gate_lora_delta,
hidden_states, and router_logits all use compatible dtypes during the LoRA
contribution.

@winglian winglian merged commit 842fa03 into main Apr 2, 2026
1 check passed
@winglian winglian deleted the feat/sonicmoe-lora branch April 2, 2026 12:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants