Skip to content

feature: Mixlora addition#3535

Open
OnePunchMonk wants to merge 6 commits into
axolotl-ai-cloud:mainfrom
OnePunchMonk:mixlora
Open

feature: Mixlora addition#3535
OnePunchMonk wants to merge 6 commits into
axolotl-ai-cloud:mainfrom
OnePunchMonk:mixlora

Conversation

@OnePunchMonk

@OnePunchMonk OnePunchMonk commented Mar 22, 2026

Copy link
Copy Markdown
Contributor

Description

adds mix lora finetuning support. Closes #1880

Motivation and Context

How has this been tested?

yet to be tested

AI Usage Disclaimer

claude

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Adds MixLoRA adapter support for expert-routed FFN fine-tuning, router+expert training, and MixLoRA-only checkpoint save/load.
    • New config fields for experts, top-k routing, router aux loss, init range, jitter, and per-expert LoRA overrides.
    • Loader/trainer updated to handle mixlora adapter flows and k-bit/quant preparation.
  • Bug Fixes

    • Validation added to reject incompatible adapter/config combinations.
  • Documentation

    • Added MixLoRA integration guide and configuration examples.
  • Tests

    • New integration tests covering MixLoRA functionality, training, and checkpoint round-trip.

…models, including router, experts, FFN patching, and related schemas.
- #1: Rename original_ffn -> base_ffn (fixes AttributeError in tests)
- axolotl-ai-cloud#2: Fix device mismatch in collect_mixlora_aux_loss zero tensor
- axolotl-ai-cloud#3: Remove > 0 guard on aux loss (always add and log)
- axolotl-ai-cloud#4: Fix or-fallback bug for falsy numeric config values
- axolotl-ai-cloud#5: Remove unnecessary .clone() in MixLoraFFN.forward
- axolotl-ai-cloud#6: Fix MixLoraExpert.forward return type annotation
- axolotl-ai-cloud#7: Update comments for down_proj delta computation
- axolotl-ai-cloud#8: Add inference guard (eval mode) for MixLoRA in adapter loader
- axolotl-ai-cloud#9: Replace @pytest.mark.skip with @pytest.mark.slow
- axolotl-ai-cloud#10: Add ge/gt field constraints and top_k <= num_experts validator
- axolotl-ai-cloud#11: Block flash_attn_fuse_qkv with mixlora
- axolotl-ai-cloud#12: Add mixlora_state_dict/load_mixlora_state_dict for checkpointing
- axolotl-ai-cloud#13: Fix typo intermmediate -> intermediate
@coderabbitai

coderabbitai Bot commented Mar 22, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds a MixLoRA integration: router/expert MoE FFN components, patching to replace SwiGLU FFNs with MixLoRA wrappers, loader/plugin/trainer support, config/schema updates, auxiliary-loss collection, tests, and README documentation.

Changes

Cohort / File(s) Summary
Core MixLoRA implementation
src/axolotl/integrations/mixlora/model.py, src/axolotl/integrations/mixlora/loss.py
New router, expert, and MixLoraFFN classes; aux-loss collection and reset helpers; mixlora_state_dict/load_mixlora_state_dict utilities. Base FFNs are frozen; router/expert weights are trainable.
Patching & integration glue
src/axolotl/integrations/mixlora/patching.py, src/axolotl/integrations/mixlora/__init__.py, src/axolotl/integrations/mixlora/constants.py
FFN detection and in-place replacement with MixLoraFFN; defaults and constant names added; package exports updated. Validates config and device/dtype placement during patching.
Trainer & plugin
src/axolotl/integrations/mixlora/trainer.py, src/axolotl/integrations/mixlora/plugin.py
MixLoraTrainer augments loss with collected router aux loss and saves MixLoRA-only safetensors sidecar. Plugin resolves MixLoRA trainer class when adapter == "mixlora".
Loaders
src/axolotl/loaders/adapter.py, src/axolotl/loaders/model.py
load_adapter gains adapter == "mixlora" branch that validates incompatibilities, applies patching, and optionally loads MixLoRA weights; ModelLoader treats mixlora like lora/qlora for dtype/quant/k-bit prep.
Schemas & validation
src/axolotl/utils/schemas/peft.py, src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/validation.py
Adds MixLoraConfig, extends LoraConfig.adapter to include "mixlora", adds validators (top-k bounds, required LoRA params), includes MixLoraConfig in AxolotlInputConfig, and disallows certain flash-attn fusions with mixlora.
Tests & docs
tests/integrations/test_mixlora.py, src/axolotl/integrations/mixlora/README.md
Integration tests for FFN wrapping, training step, state dict roundtrip, plugin registration, and slow patching test; README documents usage, config fields, limitations, and citation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • winglian
  • djsaunde
  • NanoCode012
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'feature: Mixlora addition' is vague and generic, using non-descriptive terms that don't convey the specific nature of the MixLoRA feature being added to the codebase. Consider a more descriptive title such as 'feature: Add MixLoRA MoE-style LoRA finetuning integration' to better reflect the specific functionality being implemented.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed The PR comprehensively implements the MixLoRA MoE-style finetuning feature requested in issue #1880, including router/expert modules, patching logic, training integration, and configuration schemas.
Out of Scope Changes check ✅ Passed All changes directly support MixLoRA implementation and its integration; no unrelated or out-of-scope modifications were detected in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@OnePunchMonk OnePunchMonk changed the title Mixlora feature: Mixlora addition Mar 22, 2026
@OnePunchMonk

Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Mar 22, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
src/axolotl/integrations/mixlora/loss.py (1)

30-61: Implementation is correct, but document the single-step assumption.

The function correctly traverses model wrappers, collects auxiliary losses from all MixLoraFFN modules, computes the mean, and resets the accumulators. The device-aligned zero tensor fallback is a good defensive practice.

However, the docstring states "called once per training step, after the forward pass" but doesn't clarify behavior under gradient accumulation. Since MixLoraFFN.forward() overwrites (rather than accumulates) _aux_loss, only the final micro-batch's auxiliary loss will be captured when gradient accumulation is used.

Consider either:

  1. Documenting this limitation explicitly, or
  2. Modifying MixLoraFFN to accumulate losses across micro-batches
📝 Suggested docstring clarification
 def collect_mixlora_aux_loss(
     model: torch.nn.Module,
     router_aux_loss_coef: float = 0.01,
 ) -> torch.Tensor:
     """Collect and reset auxiliary load-balance losses from all MixLoRA FFN blocks.
 
     This function should be called once per training step, after the forward pass.
     It walks all MixLoraFFN modules, collects their accumulated auxiliary losses,
     computes the mean, scales by the coefficient, and resets the accumulators.
 
+    Note: Under gradient accumulation, only the last micro-batch's auxiliary loss
+    is captured since MixLoraFFN.forward() overwrites (not accumulates) aux_loss.
+
     Args:
         model: The model (may be wrapped in PeftModel, DataParallel, etc.).
         router_aux_loss_coef: Coefficient for the auxiliary loss.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/loss.py` around lines 30 - 61, The docstring
of collect_mixlora_aux_loss should explicitly state its single-step/micro-batch
assumption because MixLoraFFN.forward() overwrites module.aux_loss (not
accumulates) so with gradient accumulation only the last micro-batch's aux loss
will be collected; fix by either (A) updating the collect_mixlora_aux_loss
docstring to note that it must be called once per micro-batch (or that
MixLoraFFN.aux_loss is overwritten between micro-batches), referencing
collect_mixlora_aux_loss and MixLoraFFN.aux_loss/reset_aux_loss, or (B) change
MixLoraFFN.forward to accumulate into aux_loss (e.g., add to existing value) and
ensure reset_aux_loss clears the accumulator so collect_mixlora_aux_loss can
still be called once per optimizer step—pick one approach and implement the
corresponding docstring or code changes.
tests/integrations/test_mixlora.py (1)

127-127: Add trailing newline.

The file is missing a trailing newline at the end, which is a common style convention for text files.

📝 Add trailing newline
         router_aux_loss_coef=mock_cfg.mixlora_router_aux_loss_coef
     )
     assert aux_loss.item() > 0
+
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integrations/test_mixlora.py` at line 127, Add a trailing newline at
the end of the file so the final line "assert aux_loss.item() > 0" is followed
by a single newline character; open the test file containing that assertion
(tests/integrations/test_mixlora.py) and ensure the file ends with a newline to
satisfy POSIX/text-file conventions and linters.
src/axolotl/integrations/mixlora/model.py (2)

184-223: Docstring inconsistency: return type mismatch.

The docstring at lines 198-199 states the return is "[T_expert, hidden_dim] expert output" but the method actually returns a tuple (intermediate, down_delta). The type annotation is correct.

📝 Fix docstring
         Returns:
-            [T_expert, hidden_dim] expert output.
+            Tuple of:
+              - intermediate: [T_expert, intermediate_dim] activated intermediate.
+              - down_delta: [T_expert, hidden_dim] LoRA delta for down_proj.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/model.py` around lines 184 - 223, The
forward method's docstring is inaccurate: it claims a single "[T_expert,
hidden_dim] expert output" while forward actually returns a tuple (intermediate,
down_delta). Update the forward docstring for the method forward to describe the
two returned tensors (intermediate: the SwiGLU-activated intermediate of shape
[T_expert, intermediate_dim] or appropriate shape, and down_delta: the LoRA
down-projection delta of shape [T_expert, hidden_dim]), and clarify that the
final down projection is applied elsewhere (MixLoraFFN); ensure the return
section matches the function signature and types.

331-335: Unused variables from unpacking.

Static analysis (Ruff RUF059) flags batch_size and seq_len as unused. Since only the shape is needed to detect 3D input and hidden_dim for reshaping:

♻️ Optional cleanup
         if x.dim() == 3:
-            batch_size, seq_len, hidden_dim = x.shape
-            x_flat = x.reshape(-1, hidden_dim)
+            x_flat = x.reshape(-1, x.shape[-1])
         else:
             x_flat = x
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/model.py` around lines 331 - 335, The
unpacking assigns unused variables (batch_size, seq_len) causing lint warnings;
instead obtain only the last dimension and reshape accordingly. In the block
where x.dim() == 3 (variable x in model.py), replace "batch_size, seq_len,
hidden_dim = x.shape" with either "hidden_dim = x.shape[-1]" or use an
underscore for unused elements ("_, _, hidden_dim = x.shape"), then call x_flat
= x.reshape(-1, hidden_dim) so only the needed hidden_dim is referenced.
src/axolotl/integrations/mixlora/patching.py (1)

147-153: Consider defensive check for parameter-less modules.

next(original_ffn.parameters()) will raise StopIteration if the module has no parameters. While this is unlikely for valid SwiGLU FFNs with nn.Linear layers, a defensive approach would be safer.

🛡️ Optional defensive fix
         # Move to the same device/dtype as the original
-        device = next(original_ffn.parameters()).device
-        dtype = next(original_ffn.parameters()).dtype
+        param_iter = iter(original_ffn.parameters())
+        first_param = next(param_iter, None)
+        if first_param is None:
+            LOG.warning(f"MixLoRA: FFN {attr_name} has no parameters, skipping device placement")
+            continue
+        device = first_param.device
+        dtype = first_param.dtype
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/patching.py` around lines 147 - 153, Check
for modules without parameters before calling next(...) on original_ffn: obtain
first_param = next(original_ffn.parameters(), None) and if it's None, skip the
device/dtype lookup and avoid moving mixlora_ffn.router/expert tensors (or
return early), otherwise set device = first_param.device and dtype =
first_param.dtype and then .to(...) the mixlora_ffn.router and
mixlora_ffn.experts; reference original_ffn, mixlora_ffn.router, and
mixlora_ffn.experts when making the guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/core/trainers/base.py`:
- Around line 399-413: compute_loss currently never invokes
_add_mixlora_aux_loss so MixLoRA router auxiliary loss is never added; update
the training loss path by calling self._add_mixlora_aux_loss(loss, model)
immediately after obtaining loss from super().compute_loss(...) inside
compute_loss so the aux loss is included and logged via store_metrics; also fix
MixLoraFFN.forward to accumulate per-microbatch auxiliary loss instead of
overwriting (use += into self._aux_loss and ensure it is zeroed at start of an
accumulation cycle or after optimizer step) so gradient-accumulated micro-steps
correctly aggregate the router aux loss.

In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 350-385: The code assigns final_output = base_output by reference
which causes in-place updates to base_output when accumulating expert deltas in
the experts loop (for self.experts), corrupting subsequent expert_delta
computations when top_k > 1; fix by making final_output an explicit copy of
base_output (preserving device/dtype) before the loop so updates to
final_output[token_indices] do not modify base_output used in
self.base_ffn.down_proj(...) and expert_delta calculation.

In `@src/axolotl/utils/schemas/peft.py`:
- Around line 299-310: The validator validate_mixlora_top_k only checks when
both fields are explicitly set, so default values (from MIXLORA_DEFAULTS in
patching.py) can create invalid combos; update validation to resolve defaults
before checking or move the check into patch_model_with_mixlora after defaults
are applied. Specifically, in validate_mixlora_top_k (or in
patch_model_with_mixlora) read mixlora_top_k and mixlora_num_experts using the
default resolution logic (reference MIXLORA_DEFAULTS) and then raise the same
ValueError if resolved mixlora_top_k > resolved mixlora_num_experts; ensure the
error message references mixlora_top_k and mixlora_num_experts so users see the
effective values.

---

Nitpick comments:
In `@src/axolotl/integrations/mixlora/loss.py`:
- Around line 30-61: The docstring of collect_mixlora_aux_loss should explicitly
state its single-step/micro-batch assumption because MixLoraFFN.forward()
overwrites module.aux_loss (not accumulates) so with gradient accumulation only
the last micro-batch's aux loss will be collected; fix by either (A) updating
the collect_mixlora_aux_loss docstring to note that it must be called once per
micro-batch (or that MixLoraFFN.aux_loss is overwritten between micro-batches),
referencing collect_mixlora_aux_loss and MixLoraFFN.aux_loss/reset_aux_loss, or
(B) change MixLoraFFN.forward to accumulate into aux_loss (e.g., add to existing
value) and ensure reset_aux_loss clears the accumulator so
collect_mixlora_aux_loss can still be called once per optimizer step—pick one
approach and implement the corresponding docstring or code changes.

In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 184-223: The forward method's docstring is inaccurate: it claims a
single "[T_expert, hidden_dim] expert output" while forward actually returns a
tuple (intermediate, down_delta). Update the forward docstring for the method
forward to describe the two returned tensors (intermediate: the SwiGLU-activated
intermediate of shape [T_expert, intermediate_dim] or appropriate shape, and
down_delta: the LoRA down-projection delta of shape [T_expert, hidden_dim]), and
clarify that the final down projection is applied elsewhere (MixLoraFFN); ensure
the return section matches the function signature and types.
- Around line 331-335: The unpacking assigns unused variables (batch_size,
seq_len) causing lint warnings; instead obtain only the last dimension and
reshape accordingly. In the block where x.dim() == 3 (variable x in model.py),
replace "batch_size, seq_len, hidden_dim = x.shape" with either "hidden_dim =
x.shape[-1]" or use an underscore for unused elements ("_, _, hidden_dim =
x.shape"), then call x_flat = x.reshape(-1, hidden_dim) so only the needed
hidden_dim is referenced.

In `@src/axolotl/integrations/mixlora/patching.py`:
- Around line 147-153: Check for modules without parameters before calling
next(...) on original_ffn: obtain first_param = next(original_ffn.parameters(),
None) and if it's None, skip the device/dtype lookup and avoid moving
mixlora_ffn.router/expert tensors (or return early), otherwise set device =
first_param.device and dtype = first_param.dtype and then .to(...) the
mixlora_ffn.router and mixlora_ffn.experts; reference original_ffn,
mixlora_ffn.router, and mixlora_ffn.experts when making the guard.

In `@tests/integrations/test_mixlora.py`:
- Line 127: Add a trailing newline at the end of the file so the final line
"assert aux_loss.item() > 0" is followed by a single newline character; open the
test file containing that assertion (tests/integrations/test_mixlora.py) and
ensure the file ends with a newline to satisfy POSIX/text-file conventions and
linters.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d81e0164-2a70-48c1-a42b-0150f6d6bae1

📥 Commits

Reviewing files that changed from the base of the PR and between 0e583ef and af20d0b.

📒 Files selected for processing (11)
  • src/axolotl/core/trainers/base.py
  • src/axolotl/integrations/mixlora/__init__.py
  • src/axolotl/integrations/mixlora/loss.py
  • src/axolotl/integrations/mixlora/model.py
  • src/axolotl/integrations/mixlora/patching.py
  • src/axolotl/loaders/adapter.py
  • src/axolotl/loaders/model.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/peft.py
  • src/axolotl/utils/schemas/validation.py
  • tests/integrations/test_mixlora.py

Comment thread src/axolotl/core/trainers/base.py Outdated
Comment thread src/axolotl/integrations/mixlora/model.py Outdated
Comment thread src/axolotl/utils/schemas/peft.py
Critical:
- Wire _add_mixlora_aux_loss into compute_loss (was defined but never called)
- Restore .clone() on final_output to fix aliasing bug with top_k > 1

Minor:
- Accumulate aux_loss across gradient accumulation micro-batches
- Add runtime top_k <= num_experts validation after resolving defaults

Nitpicks:
- Fix MixLoraExpert.forward docstring to describe tuple return
- Remove unused batch_size/seq_len variables in shape unpacking
- Add defensive check for parameter-less FFN modules in patching
- Document gradient accumulation behavior in collect_mixlora_aux_loss
- Ensure test file ends with trailing newline
@OnePunchMonk

Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Mar 22, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/core/trainers/base.py`:
- Around line 418-420: The auxiliary loss is being stored unconditionally as
train metrics; change the call to store_metrics to choose train vs eval based on
the model's training flag (e.g., use self.model.training) so that when
self.model.training is True you pass "train" and otherwise pass "eval" (update
the call site that currently calls store_metrics({"mixlora_aux_loss":
aux_loss.item()}, train_eval="train") to compute the phase from
self.model.training and pass that value).

In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 304-321: The mixlora_state_dict and load_mixlora_state_dict helper
methods are dead/duplicative (router and experts are already saved via normal
nn.Module state_dict recursion); either delete both methods (mixlora_state_dict
and load_mixlora_state_dict) to remove dead code, or, if you intend them as an
explicit API for exporting/importing just MixLoRA weights, keep them but add a
clear docstring explaining their purpose and intended usage, plus update any
public API docs and tests to cover their behavior (and ensure callers use
mixlora_state_dict/load_mixlora_state_dict where intended).
- Around line 294-302: The module-level accumulator _aux_loss (accessed via
aux_loss and reset_aux_loss) is unsafe with activation checkpointing because
MixLoraFFN.forward() mutates it during forward replays; fix by making the
accumulator forward-local or by resetting _aux_loss at the start of the real
forward: inside MixLoraFFN.forward(), clear self._aux_loss (or use a local aux
variable and return/attach it to the output) before any accumulation so
checkpoint replays don't leave stale state that compute_loss() misses. Ensure
compute_loss() still reads the correct per-step aux term (remove reliance on
persistent module state if you convert to a local accumulator).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2463935a-6a83-4856-82a6-11094f8cea2b

📥 Commits

Reviewing files that changed from the base of the PR and between af20d0b and c6b568f.

📒 Files selected for processing (4)
  • src/axolotl/core/trainers/base.py
  • src/axolotl/integrations/mixlora/loss.py
  • src/axolotl/integrations/mixlora/model.py
  • src/axolotl/integrations/mixlora/patching.py
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/integrations/mixlora/loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/integrations/mixlora/patching.py

Comment thread src/axolotl/core/trainers/base.py Outdated
Comment thread src/axolotl/integrations/mixlora/model.py
Comment thread src/axolotl/integrations/mixlora/model.py
Comment thread src/axolotl/core/trainers/base.py Outdated
Comment thread src/axolotl/integrations/mixlora/__init__.py
@OnePunchMonk OnePunchMonk marked this pull request as ready for review March 24, 2026 17:47
@codecov

codecov Bot commented Mar 24, 2026

Copy link
Copy Markdown

@OnePunchMonk

Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Mar 25, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 185-226: forward in MixLora expert currently hardcodes F.silu for
the SwiGLU activation (seen where intermediate = F.silu(gate_total) * up_total),
which can silently change behavior when patching non-SiLU FFNs; change the
expert to use an activation attribute (e.g., self.activation or self.silu_fn)
supplied/validated from the base MixLoraFFN instead of calling F.silu directly,
and update the patching/compatibility logic in patching.py to verify the base
FFN exposes the same activation (or an activation callable) before patching;
locate usages in the forward method, MixLoraFFN constructor/initializer, and
patching.py activation checks and wire the activation through those symbols so
both base and expert paths use the identical activation function.
- Around line 403-426: In load_mixlora_state_dict: detect and reject extra
MixLoRA blocks in the checkpoint when strict=True by comparing the set of
MixLora module prefixes present in the model to the set of prefixes present in
the state_dict; compute model_prefixes = {module_name for module_name, module in
model.named_modules() if isinstance(module, MixLoraFFN)} and checkpoint_prefixes
= {key.split('.',1)[0] for key in state_dict.keys()}, and if strict and
checkpoint_prefixes - model_prefixes is non-empty raise a KeyError listing the
unexpected prefixes so extra/renamed MixLoRA blocks in the checkpoint are not
silently ignored.

In `@src/axolotl/integrations/mixlora/README.md`:
- Around line 54-61: Remove trailing whitespace from the BibTeX block in
src/axolotl/integrations/mixlora/README.md: edit the `@misc` entry (starting with
li2024mixloraenhancinglarge) and remove any spaces at the ends of lines such as
the author field and the url field so the lines no longer end with whitespace;
ensure the file passes "trim trailing whitespace" checks.

In `@src/axolotl/integrations/mixlora/trainer.py`:
- Around line 62-79: The MixLoRA sidecar is being built from live self.model
parameters after calling super()._save(), causing the saved sidecar to diverge
when a prepared snapshot was passed via the state_dict param; modify _save so it
builds the mixlora state from the provided state_dict when present (falling back
to self.model only if state_dict is None) instead of always calling
mixlora_state_dict(self.model). Locate the _save method and the
mixlora_state_dict usage and change it to derive state from the state_dict
argument (or extract the MixLoRA-related tensors from state_dict using the same
key patterns mixlora_state_dict expects) before converting to cpu and writing
MIXLORA_WEIGHTS_NAME. Ensure super()._save(output_dir=output_dir,
state_dict=state_dict) remains called and that the saved sidecar uses the same
snapshot data that was written by super()._save.

In `@src/axolotl/utils/schemas/peft.py`:
- Around line 241-248: MixLoraConfig currently defines DEFAULT_* class
attributes without typing, so Pydantic v2 treats them as model fields and mypy
widens numeric types; annotate each DEFAULT_* as a typing.ClassVar with the
correct narrow type (e.g., DEFAULT_NUM_EXPERTS and DEFAULT_TOP_K as
ClassVar[int], DEFAULT_ROUTER_AUX_LOSS_COEF, DEFAULT_ROUTER_INIT_RANGE, and
DEFAULT_JITTER_NOISE as ClassVar[float]) to ensure they are class-only constants
and keep type checking correct; update the declarations on MixLoraConfig (and
the analogous DEFAULT_* attributes in the following block mentioned lines
250–263) to use ClassVar annotations.

In `@tests/integrations/test_mixlora.py`:
- Around line 28-33: Replace the re-exported import of MIXLORA_WEIGHTS_NAME from
model.py with a direct import from the constants module: import
MIXLORA_WEIGHTS_NAME from axolotl.integrations.mixlora.constants while keeping
MixLoraFFN, load_mixlora_state_dict, and mixlora_state_dict imported from
model.py so mypy's attr-defined check uses the symbol's original definition.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 02c08a13-b936-46d8-afd4-8407f6c9632b

📥 Commits

Reviewing files that changed from the base of the PR and between c6b568f and cb9a3b6.

📒 Files selected for processing (10)
  • src/axolotl/integrations/mixlora/README.md
  • src/axolotl/integrations/mixlora/__init__.py
  • src/axolotl/integrations/mixlora/constants.py
  • src/axolotl/integrations/mixlora/model.py
  • src/axolotl/integrations/mixlora/patching.py
  • src/axolotl/integrations/mixlora/plugin.py
  • src/axolotl/integrations/mixlora/trainer.py
  • src/axolotl/loaders/adapter.py
  • src/axolotl/utils/schemas/peft.py
  • tests/integrations/test_mixlora.py
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/integrations/mixlora/constants.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/loaders/adapter.py
  • src/axolotl/integrations/mixlora/patching.py

Comment thread src/axolotl/integrations/mixlora/model.py
Comment thread src/axolotl/integrations/mixlora/model.py
Comment thread src/axolotl/integrations/mixlora/README.md Outdated
Comment thread src/axolotl/integrations/mixlora/trainer.py
Comment thread src/axolotl/utils/schemas/peft.py Outdated
Comment thread tests/integrations/test_mixlora.py
@OnePunchMonk

Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 12, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/axolotl/integrations/mixlora/trainer.py (1)

93-109: Consider skipping auxiliary loss addition during evaluation.

_add_mixlora_aux_loss is called regardless of whether the model is in training or evaluation mode. Adding the router load-balance loss during validation may artificially inflate validation loss metrics, making training vs. validation comparisons less meaningful. Typically, auxiliary losses are only used during training.

♻️ Proposed fix
     def _add_mixlora_aux_loss(self, loss, model):
         """Add MixLoRA router auxiliary load-balance loss if applicable."""
+        # Only add auxiliary loss during training
+        if not model.training:
+            return loss
+
         coef = getattr(self.axolotl_cfg, "mixlora_router_aux_loss_coef", None)
         router_aux_loss_coef = (
             coef if coef is not None else MIXLORA_DEFAULTS["mixlora_router_aux_loss_coef"]
         )

         aux_loss = collect_mixlora_aux_loss(
             model, router_aux_loss_coef=router_aux_loss_coef
         )

         loss = loss + aux_loss.to(loss.device)

         train_eval = "train" if model.training else "eval"
         self.store_metrics(
             {"mixlora_aux_loss": aux_loss.item()}, train_eval=train_eval
         )
         return loss
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/trainer.py` around lines 93 - 109, The
_add_mixlora_aux_loss function is adding the MixLoRA router aux loss during
eval; change it to only compute and add aux loss when model.training is True:
check model.training at the start of _add_mixlora_aux_loss and return the
original loss (and optionally log nothing or log mixlora_aux_loss as 0) when not
training; when training, compute collect_mixlora_aux_loss using
mixlora_router_aux_loss_coef (from self.axolotl_cfg or MIXLORA_DEFAULTS), add
aux_loss.to(loss.device) to loss, and call self.store_metrics with the
aux_loss.item() and correct train_eval ("train").
src/axolotl/integrations/mixlora/model.py (1)

278-281: Consider logging a warning when falling back to default activation.

The silent fallback to F.silu when act_fn is not callable could mask misconfigured FFN modules. A warning would aid debugging.

♻️ Proposed enhancement
         # Resolve activation function from the base FFN (default to F.silu)
         self.activation_fn = getattr(self.base_ffn, "act_fn", F.silu)
         if not callable(self.activation_fn):
+            LOG.warning(
+                f"MixLoraFFN: base_ffn.act_fn is not callable ({type(self.activation_fn)}), "
+                "falling back to F.silu"
+            )
             self.activation_fn = F.silu
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/model.py` around lines 278 - 281, The code
silently falls back to F.silu if self.base_ffn.act_fn is missing or not
callable; update the block around resolving self.activation_fn (the getattr on
self.base_ffn and the callable check) to log a warning when this fallback is
used, e.g., use the module/class logger to warn that base_ffn.act_fn is missing
or not callable and that F.silu will be used, then assign F.silu as now; ensure
the log includes the actual value/type of self.base_ffn.act_fn for easier
debugging.
src/axolotl/integrations/mixlora/README.md (1)

6-6: Consider hyphenating "LoRA-based" for grammatical correctness.

The phrase "LoRA based Mixture of Experts" should use a hyphen: "LoRA-based Mixture of Experts" since "LoRA-based" is a compound adjective modifying "Mixture."

📝 Suggested fix
-See [MixLoRA: Enhancing Large Language Models Fine-Tuning with LoRA based Mixture of Experts](https://arxiv.org/abs/2404.15159)
+See [MixLoRA: Enhancing Large Language Models Fine-Tuning with LoRA-based Mixture of Experts](https://arxiv.org/abs/2404.15159)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/README.md` at line 6, Update the phrase
"LoRA based Mixture of Experts" in the README line that references the MixLoRA
paper to use the hyphenated compound adjective "LoRA-based Mixture of Experts"
so the grammar is correct; locate the exact string "LoRA based Mixture of
Experts" in src/axolotl/integrations/mixlora/README.md and replace it with
"LoRA-based Mixture of Experts".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 160-164: The __init__ parameter annotation uses the built-in
"callable" which causes mypy errors; import Callable (from collections.abc) and
change the parameter type annotation from "callable" to "Callable" (e.g., update
the activation_fn parameter in the class initializer shown with signature
containing lora_dropout and activation_fn) so mypy recognizes it as a proper
callable type; add the import line at the top of the module and update the
annotation in the constructor signature and any other occurrences of "callable"
in this file.

---

Nitpick comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 278-281: The code silently falls back to F.silu if
self.base_ffn.act_fn is missing or not callable; update the block around
resolving self.activation_fn (the getattr on self.base_ffn and the callable
check) to log a warning when this fallback is used, e.g., use the module/class
logger to warn that base_ffn.act_fn is missing or not callable and that F.silu
will be used, then assign F.silu as now; ensure the log includes the actual
value/type of self.base_ffn.act_fn for easier debugging.

In `@src/axolotl/integrations/mixlora/README.md`:
- Line 6: Update the phrase "LoRA based Mixture of Experts" in the README line
that references the MixLoRA paper to use the hyphenated compound adjective
"LoRA-based Mixture of Experts" so the grammar is correct; locate the exact
string "LoRA based Mixture of Experts" in
src/axolotl/integrations/mixlora/README.md and replace it with "LoRA-based
Mixture of Experts".

In `@src/axolotl/integrations/mixlora/trainer.py`:
- Around line 93-109: The _add_mixlora_aux_loss function is adding the MixLoRA
router aux loss during eval; change it to only compute and add aux loss when
model.training is True: check model.training at the start of
_add_mixlora_aux_loss and return the original loss (and optionally log nothing
or log mixlora_aux_loss as 0) when not training; when training, compute
collect_mixlora_aux_loss using mixlora_router_aux_loss_coef (from
self.axolotl_cfg or MIXLORA_DEFAULTS), add aux_loss.to(loss.device) to loss, and
call self.store_metrics with the aux_loss.item() and correct train_eval
("train").
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c8765227-c016-4399-9cb0-c5ca52a708e3

📥 Commits

Reviewing files that changed from the base of the PR and between cb9a3b6 and 76f15a3.

📒 Files selected for processing (6)
  • src/axolotl/integrations/mixlora/README.md
  • src/axolotl/integrations/mixlora/loss.py
  • src/axolotl/integrations/mixlora/model.py
  • src/axolotl/integrations/mixlora/trainer.py
  • src/axolotl/utils/schemas/peft.py
  • tests/integrations/test_mixlora.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/integrations/mixlora/loss.py
  • src/axolotl/utils/schemas/peft.py

Comment on lines +160 to +164
lora_dropout: float = 0.0,
activation_fn: callable = F.silu,
):
super().__init__()
self.activation_fn = activation_fn

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix mypy error: Use Callable from typing instead of callable.

The pipeline is failing because callable is a built-in function, not a valid type annotation. Use typing.Callable instead.

🐛 Proposed fix

Add the import at the top of the file:

from collections.abc import Callable

Then update the parameter type:

     def __init__(
         self,
         hidden_dim: int,
         intermediate_dim: int,
         lora_r: int,
         lora_alpha: int,
         lora_dropout: float = 0.0,
-        activation_fn: callable = F.silu,
+        activation_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
     ):
         super().__init__()
         self.activation_fn = activation_fn
🧰 Tools
🪛 GitHub Actions: lint

[error] 160-160: mypy: Function "builtins.callable" is not valid as a type [valid-type]. Perhaps you meant "typing.Callable" instead of "callable"?

🪛 GitHub Actions: Tests

[error] 160-160: mypy: Function "builtins.callable" is not valid as a type [valid-type]. note: Perhaps you meant "typing.Callable" instead of "callable"?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/mixlora/model.py` around lines 160 - 164, The
__init__ parameter annotation uses the built-in "callable" which causes mypy
errors; import Callable (from collections.abc) and change the parameter type
annotation from "callable" to "Callable" (e.g., update the activation_fn
parameter in the class initializer shown with signature containing lora_dropout
and activation_fn) so mypy recognizes it as a proper callable type; add the
import line at the top of the module and update the annotation in the
constructor signature and any other occurrences of "callable" in this file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MixLoRA finetuning

2 participants