Skip to content

ScatterMoE LoRA support#3410

Merged
winglian merged 8 commits into
mainfrom
scattermoe-lora
Feb 24, 2026
Merged

ScatterMoE LoRA support#3410
winglian merged 8 commits into
mainfrom
scattermoe-lora

Conversation

@winglian

@winglian winglian commented Feb 15, 2026

Copy link
Copy Markdown
Collaborator

Description

Uses local layer repo for ScatterMoE implementation instead of remote repo, making this easier to modify.
ScatterMoE from https://arxiv.org/abs/2403.08245 and https://github.com/shawntan/scattermoe
Fuses LoRA computation with ScatterMoE
Uses some optimizations recommended in in SonicMoE such as fused dX and fused gather

@coderabbitai

coderabbitai Bot commented Feb 15, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces ScatterMoE LoRA kernel integration with Triton implementations, updates the kernels dependency to 0.12.1, simplifies model saving logic in base trainer, adds kernel validation, and refines configuration handling. The bulk of changes comprise new LoRA-aware layer implementations and fused kernel operations for mixture-of-experts architectures.

Changes

Cohort / File(s) Summary
Triton Kernel Implementations
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py, kernels/ops.py, kernels/single.py, kernels/__init__.py
Introduces low-level Triton kernels for fused ScatterMoE with LoRA forward/backward passes, scatter-to-scatter operations, grouping, and XtY computations with extensive autotuning and masking heuristics.
LoRA and MoE Layer Modules
src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py, lora_ops.py, parallel_experts.py, parallel_linear_lora.py, __init__.py
Implements LoRA-aware MoE layer replacements, custom autograd functions for fused operations, parameter unwrapping utilities, and per-expert linear transformations with gating support.
Kernel Configuration and Plugin
src/axolotl/integrations/kernels/args.py, plugin.py
Adds pre-validation hook to disable MLU kernel when ScatterMoE is enabled; switches kernel registry from remote to local path-based loading via LocalLayerRepository.
Core Trainer Simplifications
src/axolotl/core/trainers/base.py
Removes pre-save state_dict normalization, simplifies is_main_process handling in save paths, and consolidates tokenizer saving logic.
Utility Patches
src/axolotl/loaders/patch_manager.py, src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py
Improves trust_remote_code override logic; handles both plain tensor and tuple returns from forward functions in gradient checkpointing.
Dependency Update
requirements.txt
Updates kernels package from 0.11.5 to 0.12.1.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • NanoCode012
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 43.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'ScatterMoE LoRA support' clearly summarizes the main change across the changeset, which adds comprehensive LoRA support for ScatterMoE kernels and layers.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch scattermoe-lora

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In `@src/axolotl/integrations/kernels/args.py`:
- Around line 39-45: The warning in disable_mlp_kernel_scattermoe refers to
"Disabling lora_mlp_kernel" but the code only sets data["mlp_kernel"] = False;
either set data["lora_mlp_kernel"] = False as well inside
disable_mlp_kernel_scattermoe when data.get("lora_mlp_kernel") is True, or
change the LOG.warning text to accurately state that mlp_kernel is being
disabled; update the message and/or assignment in the
disable_mlp_kernel_scattermoe method accordingly.

In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py`:
- Around line 272-332: The custom op decorator for groupXtY_compileable
incorrectly declares only DW as mutated; update the `@torch.library.custom_op`
annotation to include Db in mutates_args (e.g., mutates_args={"DW", "Db"}) so
the compiler knows the kernel may write to Db when Db is not None; keep the
function signature and internal logic unchanged, just adjust the decorator to
reflect that Db can be mutated at runtime.
- Around line 557-567: The function signature of group_compileable currently
types coeff as torch.Tensor but callers pass None; update the parameter to
coeff: Optional[torch.Tensor] (and add Optional to imports if missing) so the
`@torch.library.custom_op` annotation and runtime accept None; ensure any internal
uses of coeff in group_compileable handle None safely (same pattern used for b:
Optional[torch.Tensor] in scatter2scatter_compileable).

In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py`:
- Around line 56-62: The accumulation loop in the single-kernel (acc, loop over
_K_block_id) uses tl.load on X_blk_ptrs and W_blk_ptrs without bounds masks, so
boundary tiles produce garbage when K or N are not multiples of BLOCK_K/BLOCK_N;
add boolean masks (e.g., k_mask = K_block[:, None] < K and n_mask =
N_block[None, :] < N) and pass them to tl.load (use other=0.0) for X and W (W
uses k_mask & n_mask), and ensure K_block (and any N_block) is updated each
iteration after advancing X_blk_ptrs/W_blk_ptrs so the mask stays correct.
Ensure acc uses only masked loads so out-of-bounds elements contribute zero.
- Line 74: The grid calculation uses floor division which drops trailing N-sized
blocks when ydim % BLOCK_N != 0; replace the use of integer floor division for
computing grid with a ceiling division (use triton.cdiv or equivalent) so grid =
triton.cdiv(ydim, BLOCK_N), k to match ops.py behavior and ensure the final
partial block is included for functions referencing grid, ydim, and BLOCK_N in
single.py.

In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py`:
- Around line 228-281: The ScatterMoEGatedMLP.forward docstring incorrectly
claims it returns two values (output tensor and router logits) while the
implementation only returns layer_output; update the docstring in the
ScatterMoEGatedMLP.forward method to document the single return value (output
tensor) and remove or rephrase any mention of router_logits (or, if you prefer
to also return router logits, instead modify the return statement to return
(layer_output, router_logits) and update call sites) — locate the forward method
and adjust its Returns section accordingly to match the actual behavior.
- Around line 298-441: The forward staticmethod currently returns only
expert_output but must follow Olmoe's contract and return (final_hidden_states,
router_logits); modify the forward in MoeSparseMoeBlock (the staticmethod named
forward) to return a tuple (expert_output, router_logits) so callers like
OlmoeDecoderLayer (which does hidden_states, router_logits =
self.mlp(hidden_states)) can unpack correctly—use the existing router_logits
variable computed during the Router Computation step (ensure it is the
pre-softmax logits you already compute) and return it alongside expert_output.

In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py`:
- Line 94: The code currently uses "scaling=self._lora_scaling or 1.0" which
treats 0.0 as falsy and overrides an explicit zero; change the assignment to use
an explicit None check so that 0.0 is preserved (e.g., use self._lora_scaling if
self._lora_scaling is not None else 1.0) in the place where scaling is passed
(look for the scaling= argument in lora_ops.py, e.g., inside the method creating
the LoRA-wrapped module or function that references self._lora_scaling) so that
None defaults to 1.0 but 0.0 remains a valid value.
- Around line 62-72: The set_lora / clear_lora helpers are dead code because
HFScatterMoEGatedMLP.forward passes LoRA tensors directly to
parallel_linear_lora and never uses _lora_A/_lora_B/_lora_scaling; either remove
these methods and the unused attributes to avoid confusion, or if you intend to
store tensors on the module make them proper module buffers/parameters (use
self.register_buffer for tensors and ensure they are handled in
state_dict()/to()/half()); update or delete set_lora, clear_lora and the
attributes _lora_A/_lora_B/_lora_scaling accordingly so the code matches
HFScatterMoEGatedMLP.forward and parallel_linear_lora usage.
🧹 Nitpick comments (12)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)

62-67: The unwrapping logic is correct but the comment is slightly misleading.

(output,) = output on line 67 will raise ValueError if the tuple has more than one element. The comment on lines 63-65 references older models returning (hidden_states, present_kv, ...), but such multi-element tuples would already break this backward (which accepts only a single dY). The destructuring is fine — it's just the comment that overstates what's actually handled here.

Consider tightening the comment to avoid confusion for future readers:

Suggested comment tweak
             output = ctx.forward_function(hidden_states, *ctx.args)
-            # Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
-            # return a plain tensor, not a tuple.  Older models return tuples
-            # like (hidden_states, present_kv, ...).  Unwrap if needed.
+            # Some HF models (e.g. via GradientCheckpointingLayer) wrap the
+            # output in a single-element tuple.  Unwrap if needed so we can
+            # call backward on the plain tensor.
             if isinstance(output, (tuple, list)):
                 (output,) = output
src/axolotl/core/trainers/base.py (1)

754-755: Inconsistent is_main_process parameter across save_pretrained calls.

Line 743 explicitly passes is_main_process=self.accelerator.is_main_process to the unwrapped model's save_pretrained, but line 755 omits this parameter for the direct model path. For defensive consistency and to align with the pattern used elsewhere in this method, add the parameter here as well.

Proposed fix
-            self.model.save_pretrained(output_dir, state_dict=state_dict)
+            self.model.save_pretrained(
+                output_dir,
+                state_dict=state_dict,
+                is_main_process=self.accelerator.is_main_process,
+            )
src/axolotl/integrations/kernels/plugin.py (1)

3-5: Remove commented-out import.

# LayerRepository, is dead code. Clean it up.

Proposed fix
 from kernels import (
-    # LayerRepository,
     LocalLayerRepository,
     Mode,
src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py (1)

102-102: Remove debug print artifacts.

Lines 102 and 161 contain commented-out print statements that should be cleaned up before merging.

Also applies to: 161-161

src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py (2)

120-339: Backward return tuple is correct (15 values for 15 forward params), but comments are slightly misleading.

The inline comment on line 332 (# lora_A, lora_B, scaling) groups three items, but d_lora_A and d_lora_B on lines 330–331 are actual non-None gradients — only the None on line 332 corresponds to scaling. Consider reformatting for clarity:

-            d_lora_A,
-            d_lora_B,
-            None,  # lora_A, lora_B, scaling
+            d_lora_A,  # lora_A
+            d_lora_B,  # lora_B
+            None,      # scaling

390-417: get_lora_params_from_wrapper — redundant hasattr + getattr pattern.

Line 397 checks hasattr(module, "lora_A"), but line 406 re-fetches with getattr(module, "lora_A", {}). Since you already verified the attribute exists, you can use module.lora_A directly. Harmless but slightly inconsistent.

src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py (1)

326-328: F.sigmoid is deprecated; prefer torch.sigmoid.

torch.nn.functional.sigmoid has been deprecated in favor of torch.sigmoid since PyTorch 1.x. While it still works, using the preferred API avoids future deprecation warnings.

Proposed fix
-            shared_expert_gate_output = F.sigmoid(
+            shared_expert_gate_output = torch.sigmoid(
                 self.shared_expert_gate(hidden_states_flat)
             )
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py (2)

109-110: Heuristic-computed NO_K_MASK / NO_N_MASK are unused — kernel recomputes masking locally.

The @triton.heuristics decorator on lines 72–77 computes NO_K_MASK and NO_N_MASK, and they are accepted as kernel parameters (lines 109–110), but the kernel body at line 124 recomputes no_k_mask = K % BLOCK_K == 0 instead of using NO_K_MASK. Additionally, NO_N_MASK is never used at all — N_mask is always computed unconditionally on line 120.

For comparison, the LoRA variant in kernels/lora_ops.py correctly uses the heuristic value (no_k_mask = NO_K_MASK).

Proposed fix
-    no_k_mask = K % BLOCK_K == 0
+    no_k_mask = NO_K_MASK

And consider using NO_N_MASK to skip N_mask in loads where applicable, matching the pattern in _compute_expert_block.

Also applies to: 124-124


224-228: Misleading variable name stride_bk for the N-dimension stride of bias.

Bias has shape [E, N], so its strides are (stride_be, stride_bn). The Python wrapper names the second stride stride_bk, which is confusing since it maps to the stride_bn kernel parameter (line 92). This is only a naming issue — values are positionally correct.

Rename for consistency
     if b is None:
-        b = None
-        stride_be = stride_bk = 0
+        stride_be = stride_bn = 0
     else:
-        stride_be, stride_bk = b.stride()
+        stride_be, stride_bn = b.stride()

And update references on lines 246–247 accordingly.

src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py (3)

101-159: round_expert_counts uses a Python loop with GPU→CPU transfers per expert.

Lines 107–110 and 129–152 iterate over all experts with .item() calls, each incurring a CUDA synchronization. For models with many experts (e.g., 64 in OLMoE), this serializes GPU operations. Since this runs on every backward pass, consider vectorizing with torch tensor ops to eliminate the loop.

Vectorized alternative for the count computation (lines 105-110)
-    counts = torch.zeros(E, dtype=torch.int64, device=device)
-    prev = 0
-    for e in range(E):
-        curr = expert_offsets[e].item()
-        counts[e] = curr - prev
-        prev = curr
+    counts = torch.empty(E, dtype=torch.int64, device=device)
+    counts[0] = expert_offsets[0]
+    counts[1:] = expert_offsets[1:] - expert_offsets[:-1]

The copy loop (lines 129–152) is harder to vectorize due to variable-length segments but could use torch.cat or a small Triton kernel if profiling shows it's a bottleneck.


1087-1118: Backward dA/dB autotune search space is very large (384 configs).

4 × 4 × 4 × 2 × 3 = 384 configs before pruning. The SMEM pruning function helps, but the first kernel launch will incur a significant autotuning cost. The forward (72 configs) and dX (96 configs) are more conservative. Consider reducing the search space or adding warmup / rep tuning to the autotune decorator if cold-start latency is a concern.


1333-1343: Remove unused parameters sorted_scattered_idxs and k from group_bwd_lora.

These parameters are accepted but never used in the function body and are not passed to the kernel. While they exist for API symmetry with group_bwd_lora_fused (which requires both), that design pattern is not necessary here since the caller does not pass them. Removing them clarifies the actual interface.

Remove unused params
 def group_bwd_lora(
     DY: torch.Tensor,
     X: torch.Tensor,
     lora_A: torch.Tensor,
     lora_B: torch.Tensor,
     expert_offsets: torch.Tensor,
     E: int,
     scaling: float,
-    sorted_scattered_idxs: Optional[torch.Tensor] = None,
-    k: int = 1,
 ) -> tuple[torch.Tensor, torch.Tensor]:

Comment on lines +39 to +45
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["mlp_kernel"] = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Warning says "Disabling lora_mlp_kernel" but the code only disables mlp_kernel.

The warning on line 43 tells the user lora_mlp_kernel is being disabled, but data["lora_mlp_kernel"] is never set to False — only data["mlp_kernel"] is. If both should be disabled, add the missing assignment; otherwise, fix the warning text.

Proposed fix (if both should be disabled)
     def disable_mlp_kernel_scattermoe(cls, data):
         if data.get("use_scattermoe") is True:
             if data.get("lora_mlp_kernel") is True:
                 LOG.warning(
                     "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
                 )
+                data["lora_mlp_kernel"] = False
             data["mlp_kernel"] = False
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/args.py` around lines 39 - 45, The warning
in disable_mlp_kernel_scattermoe refers to "Disabling lora_mlp_kernel" but the
code only sets data["mlp_kernel"] = False; either set data["lora_mlp_kernel"] =
False as well inside disable_mlp_kernel_scattermoe when
data.get("lora_mlp_kernel") is True, or change the LOG.warning text to
accurately state that mlp_kernel is being disabled; update the message and/or
assignment in the disable_mlp_kernel_scattermoe method accordingly.

Comment on lines +56 to +62
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
x = tl.load(X_blk_ptrs)
w = tl.load(W_blk_ptrs)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing out-of-bounds masking for boundary blocks.

When K is not a multiple of BLOCK_K (or N not a multiple of BLOCK_N), tl.load reads past the valid range without a mask. This produces garbage values in the accumulator for boundary tiles. After switching the grid to triton.cdiv, the N boundary case becomes reachable too.

Consider adding masks similar to other kernels in this codebase:

k_mask = K_block[:, None] < K
n_mask = N_block[None, :] < N
x = tl.load(X_blk_ptrs, mask=k_mask, other=0.0)
w = tl.load(W_blk_ptrs, mask=k_mask & n_mask, other=0.0)

and updating K_block after each iteration for the mask to stay correct.

🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py`
around lines 56 - 62, The accumulation loop in the single-kernel (acc, loop over
_K_block_id) uses tl.load on X_blk_ptrs and W_blk_ptrs without bounds masks, so
boundary tiles produce garbage when K or N are not multiples of BLOCK_K/BLOCK_N;
add boolean masks (e.g., k_mask = K_block[:, None] < K and n_mask =
N_block[None, :] < N) and pass them to tl.load (use other=0.0) for X and W (W
uses k_mask & n_mask), and ensure K_block (and any N_block) is updated each
iteration after advancing X_blk_ptrs/W_blk_ptrs so the mask stays correct.
Ensure acc uses only masked loads so out-of-bounds elements contribute zero.

Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = ydim // BLOCK_N, k

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use ceiling division for grid computation — truncation drops trailing N-blocks.

ydim // BLOCK_N silently drops the last block when ydim is not a multiple of BLOCK_N. Every analogous grid in ops.py uses triton.cdiv. This will produce silently incorrect (incomplete) output for any model whose expert output dimension is not a multiple of 128.

Proposed fix
-    grid = ydim // BLOCK_N, k
+    grid = triton.cdiv(ydim, BLOCK_N), k
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py` at
line 74, The grid calculation uses floor division which drops trailing N-sized
blocks when ydim % BLOCK_N != 0; replace the use of integer floor division for
computing grid with a ceiling division (use triton.cdiv or equivalent) so grid =
triton.cdiv(ydim, BLOCK_N), k to match ops.py behavior and ensure the final
partial block is included for functions referencing grid, ydim, and BLOCK_N in
single.py.

Comment thread src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
Comment thread src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
Comment on lines +62 to +72
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
"""Attach LoRA parameters for fused computation."""
self._lora_A = lora_A
self._lora_B = lora_B
self._lora_scaling = scaling

def clear_lora(self):
"""Remove LoRA parameters."""
self._lora_A = None
self._lora_B = None
self._lora_scaling = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

rg -n 'set_lora|clear_lora' --type=py -C2

Repository: axolotl-ai-cloud/axolotl

Length of output: 2745


🏁 Script executed:

cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 3755


🏁 Script executed:

rg -n 'ParallelExperts' --type=py -C2

Repository: axolotl-ai-cloud/axolotl

Length of output: 4422


🏁 Script executed:

cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py | head -100

Repository: axolotl-ai-cloud/axolotl

Length of output: 4023


🏁 Script executed:

cat -n src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 20295


🏁 Script executed:

rg -n 'set_lora\(' --type=py -C3

Repository: axolotl-ai-cloud/axolotl

Length of output: 1549


🏁 Script executed:

rg -n 'clear_lora\(' --type=py -C3

Repository: axolotl-ai-cloud/axolotl

Length of output: 738


🏁 Script executed:

rg -n '\.set_lora|\.clear_lora' --type=py

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

rg -n 'state_dict|_load_from_state_dict|to\(device' src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Check if there's any external code that might call set_lora/clear_lora
rg -rn 'import.*ParallelExperts' --type=py

Repository: axolotl-ai-cloud/axolotl

Length of output: 150


🏁 Script executed:

# Check if ParallelExperts is ever actually instantiated and these methods called
rg -n 'ParallelExperts\(' --type=py -C2

Repository: axolotl-ai-cloud/axolotl

Length of output: 1040


🏁 Script executed:

rg -n '_lora_A|_lora_B|_lora_scaling' --type=py -C2

Repository: axolotl-ai-cloud/axolotl

Length of output: 5729


🏁 Script executed:

grep -r "getattr.*set_lora\|hasattr.*set_lora" --include="*.py" src/

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


These set_lora / clear_lora methods appear unused — LoRA parameters are managed at the call site, not stored on the module.

The plain attributes _lora_A, _lora_B, and _lora_scaling won't be saved/restored by state_dict() or moved by .to(device) / .half(). However, this is likely not a concern: the actual LoRA integration in HFScatterMoEGatedMLP.forward() (in layers.py) extracts LoRA parameters from PEFT wrappers at runtime and passes them directly as function arguments to parallel_linear_lora(), bypassing these setter methods entirely. No code path calls set_lora() or clear_lora().

🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py` around
lines 62 - 72, The set_lora / clear_lora helpers are dead code because
HFScatterMoEGatedMLP.forward passes LoRA tensors directly to
parallel_linear_lora and never uses _lora_A/_lora_B/_lora_scaling; either remove
these methods and the unused attributes to avoid confusion, or if you intend to
store tensors on the module make them proper module buffers/parameters (use
self.register_buffer for tensors and ensure they are handled in
state_dict()/to()/half()); update or delete set_lora, clear_lora and the
attributes _lora_A/_lora_B/_lora_scaling accordingly so the code matches
HFScatterMoEGatedMLP.forward and parallel_linear_lora usage.

expert_offsets,
lora_A=self._lora_A,
lora_B=self._lora_B,
scaling=self._lora_scaling or 1.0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

self._lora_scaling or 1.0 silently overrides an explicit scaling=0.0.

0.0 is falsy in Python, so if someone intentionally sets scaling=0.0 (to disable the LoRA delta), this expression evaluates to 1.0 instead. Use self._lora_scaling if self._lora_scaling is not None else 1.0 to distinguish None (unset) from 0.0 (intentionally zero).

Proposed fix
-            scaling=self._lora_scaling or 1.0,
+            scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
scaling=self._lora_scaling or 1.0,
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py` at line
94, The code currently uses "scaling=self._lora_scaling or 1.0" which treats 0.0
as falsy and overrides an explicit zero; change the assignment to use an
explicit None check so that 0.0 is preserved (e.g., use self._lora_scaling if
self._lora_scaling is not None else 1.0) in the place where scaling is passed
(look for the scaling= argument in lora_ops.py, e.g., inside the method creating
the LoRA-wrapped module or function that references self._lora_scaling) so that
None defaults to 1.0 but 0.0 remains a valid value.

output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added recently to solve some saving issue . Do the changes below solve it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did a quick pass on this, .clone() may be unintentionally placing tensors on GPU.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This was a cleanup from changes upstream.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just added by ved a week ago to fix saving in context parallelism 97a4f28

@winglian winglian merged commit 68f1b70 into main Feb 24, 2026
19 of 22 checks passed
@winglian winglian deleted the scattermoe-lora branch February 24, 2026 19:59
@coderabbitai coderabbitai Bot mentioned this pull request Mar 3, 2026
13 tasks
@coderabbitai coderabbitai Bot mentioned this pull request Apr 30, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants