feat: add sonicmoe fused lora support#3519
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdded SonicMoE + PEFT LoRA integration with runtime weight materialization. Introduces LoRA utilities for expert weight unwrapping and gradient computation, updates expert and router weight extraction to support LoRA deltas, adds configuration validation, reorganizes imports to new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tests/integrations/test_sonicmoe_lora.py (1)
268-275: Make the base-weight freeze test actually exercisebackward().
Wis loaded from the fixture withrequires_grad=False, so this stays green even ifMoELoRAMaterialize.backward()accidentally emits a base-weight gradient. Re-enable grads forWinside this test and keep asserting thatW.gradstaysNone.Suggested tightening
def test_no_grad_for_base_weight(self, setup): W, A, B, scaling, E, r = setup + W = W.detach().clone().requires_grad_(True) W_eff = MoELoRAMaterialize.apply(W, A, B, scaling) loss = W_eff.sum() loss.backward() assert W.grad is None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integrations/test_sonicmoe_lora.py` around lines 268 - 275, The test test_no_grad_for_base_weight currently uses a W tensor with requires_grad=False so it won't catch accidental gradients; before calling MoELoRAMaterialize.apply(W, A, B, scaling) set W.requires_grad_(True) to enable gradient tracking, then run loss.backward() and keep asserting that W.grad is None while A.grad and B.grad are not None; this ensures MoELoRAMaterialize.backward() is actually exercised and does not populate base-weight gradients.tests/e2e/integrations/test_sonicmoe_lora.py (1)
267-271: Consider adding Inf check for consistency.The
test_loss_decreasestest checks for both NaN and Inf (lines 127-128), but this test only checks for NaN. Consider adding the Inf check for consistency across tests.♻️ Proposed fix
for step in range(20): out = model(input_ids, labels=input_ids) loss = out.loss assert not math.isnan(loss.item()), f"NaN loss at step {step}" + assert not math.isinf(loss.item()), f"Inf loss at step {step}" losses.append(loss.item())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/integrations/test_sonicmoe_lora.py` around lines 267 - 271, The loop that computes loss (for step in range(20): using model(input_ids, labels=input_ids) and variable loss appended to losses) only asserts NaN; add an Inf check for consistency by asserting the loss is finite (e.g., use math.isfinite(loss.item()) or assert not math.isinf(loss.item())) right after computing loss and before appending to losses, and include a clear assertion message such as "Inf loss at step {step}" referencing the step, loss, model, input_ids, and losses variables to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/integrations/kernels/libs/sonicmoe/lora.py`:
- Around line 50-55: The code currently only checks adapter_name in lora_A_dict
before accessing lora_B_dict[adapter_name] and scaling_dict[adapter_name], which
can raise a KeyError if the dicts are inconsistent; update the logic around
lora_A_dict, lora_B_dict and scaling_dict to defensively verify that
adapter_name exists in all three (e.g., check "if adapter_name not in
lora_A_dict or adapter_name not in lora_B_dict or adapter_name not in
scaling_dict") and return the same safe fallback (None, None, None) or log an
error when any are missing so the function that uses lora_A, lora_B and scaling
never encounters an unexpected KeyError.
In `@src/axolotl/integrations/kernels/libs/sonicmoe/patch.py`:
- Around line 192-205: The current fused branch recombines the DTensor router
and LoRA delta by doing router_weight + router_lora_delta (via
effective_router_weight), which reintroduces DTensor+Tensor ops and breaks
sharded/FSDP runs; instead, when unwrap_gate_lora() returns a non-None
router_lora_delta, either (a) materialize a local router tensor before adding
the delta (so the add is between local tensors) or (b) follow the two-operation
pattern used in general_routing_inputs()/routing.py by keeping router_weight and
router_lora_delta separate and applying two F.linear calls rather than summing
them, or (c) disable the fused gate LoRA path until the kernel supports
consuming the delta separately; update the code around unwrap_gate_lora,
effective_router_weight, and the call site to moe_TC_softmax_topk_layer to
implement one of these fixes.
In `@src/axolotl/integrations/kernels/libs/sonicmoe/routing.py`:
- Around line 98-106: The softmax path uses F.linear(hidden_states,
gate_lora_delta) while hidden_states may be bf16 and gate_lora_delta is fp32
(from PEFT), causing a dtype mismatch; update the softmax routes (the F.linear
calls that add gate_lora_delta to router_logits) to mirror the sigmoid branch:
cast both hidden_states and gate_lora_delta to float32 before the F.linear,
perform the matmul in fp32, then cast the result back to the router_logits dtype
(or add it back in float32 consistently) so unwrap_gate_lora/gate_lora_delta,
hidden_states, and router_logits all use compatible dtypes during the LoRA
contribution.
---
Nitpick comments:
In `@tests/e2e/integrations/test_sonicmoe_lora.py`:
- Around line 267-271: The loop that computes loss (for step in range(20): using
model(input_ids, labels=input_ids) and variable loss appended to losses) only
asserts NaN; add an Inf check for consistency by asserting the loss is finite
(e.g., use math.isfinite(loss.item()) or assert not math.isinf(loss.item()))
right after computing loss and before appending to losses, and include a clear
assertion message such as "Inf loss at step {step}" referencing the step, loss,
model, input_ids, and losses variables to locate the change.
In `@tests/integrations/test_sonicmoe_lora.py`:
- Around line 268-275: The test test_no_grad_for_base_weight currently uses a W
tensor with requires_grad=False so it won't catch accidental gradients; before
calling MoELoRAMaterialize.apply(W, A, B, scaling) set W.requires_grad_(True) to
enable gradient tracking, then run loss.backward() and keep asserting that
W.grad is None while A.grad and B.grad are not None; this ensures
MoELoRAMaterialize.backward() is actually exercised and does not populate
base-weight gradients.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 3274f4ae-8b22-468d-97ae-0fc5cdb0f1a7
📒 Files selected for processing (12)
src/axolotl/integrations/kernels/args.pysrc/axolotl/integrations/kernels/libs/sonicmoe/__init__.pysrc/axolotl/integrations/kernels/libs/sonicmoe/lora.pysrc/axolotl/integrations/kernels/libs/sonicmoe/patch.pysrc/axolotl/integrations/kernels/libs/sonicmoe/routing.pysrc/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.pysrc/axolotl/integrations/kernels/plugin.pytests/e2e/integrations/test_sonicmoe.pytests/e2e/integrations/test_sonicmoe_lora.pytests/integrations/test_sonicmoe.pytests/integrations/test_sonicmoe_gradients.pytests/integrations/test_sonicmoe_lora.py
| # Unwrap router for attribute access + optional LoRA delta | ||
| raw_router = getattr(self, router_attr) | ||
| base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router) | ||
| if router_lora_delta is not None: | ||
| effective_router_weight = router_weight + router_lora_delta | ||
| else: | ||
| effective_router_weight = router_weight | ||
|
|
||
| # Permute weights to SonicMoE layout: | ||
| # gate_up: [E, 2*I, H] -> [2*I, H, E] | ||
| # down: [E, H, I] -> [H, I, E] | ||
| gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) | ||
| down_weight = self.experts.down_proj.permute(1, 2, 0) | ||
| # Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout | ||
| gate_up_weight, down_weight = _get_expert_weights(self.experts) | ||
|
|
||
| output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer( | ||
| hidden_states_flat, | ||
| router.weight, | ||
| effective_router_weight, |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In PyTorch DTensor/FSDP, is adding a DTensor to a regular Tensor directly supported, or must both operands be the same distributed tensor type?
💡 Result:
In PyTorch DTensor and FSDP (fully_shard), adding a DTensor to a regular torch.Tensor is not directly supported. Both operands must be DTensors for distributed operators. Official PyTorch DTensor documentation explicitly states: "We disallow mixed Tensor-DTensor operations: if the input to any operations (e.g. torch.add) is a DTensor, then all Tensor inputs must be DTensors." Users must convert regular Tensors to DTensors using DTensor.from_local with appropriate placements (e.g., Replicate for constants). This raises a RuntimeError: "got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!" Numerous GitHub issues and examples confirm this behavior in DTensor and FSDP2 contexts, where FSDP parameters are DTensors but activations or biases may remain regular Tensors, requiring explicit conversion. Example fix: bias_dtensor = DTensor.from_local(bias_tensor, device_mesh=mesh, placements=[Replicate]) result = dtensor_param + bias_dtensor
Citations:
- 1: https://docs.pytorch.org/docs/stable/_sources/distributed.tensor.md.txt
- 2: https://dev-discuss.pytorch.org/t/dtensor-status-design-and-looking-forward/2749
- 3: FSDP2 "got mixed torch.Tensor and DTensor" pytorch/pytorch#153354
- 4: RuntimeError: Got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators pytorch/torchtitan#875
- 5: https://runebook.dev/en/docs/pytorch/distributed.tensor/torch.distributed.tensor.DTensor
🏁 Script executed:
# First, let's examine the patch.py file around the flagged lines
cat -n src/axolotl/integrations/kernels/libs/sonicmoe/patch.py | head -220 | tail -60Repository: axolotl-ai-cloud/axolotl
Length of output: 2884
🏁 Script executed:
# Find the unwrap_gate_lora function definition
rg -A 20 "def unwrap_gate_lora" src/axolotl/integrations/kernels/libs/sonicmoe/Repository: axolotl-ai-cloud/axolotl
Length of output: 2040
🏁 Script executed:
# Check imports and understand tensor types used
head -50 src/axolotl/integrations/kernels/libs/sonicmoe/patch.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1845
🏁 Script executed:
# Look at the lora.py file to understand the context mentioned in the review
cat -n src/axolotl/integrations/kernels/libs/sonicmoe/lora.py | grep -A 30 "router_lora_delta"Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Get the full unwrap_gate_lora function implementation
rg -A 40 "def unwrap_gate_lora" src/axolotl/integrations/kernels/libs/sonicmoe/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1537
🏁 Script executed:
# Check for FSDP imports and usage in the sonicmoe module
rg "FSDP|DTensor|fully_shard" src/axolotl/integrations/kernels/libs/sonicmoe/Repository: axolotl-ai-cloud/axolotl
Length of output: 336
🏁 Script executed:
# Look at get_lora_params_from_wrapper to understand tensor types
rg -B 5 -A 20 "def get_lora_params_from_wrapper" src/axolotl/integrations/kernels/libs/sonicmoe/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 928
🏁 Script executed:
# Check routing.py to see the workaround mentioned
rg -B 10 -A 10 "Two F.linear calls avoid mixing DTensor" src/axolotl/integrations/kernels/libs/sonicmoe/routing.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1130
🏁 Script executed:
# Check if there's any context about when FSDP is actually active
rg -B 5 -A 5 "FSDP\|fully_shard" src/axolotl/integrations/kernels/libs/sonicmoe/ || echo "No direct FSDP usage found"Repository: axolotl-ai-cloud/axolotl
Length of output: 94
🏁 Script executed:
# Look at the general forward path to see how it handles router_lora_delta
rg -B 5 -A 15 "def _make_forward\(" src/axolotl/integrations/kernels/libs/sonicmoe/patch.py | head -60Repository: axolotl-ai-cloud/axolotl
Length of output: 50
Don't recombine the router DTensor and LoRA delta in fused mode.
unwrap_gate_lora() already splits these tensors to avoid DTensor + Tensor under FSDP. Doing the add here reintroduces that failure mode. The general routing path in routing.py avoids this correctly using two separate F.linear calls. Sharded LoRA runs will fail as soon as this fused branch is exercised.
Materialize a local router tensor explicitly before the add, or use the two-operation pattern from general_routing_inputs(), or gate fused router LoRA off until the kernel can consume the delta separately.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/kernels/libs/sonicmoe/patch.py` around lines 192 -
205, The current fused branch recombines the DTensor router and LoRA delta by
doing router_weight + router_lora_delta (via effective_router_weight), which
reintroduces DTensor+Tensor ops and breaks sharded/FSDP runs; instead, when
unwrap_gate_lora() returns a non-None router_lora_delta, either (a) materialize
a local router tensor before adding the delta (so the add is between local
tensors) or (b) follow the two-operation pattern used in
general_routing_inputs()/routing.py by keeping router_weight and
router_lora_delta separate and applying two F.linear calls rather than summing
them, or (c) disable the fused gate LoRA path until the kernel supports
consuming the delta separately; update the code around unwrap_gate_lora,
effective_router_weight, and the call site to moe_TC_softmax_topk_layer to
implement one of these fixes.
| base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) | ||
| T, H = hidden_states.shape | ||
| K = base_gate.top_k | ||
|
|
||
| # Compute router logits and softmax over all experts. | ||
| # Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP. | ||
| router_logits = F.linear(hidden_states, gate_weight) # [T, E] | ||
| if gate_lora_delta is not None: | ||
| router_logits = router_logits + F.linear(hidden_states, gate_lora_delta) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does torch.nn.functional.linear require input and weight to have the same dtype, and does PEFT keep LoRA adapter weights in fp32 by default when the base model runs in bf16?
💡 Result:
Yes to both questions. torch.nn.functional.linear requires input and weight to have the same dtype, as evidenced by RuntimeError "mat1 and mat2 must have the same dtype" when they differ, from PyTorch forums and code behavior using underlying matmul operations that enforce dtype matching. PEFT keeps LoRA adapter weights in fp32 by default when loading, even if the base model is in bf16. This is a deliberate design choice for training stability and better performance with higher precision adapters, as stated in PEFT GitHub issues #2421 and #1249. Users can disable this with autocast_adapter_dtype=False.
Citations:
- 1: https://discuss.pytorch.org/t/runtimeerror-mat1-and-mat2-must-have-the-same-dtype/166759
- 2: About the dtype of trainable params huggingface/peft#1249
- 3: Adapters saved in float16 are loaded in float32 huggingface/peft#2421
🏁 Script executed:
cat -n src/axolotl/integrations/kernels/libs/sonicmoe/routing.py | sed -n '90,160p'Repository: axolotl-ai-cloud/axolotl
Length of output: 3642
🏁 Script executed:
rg -n "def sigmoid_topk_routing" src/axolotl/integrations/kernels/libs/sonicmoe/routing.py -A 30Repository: axolotl-ai-cloud/axolotl
Length of output: 1617
🏁 Script executed:
rg -n "def sigmoid_topk_routing" src/axolotl/integrations/kernels/libs/sonicmoe/routing.py -A 60 | tail -40Repository: axolotl-ai-cloud/axolotl
Length of output: 2151
Run the new softmax LoRA matmuls in fp32.
gate_lora_delta comes back in adapter dtype (fp32 by default via PEFT), and these routes execute with bf16 activations. The F.linear(hidden_states, gate_lora_delta) calls at lines 99-101 and 142-144 are dtype-unsafe: PyTorch's F.linear requires matching input and weight dtypes, and will throw a RuntimeError when hidden_states are bf16 and gate_lora_delta is fp32. The sigmoid path already casts both operands to float32 (line 223-227); apply the same pattern to both softmax functions.
The current test helper in tests/integrations/test_sonicmoe_gradients.py still builds plain .weight gates only, so this code path is not exercised.
Suggested pattern
- router_logits = F.linear(hidden_states, gate_weight) # [T, E]
+ hidden_states_fp32 = hidden_states.float()
+ router_logits = F.linear(hidden_states_fp32, gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
- router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
+ router_logits = router_logits + F.linear(
+ hidden_states_fp32, gate_lora_delta.float()
+ )🧰 Tools
🪛 Ruff (0.15.7)
[warning] 99-99: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/kernels/libs/sonicmoe/routing.py` around lines 98 -
106, The softmax path uses F.linear(hidden_states, gate_lora_delta) while
hidden_states may be bf16 and gate_lora_delta is fp32 (from PEFT), causing a
dtype mismatch; update the softmax routes (the F.linear calls that add
gate_lora_delta to router_logits) to mirror the sigmoid branch: cast both
hidden_states and gate_lora_delta to float32 before the F.linear, perform the
matmul in fp32, then cast the result back to the router_logits dtype (or add it
back in float32 consistently) so unwrap_gate_lora/gate_lora_delta,
hidden_states, and router_logits all use compatible dtypes during the LoRA
contribution.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
Purely experimental. We recommend folks interested in MoE kernel + LoRA to check out ScatterMoE fused LoRA instead as it's validated.
Concept: since we cannot patch internal CUTLASS code, during forward, we mod adapter -> weights -> kernel, whereas during backwards, kernel -> extract adapter grad back -> weights.
Motivation and Context
How has this been tested?
AI Usage Disclaimer
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Tests