consolidate behavioud of routing in scattermoe kernels#3475
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR refactors the scattermoe-lora kernel integration by introducing centralized routing helper functions (_softmax_topk_route, _sigmoid_topk_route, _route dispatcher) and a shared-expert computation helper (_compute_shared_expert). The HFScatterMoEGatedMLP forward path is updated to use these helpers, reducing in-method complexity. Comprehensive end-to-end and unit tests are added to validate sigmoid and softmax routing variants. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tests/integrations/test_scattermoe_lora.py (1)
369-380: Consider adding missing attributes to thebias_on_gate=Falsemock.When
bias_on_gate=False, the mockmoe_blockis missingn_routed_experts,n_group,norm_topk_prob, androuted_scaling_factor. While_sigmoid_topk_routehandles these withgetattrdefaults, explicitly setting them would make the test more representative of real model structures and exercise more code paths consistently.🧪 Suggested enhancement
else: # minimax_m2 style: bias on block, not gate gate = SimpleNamespace( weight=torch.randn(E, H), top_k=K, ) moe_block = SimpleNamespace( gate=gate, top_k=K, e_score_correction_bias=torch.zeros(E), + n_routed_experts=E, + n_group=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integrations/test_scattermoe_lora.py` around lines 369 - 380, The mock created for the bias_on_gate=False branch lacks several attributes present on real blocks; update the SimpleNamespace `moe_block` (and/or `gate`) in that branch to include `n_routed_experts`, `n_group`, `norm_topk_prob`, and `routed_scaling_factor` with sensible default tensors/scalars (e.g., ints or torch tensors matching expected shapes) so tests exercise the same code paths as real models and avoid relying on getattr defaults in `_sigmoid_topk_route`.tests/e2e/integrations/test_scattermoe_lora_kernels.py (2)
1489-1537: Unused parameters in reference implementation.The
gate_weightandnum_expertsparameters (flagged by static analysis) are unused. The routing decision is already encoded inrouting_weights/selected_experts, andnum_expertscan be inferred fromgate_up_proj.shape[0]. Consider removing them or prefixing with underscore if kept for API consistency.🧹 Suggested cleanup
def _reference_moe_forward( hidden_states, - gate_weight, gate_up_proj, down_proj, act_fn, routing_weights, selected_experts, - num_experts, ):Then update call sites accordingly.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/integrations/test_scattermoe_lora_kernels.py` around lines 1489 - 1537, The _reference_moe_forward function currently has unused parameters gate_weight and num_experts; remove these parameters from the signature and all call sites (or if you must keep them for API compatibility, rename them to _gate_weight and _num_experts to mark as intentionally unused) and derive number of experts from gate_up_proj.shape[0] where needed; update any calls to _reference_moe_forward to match the new signature or continue passing the arguments if renamed (keeping callers consistent with the change).
1570-1583: Mock forbias_on_gate=Falsecase is incomplete but works for current tests.The minimax_m2 style mock (lines 1570-1581) is missing several attributes (
n_routed_experts,n_group,norm_topk_prob,routed_scaling_factor). This works because the test at line 1640 usesn_group=1and_sigmoid_topk_routehandles missing attributes withgetattrdefaults. Consider adding these for consistency with thebias_on_gate=Truecase.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/integrations/test_scattermoe_lora_kernels.py` around lines 1570 - 1583, The minimax_m2-style mock in the test creates gate and moe_block SimpleNamespace objects but omits attributes expected in the bias_on_gate=True case; add the missing attributes n_routed_experts, n_group, norm_topk_prob, and routed_scaling_factor to the moe_block (and any needed defaults on gate) so the mock mirrors the other branch; update the SimpleNamespace construction for gate/moe_block used in test_scattermoe_lora_kernels.py (the minimax_m2 mock) to include these attributes with sensible defaults (e.g., n_group=1 and zeros/ones as appropriate) so code paths like _sigmoid_topk_route see consistent fields.src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py (1)
228-248: Unusedmoe_blockparameter is acceptable for API consistency.The
moe_blockparameter is unused here (as flagged by static analysis) but provides API symmetry with_sigmoid_topk_route, allowing the_routedispatcher to call both functions with the same signature. Consider prefixing with underscore to silence the linter.🔧 Suggested fix to silence linter
def _softmax_topk_route( - moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta + _moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta ):Note: You'd also need to update the call site in
_routeto pass the argument positionally or update the parameter name there.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py` around lines 228 - 248, The unused parameter moe_block in _softmax_topk_route should be renamed with a leading underscore (e.g., _moe_block) to silence the linter while keeping API symmetry with _sigmoid_topk_route; update the function signature in _softmax_topk_route and ensure the dispatcher _route still passes the argument correctly (either positionally or by matching the new name) so call sites remain consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py`:
- Around line 275-300: If e_score_correction_bias is missing, avoid adding None
to router_probs by defaulting e_score_correction_bias to a zero tensor with the
same shape/device/dtype as router_probs (use getattr on base_gate and moe_block
as already done, then if None create torch.zeros_like(router_probs)); also guard
access to moe_block.topk_group by using getattr(moe_block, "topk_group", 1) (and
optionally validate it's an int >0) before using it in the group selection logic
so _sigmoid_topk_route / scores_for_choice won't raise TypeError/AttributeError
when moe_block is incomplete.
---
Nitpick comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py`:
- Around line 228-248: The unused parameter moe_block in _softmax_topk_route
should be renamed with a leading underscore (e.g., _moe_block) to silence the
linter while keeping API symmetry with _sigmoid_topk_route; update the function
signature in _softmax_topk_route and ensure the dispatcher _route still passes
the argument correctly (either positionally or by matching the new name) so call
sites remain consistent.
In `@tests/e2e/integrations/test_scattermoe_lora_kernels.py`:
- Around line 1489-1537: The _reference_moe_forward function currently has
unused parameters gate_weight and num_experts; remove these parameters from the
signature and all call sites (or if you must keep them for API compatibility,
rename them to _gate_weight and _num_experts to mark as intentionally unused)
and derive number of experts from gate_up_proj.shape[0] where needed; update any
calls to _reference_moe_forward to match the new signature or continue passing
the arguments if renamed (keeping callers consistent with the change).
- Around line 1570-1583: The minimax_m2-style mock in the test creates gate and
moe_block SimpleNamespace objects but omits attributes expected in the
bias_on_gate=True case; add the missing attributes n_routed_experts, n_group,
norm_topk_prob, and routed_scaling_factor to the moe_block (and any needed
defaults on gate) so the mock mirrors the other branch; update the
SimpleNamespace construction for gate/moe_block used in
test_scattermoe_lora_kernels.py (the minimax_m2 mock) to include these
attributes with sensible defaults (e.g., n_group=1 and zeros/ones as
appropriate) so code paths like _sigmoid_topk_route see consistent fields.
In `@tests/integrations/test_scattermoe_lora.py`:
- Around line 369-380: The mock created for the bias_on_gate=False branch lacks
several attributes present on real blocks; update the SimpleNamespace
`moe_block` (and/or `gate`) in that branch to include `n_routed_experts`,
`n_group`, `norm_topk_prob`, and `routed_scaling_factor` with sensible default
tensors/scalars (e.g., ints or torch tensors matching expected shapes) so tests
exercise the same code paths as real models and avoid relying on getattr
defaults in `_sigmoid_topk_route`.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4437119c-ec65-4b85-990f-4375371326b5
📒 Files selected for processing (3)
src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.pytests/e2e/integrations/test_scattermoe_lora_kernels.pytests/integrations/test_scattermoe_lora.py
| # Bias-corrected scores for expert selection (not used for final weights). | ||
| # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block. | ||
| e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None) | ||
| if e_score_correction_bias is None: | ||
| e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None) | ||
| scores_for_choice = router_probs + e_score_correction_bias | ||
|
|
||
| # Group-based selection: pick top groups, mask the rest | ||
| n_group = getattr(moe_block, "n_group", 1) | ||
| if n_group > 1: | ||
| group_scores = ( | ||
| scores_for_choice.view(-1, n_group, num_experts // n_group) | ||
| .topk(2, dim=-1)[0] | ||
| .sum(dim=-1) | ||
| ) # [T, n_group] | ||
| group_idx = torch.topk( | ||
| group_scores, k=moe_block.topk_group, dim=-1, sorted=False | ||
| )[1] | ||
| group_mask = torch.zeros_like(group_scores) | ||
| group_mask.scatter_(1, group_idx, 1) | ||
| score_mask = ( | ||
| group_mask.unsqueeze(-1) | ||
| .expand(-1, n_group, num_experts // n_group) | ||
| .reshape(-1, num_experts) | ||
| ) | ||
| scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) |
There was a problem hiding this comment.
Potential runtime errors if called with incomplete moe_block.
Two concerns:
-
Line 280: If
e_score_correction_biasisNone(both getattr calls fail), this line will raiseTypeError: unsupported operand type(s) for +: 'Tensor' and 'NoneType'. While_routeguards against this, direct calls to_sigmoid_topk_routecould crash. -
Line 291:
moe_block.topk_groupis accessed directly withoutgetattr. Ifn_group > 1buttopk_groupattribute is missing, this raisesAttributeError.
🛡️ Proposed defensive fix
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
+ if e_score_correction_bias is None:
+ raise ValueError(
+ "_sigmoid_topk_route requires e_score_correction_bias on gate or moe_block"
+ )
scores_for_choice = router_probs + e_score_correction_bias
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, num_experts // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
+ topk_group = getattr(moe_block, "topk_group", 1)
group_idx = torch.topk(
- group_scores, k=moe_block.topk_group, dim=-1, sorted=False
+ group_scores, k=topk_group, dim=-1, sorted=False
)[1]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py` around lines
275 - 300, If e_score_correction_bias is missing, avoid adding None to
router_probs by defaulting e_score_correction_bias to a zero tensor with the
same shape/device/dtype as router_probs (use getattr on base_gate and moe_block
as already done, then if None create torch.zeros_like(router_probs)); also guard
access to moe_block.topk_group by using getattr(moe_block, "topk_group", 1) (and
optionally validate it's an int >0) before using it in the group selection logic
so _sigmoid_topk_route / scores_for_choice won't raise TypeError/AttributeError
when moe_block is incomplete.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
match behaviour of sonicmoe of softmax and sigmoid routing in scattermoe
capture the scattermoe autotuning best kernel metrics so we can prune them better later
Motivation and Context
How has this been tested?
AI Usage Disclaimer
Claude
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
Refactor
Tests