ScatterMoE LoRA support#3410
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces ScatterMoE LoRA kernel integration with Triton implementations, updates the kernels dependency to 0.12.1, simplifies model saving logic in base trainer, adds kernel validation, and refines configuration handling. The bulk of changes comprise new LoRA-aware layer implementations and fused kernel operations for mixture-of-experts architectures. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Fix all issues with AI agents
In `@src/axolotl/integrations/kernels/args.py`:
- Around line 39-45: The warning in disable_mlp_kernel_scattermoe refers to
"Disabling lora_mlp_kernel" but the code only sets data["mlp_kernel"] = False;
either set data["lora_mlp_kernel"] = False as well inside
disable_mlp_kernel_scattermoe when data.get("lora_mlp_kernel") is True, or
change the LOG.warning text to accurately state that mlp_kernel is being
disabled; update the message and/or assignment in the
disable_mlp_kernel_scattermoe method accordingly.
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py`:
- Around line 272-332: The custom op decorator for groupXtY_compileable
incorrectly declares only DW as mutated; update the `@torch.library.custom_op`
annotation to include Db in mutates_args (e.g., mutates_args={"DW", "Db"}) so
the compiler knows the kernel may write to Db when Db is not None; keep the
function signature and internal logic unchanged, just adjust the decorator to
reflect that Db can be mutated at runtime.
- Around line 557-567: The function signature of group_compileable currently
types coeff as torch.Tensor but callers pass None; update the parameter to
coeff: Optional[torch.Tensor] (and add Optional to imports if missing) so the
`@torch.library.custom_op` annotation and runtime accept None; ensure any internal
uses of coeff in group_compileable handle None safely (same pattern used for b:
Optional[torch.Tensor] in scatter2scatter_compileable).
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py`:
- Around line 56-62: The accumulation loop in the single-kernel (acc, loop over
_K_block_id) uses tl.load on X_blk_ptrs and W_blk_ptrs without bounds masks, so
boundary tiles produce garbage when K or N are not multiples of BLOCK_K/BLOCK_N;
add boolean masks (e.g., k_mask = K_block[:, None] < K and n_mask =
N_block[None, :] < N) and pass them to tl.load (use other=0.0) for X and W (W
uses k_mask & n_mask), and ensure K_block (and any N_block) is updated each
iteration after advancing X_blk_ptrs/W_blk_ptrs so the mask stays correct.
Ensure acc uses only masked loads so out-of-bounds elements contribute zero.
- Line 74: The grid calculation uses floor division which drops trailing N-sized
blocks when ydim % BLOCK_N != 0; replace the use of integer floor division for
computing grid with a ceiling division (use triton.cdiv or equivalent) so grid =
triton.cdiv(ydim, BLOCK_N), k to match ops.py behavior and ensure the final
partial block is included for functions referencing grid, ydim, and BLOCK_N in
single.py.
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py`:
- Around line 228-281: The ScatterMoEGatedMLP.forward docstring incorrectly
claims it returns two values (output tensor and router logits) while the
implementation only returns layer_output; update the docstring in the
ScatterMoEGatedMLP.forward method to document the single return value (output
tensor) and remove or rephrase any mention of router_logits (or, if you prefer
to also return router logits, instead modify the return statement to return
(layer_output, router_logits) and update call sites) — locate the forward method
and adjust its Returns section accordingly to match the actual behavior.
- Around line 298-441: The forward staticmethod currently returns only
expert_output but must follow Olmoe's contract and return (final_hidden_states,
router_logits); modify the forward in MoeSparseMoeBlock (the staticmethod named
forward) to return a tuple (expert_output, router_logits) so callers like
OlmoeDecoderLayer (which does hidden_states, router_logits =
self.mlp(hidden_states)) can unpack correctly—use the existing router_logits
variable computed during the Router Computation step (ensure it is the
pre-softmax logits you already compute) and return it alongside expert_output.
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py`:
- Line 94: The code currently uses "scaling=self._lora_scaling or 1.0" which
treats 0.0 as falsy and overrides an explicit zero; change the assignment to use
an explicit None check so that 0.0 is preserved (e.g., use self._lora_scaling if
self._lora_scaling is not None else 1.0) in the place where scaling is passed
(look for the scaling= argument in lora_ops.py, e.g., inside the method creating
the LoRA-wrapped module or function that references self._lora_scaling) so that
None defaults to 1.0 but 0.0 remains a valid value.
- Around line 62-72: The set_lora / clear_lora helpers are dead code because
HFScatterMoEGatedMLP.forward passes LoRA tensors directly to
parallel_linear_lora and never uses _lora_A/_lora_B/_lora_scaling; either remove
these methods and the unused attributes to avoid confusion, or if you intend to
store tensors on the module make them proper module buffers/parameters (use
self.register_buffer for tensors and ensure they are handled in
state_dict()/to()/half()); update or delete set_lora, clear_lora and the
attributes _lora_A/_lora_B/_lora_scaling accordingly so the code matches
HFScatterMoEGatedMLP.forward and parallel_linear_lora usage.
🧹 Nitpick comments (12)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
62-67: The unwrapping logic is correct but the comment is slightly misleading.
(output,) = outputon line 67 will raiseValueErrorif the tuple has more than one element. The comment on lines 63-65 references older models returning(hidden_states, present_kv, ...), but such multi-element tuples would already break thisbackward(which accepts only a singledY). The destructuring is fine — it's just the comment that overstates what's actually handled here.Consider tightening the comment to avoid confusion for future readers:
Suggested comment tweak
output = ctx.forward_function(hidden_states, *ctx.args) - # Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer - # return a plain tensor, not a tuple. Older models return tuples - # like (hidden_states, present_kv, ...). Unwrap if needed. + # Some HF models (e.g. via GradientCheckpointingLayer) wrap the + # output in a single-element tuple. Unwrap if needed so we can + # call backward on the plain tensor. if isinstance(output, (tuple, list)): (output,) = outputsrc/axolotl/core/trainers/base.py (1)
754-755: Inconsistentis_main_processparameter acrosssave_pretrainedcalls.Line 743 explicitly passes
is_main_process=self.accelerator.is_main_processto the unwrapped model'ssave_pretrained, but line 755 omits this parameter for the direct model path. For defensive consistency and to align with the pattern used elsewhere in this method, add the parameter here as well.Proposed fix
- self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained( + output_dir, + state_dict=state_dict, + is_main_process=self.accelerator.is_main_process, + )src/axolotl/integrations/kernels/plugin.py (1)
3-5: Remove commented-out import.
# LayerRepository,is dead code. Clean it up.Proposed fix
from kernels import ( - # LayerRepository, LocalLayerRepository, Mode,src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py (1)
102-102: Remove debug print artifacts.Lines 102 and 161 contain commented-out
Also applies to: 161-161
src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py (2)
120-339: Backward return tuple is correct (15 values for 15 forward params), but comments are slightly misleading.The inline comment on line 332 (
# lora_A, lora_B, scaling) groups three items, butd_lora_Aandd_lora_Bon lines 330–331 are actual non-None gradients — only theNoneon line 332 corresponds toscaling. Consider reformatting for clarity:- d_lora_A, - d_lora_B, - None, # lora_A, lora_B, scaling + d_lora_A, # lora_A + d_lora_B, # lora_B + None, # scaling
390-417:get_lora_params_from_wrapper— redundanthasattr+getattrpattern.Line 397 checks
hasattr(module, "lora_A"), but line 406 re-fetches withgetattr(module, "lora_A", {}). Since you already verified the attribute exists, you can usemodule.lora_Adirectly. Harmless but slightly inconsistent.src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py (1)
326-328:F.sigmoidis deprecated; prefertorch.sigmoid.
torch.nn.functional.sigmoidhas been deprecated in favor oftorch.sigmoidsince PyTorch 1.x. While it still works, using the preferred API avoids future deprecation warnings.Proposed fix
- shared_expert_gate_output = F.sigmoid( + shared_expert_gate_output = torch.sigmoid( self.shared_expert_gate(hidden_states_flat) )src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py (2)
109-110: Heuristic-computedNO_K_MASK/NO_N_MASKare unused — kernel recomputes masking locally.The
@triton.heuristicsdecorator on lines 72–77 computesNO_K_MASKandNO_N_MASK, and they are accepted as kernel parameters (lines 109–110), but the kernel body at line 124 recomputesno_k_mask = K % BLOCK_K == 0instead of usingNO_K_MASK. Additionally,NO_N_MASKis never used at all —N_maskis always computed unconditionally on line 120.For comparison, the LoRA variant in
kernels/lora_ops.pycorrectly uses the heuristic value (no_k_mask = NO_K_MASK).Proposed fix
- no_k_mask = K % BLOCK_K == 0 + no_k_mask = NO_K_MASKAnd consider using
NO_N_MASKto skipN_maskin loads where applicable, matching the pattern in_compute_expert_block.Also applies to: 124-124
224-228: Misleading variable namestride_bkfor the N-dimension stride of bias.Bias has shape
[E, N], so its strides are(stride_be, stride_bn). The Python wrapper names the second stridestride_bk, which is confusing since it maps to thestride_bnkernel parameter (line 92). This is only a naming issue — values are positionally correct.Rename for consistency
if b is None: - b = None - stride_be = stride_bk = 0 + stride_be = stride_bn = 0 else: - stride_be, stride_bk = b.stride() + stride_be, stride_bn = b.stride()And update references on lines 246–247 accordingly.
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py (3)
101-159:round_expert_countsuses a Python loop with GPU→CPU transfers per expert.Lines 107–110 and 129–152 iterate over all experts with
.item()calls, each incurring a CUDA synchronization. For models with many experts (e.g., 64 in OLMoE), this serializes GPU operations. Since this runs on every backward pass, consider vectorizing with torch tensor ops to eliminate the loop.Vectorized alternative for the count computation (lines 105-110)
- counts = torch.zeros(E, dtype=torch.int64, device=device) - prev = 0 - for e in range(E): - curr = expert_offsets[e].item() - counts[e] = curr - prev - prev = curr + counts = torch.empty(E, dtype=torch.int64, device=device) + counts[0] = expert_offsets[0] + counts[1:] = expert_offsets[1:] - expert_offsets[:-1]The copy loop (lines 129–152) is harder to vectorize due to variable-length segments but could use
torch.cator a small Triton kernel if profiling shows it's a bottleneck.
1087-1118: Backward dA/dB autotune search space is very large (384 configs).
4 × 4 × 4 × 2 × 3 = 384configs before pruning. The SMEM pruning function helps, but the first kernel launch will incur a significant autotuning cost. The forward (72 configs) and dX (96 configs) are more conservative. Consider reducing the search space or addingwarmup/reptuning to the autotune decorator if cold-start latency is a concern.
1333-1343: Remove unused parameterssorted_scattered_idxsandkfromgroup_bwd_lora.These parameters are accepted but never used in the function body and are not passed to the kernel. While they exist for API symmetry with
group_bwd_lora_fused(which requires both), that design pattern is not necessary here since the caller does not pass them. Removing them clarifies the actual interface.Remove unused params
def group_bwd_lora( DY: torch.Tensor, X: torch.Tensor, lora_A: torch.Tensor, lora_B: torch.Tensor, expert_offsets: torch.Tensor, E: int, scaling: float, - sorted_scattered_idxs: Optional[torch.Tensor] = None, - k: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]:
| def disable_mlp_kernel_scattermoe(cls, data): | ||
| if data.get("use_scattermoe") is True: | ||
| if data.get("lora_mlp_kernel") is True: | ||
| LOG.warning( | ||
| "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues." | ||
| ) | ||
| data["mlp_kernel"] = False |
There was a problem hiding this comment.
Warning says "Disabling lora_mlp_kernel" but the code only disables mlp_kernel.
The warning on line 43 tells the user lora_mlp_kernel is being disabled, but data["lora_mlp_kernel"] is never set to False — only data["mlp_kernel"] is. If both should be disabled, add the missing assignment; otherwise, fix the warning text.
Proposed fix (if both should be disabled)
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
+ data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/args.py` around lines 39 - 45, The warning
in disable_mlp_kernel_scattermoe refers to "Disabling lora_mlp_kernel" but the
code only sets data["mlp_kernel"] = False; either set data["lora_mlp_kernel"] =
False as well inside disable_mlp_kernel_scattermoe when
data.get("lora_mlp_kernel") is True, or change the LOG.warning text to
accurately state that mlp_kernel is being disabled; update the message and/or
assignment in the disable_mlp_kernel_scattermoe method accordingly.
| acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) | ||
| for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)): | ||
| x = tl.load(X_blk_ptrs) | ||
| w = tl.load(W_blk_ptrs) | ||
| acc += tl.sum(x * w, axis=0)[None, :] | ||
| X_blk_ptrs += BLOCK_K * stride_xk | ||
| W_blk_ptrs += BLOCK_K * stride_wk |
There was a problem hiding this comment.
Missing out-of-bounds masking for boundary blocks.
When K is not a multiple of BLOCK_K (or N not a multiple of BLOCK_N), tl.load reads past the valid range without a mask. This produces garbage values in the accumulator for boundary tiles. After switching the grid to triton.cdiv, the N boundary case becomes reachable too.
Consider adding masks similar to other kernels in this codebase:
k_mask = K_block[:, None] < K
n_mask = N_block[None, :] < N
x = tl.load(X_blk_ptrs, mask=k_mask, other=0.0)
w = tl.load(W_blk_ptrs, mask=k_mask & n_mask, other=0.0)and updating K_block after each iteration for the mask to stay correct.
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py`
around lines 56 - 62, The accumulation loop in the single-kernel (acc, loop over
_K_block_id) uses tl.load on X_blk_ptrs and W_blk_ptrs without bounds masks, so
boundary tiles produce garbage when K or N are not multiples of BLOCK_K/BLOCK_N;
add boolean masks (e.g., k_mask = K_block[:, None] < K and n_mask =
N_block[None, :] < N) and pass them to tl.load (use other=0.0) for X and W (W
uses k_mask & n_mask), and ensure K_block (and any N_block) is updated each
iteration after advancing X_blk_ptrs/W_blk_ptrs so the mask stays correct.
Ensure acc uses only masked loads so out-of-bounds elements contribute zero.
| Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) | ||
| BLOCK_N = 128 | ||
| BLOCK_K = 128 | ||
| grid = ydim // BLOCK_N, k |
There was a problem hiding this comment.
Use ceiling division for grid computation — truncation drops trailing N-blocks.
ydim // BLOCK_N silently drops the last block when ydim is not a multiple of BLOCK_N. Every analogous grid in ops.py uses triton.cdiv. This will produce silently incorrect (incomplete) output for any model whose expert output dimension is not a multiple of 128.
Proposed fix
- grid = ydim // BLOCK_N, k
+ grid = triton.cdiv(ydim, BLOCK_N), k🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py` at
line 74, The grid calculation uses floor division which drops trailing N-sized
blocks when ydim % BLOCK_N != 0; replace the use of integer floor division for
computing grid with a ceiling division (use triton.cdiv or equivalent) so grid =
triton.cdiv(ydim, BLOCK_N), k to match ops.py behavior and ensure the final
partial block is included for functions referencing grid, ydim, and BLOCK_N in
single.py.
| def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float): | ||
| """Attach LoRA parameters for fused computation.""" | ||
| self._lora_A = lora_A | ||
| self._lora_B = lora_B | ||
| self._lora_scaling = scaling | ||
|
|
||
| def clear_lora(self): | ||
| """Remove LoRA parameters.""" | ||
| self._lora_A = None | ||
| self._lora_B = None | ||
| self._lora_scaling = None |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
rg -n 'set_lora|clear_lora' --type=py -C2Repository: axolotl-ai-cloud/axolotl
Length of output: 2745
🏁 Script executed:
cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 3755
🏁 Script executed:
rg -n 'ParallelExperts' --type=py -C2Repository: axolotl-ai-cloud/axolotl
Length of output: 4422
🏁 Script executed:
cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py | head -100Repository: axolotl-ai-cloud/axolotl
Length of output: 4023
🏁 Script executed:
cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 20295
🏁 Script executed:
rg -n 'set_lora\(' --type=py -C3Repository: axolotl-ai-cloud/axolotl
Length of output: 1549
🏁 Script executed:
rg -n 'clear_lora\(' --type=py -C3Repository: axolotl-ai-cloud/axolotl
Length of output: 738
🏁 Script executed:
rg -n '\.set_lora|\.clear_lora' --type=pyRepository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
rg -n 'state_dict|_load_from_state_dict|to\(device' src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Check if there's any external code that might call set_lora/clear_lora
rg -rn 'import.*ParallelExperts' --type=pyRepository: axolotl-ai-cloud/axolotl
Length of output: 150
🏁 Script executed:
# Check if ParallelExperts is ever actually instantiated and these methods called
rg -n 'ParallelExperts\(' --type=py -C2Repository: axolotl-ai-cloud/axolotl
Length of output: 1040
🏁 Script executed:
rg -n '_lora_A|_lora_B|_lora_scaling' --type=py -C2Repository: axolotl-ai-cloud/axolotl
Length of output: 5729
🏁 Script executed:
grep -r "getattr.*set_lora\|hasattr.*set_lora" --include="*.py" src/Repository: axolotl-ai-cloud/axolotl
Length of output: 50
These set_lora / clear_lora methods appear unused — LoRA parameters are managed at the call site, not stored on the module.
The plain attributes _lora_A, _lora_B, and _lora_scaling won't be saved/restored by state_dict() or moved by .to(device) / .half(). However, this is likely not a concern: the actual LoRA integration in HFScatterMoEGatedMLP.forward() (in layers.py) extracts LoRA parameters from PEFT wrappers at runtime and passes them directly as function arguments to parallel_linear_lora(), bypassing these setter methods entirely. No code path calls set_lora() or clear_lora().
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py` around
lines 62 - 72, The set_lora / clear_lora helpers are dead code because
HFScatterMoEGatedMLP.forward passes LoRA tensors directly to
parallel_linear_lora and never uses _lora_A/_lora_B/_lora_scaling; either remove
these methods and the unused attributes to avoid confusion, or if you intend to
store tensors on the module make them proper module buffers/parameters (use
self.register_buffer for tensors and ensure they are handled in
state_dict()/to()/half()); update or delete set_lora, clear_lora and the
attributes _lora_A/_lora_B/_lora_scaling accordingly so the code matches
HFScatterMoEGatedMLP.forward and parallel_linear_lora usage.
| expert_offsets, | ||
| lora_A=self._lora_A, | ||
| lora_B=self._lora_B, | ||
| scaling=self._lora_scaling or 1.0, |
There was a problem hiding this comment.
self._lora_scaling or 1.0 silently overrides an explicit scaling=0.0.
0.0 is falsy in Python, so if someone intentionally sets scaling=0.0 (to disable the LoRA delta), this expression evaluates to 1.0 instead. Use self._lora_scaling if self._lora_scaling is not None else 1.0 to distinguish None (unset) from 0.0 (intentionally zero).
Proposed fix
- scaling=self._lora_scaling or 1.0,
+ scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| scaling=self._lora_scaling or 1.0, | |
| scaling=self._lora_scaling if self._lora_scaling is not None else 1.0, |
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py` at line
94, The code currently uses "scaling=self._lora_scaling or 1.0" which treats 0.0
as falsy and overrides an explicit zero; change the assignment to use an
explicit None check so that 0.0 is preserved (e.g., use self._lora_scaling if
self._lora_scaling is not None else 1.0) in the place where scaling is passed
(look for the scaling= argument in lora_ops.py, e.g., inside the method creating
the LoRA-wrapped module or function that references self._lora_scaling) so that
None defaults to 1.0 but 0.0 remains a valid value.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| output_dir = output_dir if output_dir is not None else self.args.output_dir | ||
| os.makedirs(output_dir, exist_ok=True) | ||
| LOG.info(f"Saving model checkpoint to {output_dir}") | ||
| if state_dict is None: |
There was a problem hiding this comment.
This was added recently to solve some saving issue . Do the changes below solve it?
There was a problem hiding this comment.
Just did a quick pass on this, .clone() may be unintentionally placing tensors on GPU.
There was a problem hiding this comment.
What do you mean? This was a cleanup from changes upstream.
There was a problem hiding this comment.
This was just added by ved a week ago to fix saving in context parallelism 97a4f28
3a49d44 to
a00b11e
Compare
Description
Uses local layer repo for ScatterMoE implementation instead of remote repo, making this easier to modify.
ScatterMoE from https://arxiv.org/abs/2403.08245 and https://github.com/shawntan/scattermoe
Fuses LoRA computation with ScatterMoE
Uses some optimizations recommended in in SonicMoE such as fused dX and fused gather