feature: Mixlora addition#3535
Conversation
…models, including router, experts, FFN patching, and related schemas.
- #1: Rename original_ffn -> base_ffn (fixes AttributeError in tests) - axolotl-ai-cloud#2: Fix device mismatch in collect_mixlora_aux_loss zero tensor - axolotl-ai-cloud#3: Remove > 0 guard on aux loss (always add and log) - axolotl-ai-cloud#4: Fix or-fallback bug for falsy numeric config values - axolotl-ai-cloud#5: Remove unnecessary .clone() in MixLoraFFN.forward - axolotl-ai-cloud#6: Fix MixLoraExpert.forward return type annotation - axolotl-ai-cloud#7: Update comments for down_proj delta computation - axolotl-ai-cloud#8: Add inference guard (eval mode) for MixLoRA in adapter loader - axolotl-ai-cloud#9: Replace @pytest.mark.skip with @pytest.mark.slow - axolotl-ai-cloud#10: Add ge/gt field constraints and top_k <= num_experts validator - axolotl-ai-cloud#11: Block flash_attn_fuse_qkv with mixlora - axolotl-ai-cloud#12: Add mixlora_state_dict/load_mixlora_state_dict for checkpointing - axolotl-ai-cloud#13: Fix typo intermmediate -> intermediate
📝 WalkthroughWalkthroughAdds a MixLoRA integration: router/expert MoE FFN components, patching to replace SwiGLU FFNs with MixLoRA wrappers, loader/plugin/trainer support, config/schema updates, auxiliary-loss collection, tests, and README documentation. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
src/axolotl/integrations/mixlora/loss.py (1)
30-61: Implementation is correct, but document the single-step assumption.The function correctly traverses model wrappers, collects auxiliary losses from all
MixLoraFFNmodules, computes the mean, and resets the accumulators. The device-aligned zero tensor fallback is a good defensive practice.However, the docstring states "called once per training step, after the forward pass" but doesn't clarify behavior under gradient accumulation. Since
MixLoraFFN.forward()overwrites (rather than accumulates)_aux_loss, only the final micro-batch's auxiliary loss will be captured when gradient accumulation is used.Consider either:
- Documenting this limitation explicitly, or
- Modifying
MixLoraFFNto accumulate losses across micro-batches📝 Suggested docstring clarification
def collect_mixlora_aux_loss( model: torch.nn.Module, router_aux_loss_coef: float = 0.01, ) -> torch.Tensor: """Collect and reset auxiliary load-balance losses from all MixLoRA FFN blocks. This function should be called once per training step, after the forward pass. It walks all MixLoraFFN modules, collects their accumulated auxiliary losses, computes the mean, scales by the coefficient, and resets the accumulators. + Note: Under gradient accumulation, only the last micro-batch's auxiliary loss + is captured since MixLoraFFN.forward() overwrites (not accumulates) aux_loss. + Args: model: The model (may be wrapped in PeftModel, DataParallel, etc.). router_aux_loss_coef: Coefficient for the auxiliary loss.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/loss.py` around lines 30 - 61, The docstring of collect_mixlora_aux_loss should explicitly state its single-step/micro-batch assumption because MixLoraFFN.forward() overwrites module.aux_loss (not accumulates) so with gradient accumulation only the last micro-batch's aux loss will be collected; fix by either (A) updating the collect_mixlora_aux_loss docstring to note that it must be called once per micro-batch (or that MixLoraFFN.aux_loss is overwritten between micro-batches), referencing collect_mixlora_aux_loss and MixLoraFFN.aux_loss/reset_aux_loss, or (B) change MixLoraFFN.forward to accumulate into aux_loss (e.g., add to existing value) and ensure reset_aux_loss clears the accumulator so collect_mixlora_aux_loss can still be called once per optimizer step—pick one approach and implement the corresponding docstring or code changes.tests/integrations/test_mixlora.py (1)
127-127: Add trailing newline.The file is missing a trailing newline at the end, which is a common style convention for text files.
📝 Add trailing newline
router_aux_loss_coef=mock_cfg.mixlora_router_aux_loss_coef ) assert aux_loss.item() > 0 +🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integrations/test_mixlora.py` at line 127, Add a trailing newline at the end of the file so the final line "assert aux_loss.item() > 0" is followed by a single newline character; open the test file containing that assertion (tests/integrations/test_mixlora.py) and ensure the file ends with a newline to satisfy POSIX/text-file conventions and linters.src/axolotl/integrations/mixlora/model.py (2)
184-223: Docstring inconsistency: return type mismatch.The docstring at lines 198-199 states the return is
"[T_expert, hidden_dim] expert output"but the method actually returns a tuple(intermediate, down_delta). The type annotation is correct.📝 Fix docstring
Returns: - [T_expert, hidden_dim] expert output. + Tuple of: + - intermediate: [T_expert, intermediate_dim] activated intermediate. + - down_delta: [T_expert, hidden_dim] LoRA delta for down_proj.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/model.py` around lines 184 - 223, The forward method's docstring is inaccurate: it claims a single "[T_expert, hidden_dim] expert output" while forward actually returns a tuple (intermediate, down_delta). Update the forward docstring for the method forward to describe the two returned tensors (intermediate: the SwiGLU-activated intermediate of shape [T_expert, intermediate_dim] or appropriate shape, and down_delta: the LoRA down-projection delta of shape [T_expert, hidden_dim]), and clarify that the final down projection is applied elsewhere (MixLoraFFN); ensure the return section matches the function signature and types.
331-335: Unused variables from unpacking.Static analysis (Ruff RUF059) flags
batch_sizeandseq_lenas unused. Since only the shape is needed to detect 3D input andhidden_dimfor reshaping:♻️ Optional cleanup
if x.dim() == 3: - batch_size, seq_len, hidden_dim = x.shape - x_flat = x.reshape(-1, hidden_dim) + x_flat = x.reshape(-1, x.shape[-1]) else: x_flat = x🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/model.py` around lines 331 - 335, The unpacking assigns unused variables (batch_size, seq_len) causing lint warnings; instead obtain only the last dimension and reshape accordingly. In the block where x.dim() == 3 (variable x in model.py), replace "batch_size, seq_len, hidden_dim = x.shape" with either "hidden_dim = x.shape[-1]" or use an underscore for unused elements ("_, _, hidden_dim = x.shape"), then call x_flat = x.reshape(-1, hidden_dim) so only the needed hidden_dim is referenced.src/axolotl/integrations/mixlora/patching.py (1)
147-153: Consider defensive check for parameter-less modules.
next(original_ffn.parameters())will raiseStopIterationif the module has no parameters. While this is unlikely for valid SwiGLU FFNs withnn.Linearlayers, a defensive approach would be safer.🛡️ Optional defensive fix
# Move to the same device/dtype as the original - device = next(original_ffn.parameters()).device - dtype = next(original_ffn.parameters()).dtype + param_iter = iter(original_ffn.parameters()) + first_param = next(param_iter, None) + if first_param is None: + LOG.warning(f"MixLoRA: FFN {attr_name} has no parameters, skipping device placement") + continue + device = first_param.device + dtype = first_param.dtype🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/patching.py` around lines 147 - 153, Check for modules without parameters before calling next(...) on original_ffn: obtain first_param = next(original_ffn.parameters(), None) and if it's None, skip the device/dtype lookup and avoid moving mixlora_ffn.router/expert tensors (or return early), otherwise set device = first_param.device and dtype = first_param.dtype and then .to(...) the mixlora_ffn.router and mixlora_ffn.experts; reference original_ffn, mixlora_ffn.router, and mixlora_ffn.experts when making the guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/trainers/base.py`:
- Around line 399-413: compute_loss currently never invokes
_add_mixlora_aux_loss so MixLoRA router auxiliary loss is never added; update
the training loss path by calling self._add_mixlora_aux_loss(loss, model)
immediately after obtaining loss from super().compute_loss(...) inside
compute_loss so the aux loss is included and logged via store_metrics; also fix
MixLoraFFN.forward to accumulate per-microbatch auxiliary loss instead of
overwriting (use += into self._aux_loss and ensure it is zeroed at start of an
accumulation cycle or after optimizer step) so gradient-accumulated micro-steps
correctly aggregate the router aux loss.
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 350-385: The code assigns final_output = base_output by reference
which causes in-place updates to base_output when accumulating expert deltas in
the experts loop (for self.experts), corrupting subsequent expert_delta
computations when top_k > 1; fix by making final_output an explicit copy of
base_output (preserving device/dtype) before the loop so updates to
final_output[token_indices] do not modify base_output used in
self.base_ffn.down_proj(...) and expert_delta calculation.
In `@src/axolotl/utils/schemas/peft.py`:
- Around line 299-310: The validator validate_mixlora_top_k only checks when
both fields are explicitly set, so default values (from MIXLORA_DEFAULTS in
patching.py) can create invalid combos; update validation to resolve defaults
before checking or move the check into patch_model_with_mixlora after defaults
are applied. Specifically, in validate_mixlora_top_k (or in
patch_model_with_mixlora) read mixlora_top_k and mixlora_num_experts using the
default resolution logic (reference MIXLORA_DEFAULTS) and then raise the same
ValueError if resolved mixlora_top_k > resolved mixlora_num_experts; ensure the
error message references mixlora_top_k and mixlora_num_experts so users see the
effective values.
---
Nitpick comments:
In `@src/axolotl/integrations/mixlora/loss.py`:
- Around line 30-61: The docstring of collect_mixlora_aux_loss should explicitly
state its single-step/micro-batch assumption because MixLoraFFN.forward()
overwrites module.aux_loss (not accumulates) so with gradient accumulation only
the last micro-batch's aux loss will be collected; fix by either (A) updating
the collect_mixlora_aux_loss docstring to note that it must be called once per
micro-batch (or that MixLoraFFN.aux_loss is overwritten between micro-batches),
referencing collect_mixlora_aux_loss and MixLoraFFN.aux_loss/reset_aux_loss, or
(B) change MixLoraFFN.forward to accumulate into aux_loss (e.g., add to existing
value) and ensure reset_aux_loss clears the accumulator so
collect_mixlora_aux_loss can still be called once per optimizer step—pick one
approach and implement the corresponding docstring or code changes.
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 184-223: The forward method's docstring is inaccurate: it claims a
single "[T_expert, hidden_dim] expert output" while forward actually returns a
tuple (intermediate, down_delta). Update the forward docstring for the method
forward to describe the two returned tensors (intermediate: the SwiGLU-activated
intermediate of shape [T_expert, intermediate_dim] or appropriate shape, and
down_delta: the LoRA down-projection delta of shape [T_expert, hidden_dim]), and
clarify that the final down projection is applied elsewhere (MixLoraFFN); ensure
the return section matches the function signature and types.
- Around line 331-335: The unpacking assigns unused variables (batch_size,
seq_len) causing lint warnings; instead obtain only the last dimension and
reshape accordingly. In the block where x.dim() == 3 (variable x in model.py),
replace "batch_size, seq_len, hidden_dim = x.shape" with either "hidden_dim =
x.shape[-1]" or use an underscore for unused elements ("_, _, hidden_dim =
x.shape"), then call x_flat = x.reshape(-1, hidden_dim) so only the needed
hidden_dim is referenced.
In `@src/axolotl/integrations/mixlora/patching.py`:
- Around line 147-153: Check for modules without parameters before calling
next(...) on original_ffn: obtain first_param = next(original_ffn.parameters(),
None) and if it's None, skip the device/dtype lookup and avoid moving
mixlora_ffn.router/expert tensors (or return early), otherwise set device =
first_param.device and dtype = first_param.dtype and then .to(...) the
mixlora_ffn.router and mixlora_ffn.experts; reference original_ffn,
mixlora_ffn.router, and mixlora_ffn.experts when making the guard.
In `@tests/integrations/test_mixlora.py`:
- Line 127: Add a trailing newline at the end of the file so the final line
"assert aux_loss.item() > 0" is followed by a single newline character; open the
test file containing that assertion (tests/integrations/test_mixlora.py) and
ensure the file ends with a newline to satisfy POSIX/text-file conventions and
linters.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d81e0164-2a70-48c1-a42b-0150f6d6bae1
📒 Files selected for processing (11)
src/axolotl/core/trainers/base.pysrc/axolotl/integrations/mixlora/__init__.pysrc/axolotl/integrations/mixlora/loss.pysrc/axolotl/integrations/mixlora/model.pysrc/axolotl/integrations/mixlora/patching.pysrc/axolotl/loaders/adapter.pysrc/axolotl/loaders/model.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/peft.pysrc/axolotl/utils/schemas/validation.pytests/integrations/test_mixlora.py
Critical: - Wire _add_mixlora_aux_loss into compute_loss (was defined but never called) - Restore .clone() on final_output to fix aliasing bug with top_k > 1 Minor: - Accumulate aux_loss across gradient accumulation micro-batches - Add runtime top_k <= num_experts validation after resolving defaults Nitpicks: - Fix MixLoraExpert.forward docstring to describe tuple return - Remove unused batch_size/seq_len variables in shape unpacking - Add defensive check for parameter-less FFN modules in patching - Document gradient accumulation behavior in collect_mixlora_aux_loss - Ensure test file ends with trailing newline
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/trainers/base.py`:
- Around line 418-420: The auxiliary loss is being stored unconditionally as
train metrics; change the call to store_metrics to choose train vs eval based on
the model's training flag (e.g., use self.model.training) so that when
self.model.training is True you pass "train" and otherwise pass "eval" (update
the call site that currently calls store_metrics({"mixlora_aux_loss":
aux_loss.item()}, train_eval="train") to compute the phase from
self.model.training and pass that value).
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 304-321: The mixlora_state_dict and load_mixlora_state_dict helper
methods are dead/duplicative (router and experts are already saved via normal
nn.Module state_dict recursion); either delete both methods (mixlora_state_dict
and load_mixlora_state_dict) to remove dead code, or, if you intend them as an
explicit API for exporting/importing just MixLoRA weights, keep them but add a
clear docstring explaining their purpose and intended usage, plus update any
public API docs and tests to cover their behavior (and ensure callers use
mixlora_state_dict/load_mixlora_state_dict where intended).
- Around line 294-302: The module-level accumulator _aux_loss (accessed via
aux_loss and reset_aux_loss) is unsafe with activation checkpointing because
MixLoraFFN.forward() mutates it during forward replays; fix by making the
accumulator forward-local or by resetting _aux_loss at the start of the real
forward: inside MixLoraFFN.forward(), clear self._aux_loss (or use a local aux
variable and return/attach it to the output) before any accumulation so
checkpoint replays don't leave stale state that compute_loss() misses. Ensure
compute_loss() still reads the correct per-step aux term (remove reliance on
persistent module state if you convert to a local accumulator).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2463935a-6a83-4856-82a6-11094f8cea2b
📒 Files selected for processing (4)
src/axolotl/core/trainers/base.pysrc/axolotl/integrations/mixlora/loss.pysrc/axolotl/integrations/mixlora/model.pysrc/axolotl/integrations/mixlora/patching.py
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/integrations/mixlora/loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/integrations/mixlora/patching.py
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
…ng, and validation checks
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 185-226: forward in MixLora expert currently hardcodes F.silu for
the SwiGLU activation (seen where intermediate = F.silu(gate_total) * up_total),
which can silently change behavior when patching non-SiLU FFNs; change the
expert to use an activation attribute (e.g., self.activation or self.silu_fn)
supplied/validated from the base MixLoraFFN instead of calling F.silu directly,
and update the patching/compatibility logic in patching.py to verify the base
FFN exposes the same activation (or an activation callable) before patching;
locate usages in the forward method, MixLoraFFN constructor/initializer, and
patching.py activation checks and wire the activation through those symbols so
both base and expert paths use the identical activation function.
- Around line 403-426: In load_mixlora_state_dict: detect and reject extra
MixLoRA blocks in the checkpoint when strict=True by comparing the set of
MixLora module prefixes present in the model to the set of prefixes present in
the state_dict; compute model_prefixes = {module_name for module_name, module in
model.named_modules() if isinstance(module, MixLoraFFN)} and checkpoint_prefixes
= {key.split('.',1)[0] for key in state_dict.keys()}, and if strict and
checkpoint_prefixes - model_prefixes is non-empty raise a KeyError listing the
unexpected prefixes so extra/renamed MixLoRA blocks in the checkpoint are not
silently ignored.
In `@src/axolotl/integrations/mixlora/README.md`:
- Around line 54-61: Remove trailing whitespace from the BibTeX block in
src/axolotl/integrations/mixlora/README.md: edit the `@misc` entry (starting with
li2024mixloraenhancinglarge) and remove any spaces at the ends of lines such as
the author field and the url field so the lines no longer end with whitespace;
ensure the file passes "trim trailing whitespace" checks.
In `@src/axolotl/integrations/mixlora/trainer.py`:
- Around line 62-79: The MixLoRA sidecar is being built from live self.model
parameters after calling super()._save(), causing the saved sidecar to diverge
when a prepared snapshot was passed via the state_dict param; modify _save so it
builds the mixlora state from the provided state_dict when present (falling back
to self.model only if state_dict is None) instead of always calling
mixlora_state_dict(self.model). Locate the _save method and the
mixlora_state_dict usage and change it to derive state from the state_dict
argument (or extract the MixLoRA-related tensors from state_dict using the same
key patterns mixlora_state_dict expects) before converting to cpu and writing
MIXLORA_WEIGHTS_NAME. Ensure super()._save(output_dir=output_dir,
state_dict=state_dict) remains called and that the saved sidecar uses the same
snapshot data that was written by super()._save.
In `@src/axolotl/utils/schemas/peft.py`:
- Around line 241-248: MixLoraConfig currently defines DEFAULT_* class
attributes without typing, so Pydantic v2 treats them as model fields and mypy
widens numeric types; annotate each DEFAULT_* as a typing.ClassVar with the
correct narrow type (e.g., DEFAULT_NUM_EXPERTS and DEFAULT_TOP_K as
ClassVar[int], DEFAULT_ROUTER_AUX_LOSS_COEF, DEFAULT_ROUTER_INIT_RANGE, and
DEFAULT_JITTER_NOISE as ClassVar[float]) to ensure they are class-only constants
and keep type checking correct; update the declarations on MixLoraConfig (and
the analogous DEFAULT_* attributes in the following block mentioned lines
250–263) to use ClassVar annotations.
In `@tests/integrations/test_mixlora.py`:
- Around line 28-33: Replace the re-exported import of MIXLORA_WEIGHTS_NAME from
model.py with a direct import from the constants module: import
MIXLORA_WEIGHTS_NAME from axolotl.integrations.mixlora.constants while keeping
MixLoraFFN, load_mixlora_state_dict, and mixlora_state_dict imported from
model.py so mypy's attr-defined check uses the symbol's original definition.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 02c08a13-b936-46d8-afd4-8407f6c9632b
📒 Files selected for processing (10)
src/axolotl/integrations/mixlora/README.mdsrc/axolotl/integrations/mixlora/__init__.pysrc/axolotl/integrations/mixlora/constants.pysrc/axolotl/integrations/mixlora/model.pysrc/axolotl/integrations/mixlora/patching.pysrc/axolotl/integrations/mixlora/plugin.pysrc/axolotl/integrations/mixlora/trainer.pysrc/axolotl/loaders/adapter.pysrc/axolotl/utils/schemas/peft.pytests/integrations/test_mixlora.py
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/integrations/mixlora/constants.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/loaders/adapter.py
- src/axolotl/integrations/mixlora/patching.py
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/axolotl/integrations/mixlora/trainer.py (1)
93-109: Consider skipping auxiliary loss addition during evaluation.
_add_mixlora_aux_lossis called regardless of whether the model is in training or evaluation mode. Adding the router load-balance loss during validation may artificially inflate validation loss metrics, making training vs. validation comparisons less meaningful. Typically, auxiliary losses are only used during training.♻️ Proposed fix
def _add_mixlora_aux_loss(self, loss, model): """Add MixLoRA router auxiliary load-balance loss if applicable.""" + # Only add auxiliary loss during training + if not model.training: + return loss + coef = getattr(self.axolotl_cfg, "mixlora_router_aux_loss_coef", None) router_aux_loss_coef = ( coef if coef is not None else MIXLORA_DEFAULTS["mixlora_router_aux_loss_coef"] ) aux_loss = collect_mixlora_aux_loss( model, router_aux_loss_coef=router_aux_loss_coef ) loss = loss + aux_loss.to(loss.device) train_eval = "train" if model.training else "eval" self.store_metrics( {"mixlora_aux_loss": aux_loss.item()}, train_eval=train_eval ) return loss🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/trainer.py` around lines 93 - 109, The _add_mixlora_aux_loss function is adding the MixLoRA router aux loss during eval; change it to only compute and add aux loss when model.training is True: check model.training at the start of _add_mixlora_aux_loss and return the original loss (and optionally log nothing or log mixlora_aux_loss as 0) when not training; when training, compute collect_mixlora_aux_loss using mixlora_router_aux_loss_coef (from self.axolotl_cfg or MIXLORA_DEFAULTS), add aux_loss.to(loss.device) to loss, and call self.store_metrics with the aux_loss.item() and correct train_eval ("train").src/axolotl/integrations/mixlora/model.py (1)
278-281: Consider logging a warning when falling back to default activation.The silent fallback to
F.siluwhenact_fnis not callable could mask misconfigured FFN modules. A warning would aid debugging.♻️ Proposed enhancement
# Resolve activation function from the base FFN (default to F.silu) self.activation_fn = getattr(self.base_ffn, "act_fn", F.silu) if not callable(self.activation_fn): + LOG.warning( + f"MixLoraFFN: base_ffn.act_fn is not callable ({type(self.activation_fn)}), " + "falling back to F.silu" + ) self.activation_fn = F.silu🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/model.py` around lines 278 - 281, The code silently falls back to F.silu if self.base_ffn.act_fn is missing or not callable; update the block around resolving self.activation_fn (the getattr on self.base_ffn and the callable check) to log a warning when this fallback is used, e.g., use the module/class logger to warn that base_ffn.act_fn is missing or not callable and that F.silu will be used, then assign F.silu as now; ensure the log includes the actual value/type of self.base_ffn.act_fn for easier debugging.src/axolotl/integrations/mixlora/README.md (1)
6-6: Consider hyphenating "LoRA-based" for grammatical correctness.The phrase "LoRA based Mixture of Experts" should use a hyphen: "LoRA-based Mixture of Experts" since "LoRA-based" is a compound adjective modifying "Mixture."
📝 Suggested fix
-See [MixLoRA: Enhancing Large Language Models Fine-Tuning with LoRA based Mixture of Experts](https://arxiv.org/abs/2404.15159) +See [MixLoRA: Enhancing Large Language Models Fine-Tuning with LoRA-based Mixture of Experts](https://arxiv.org/abs/2404.15159)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/integrations/mixlora/README.md` at line 6, Update the phrase "LoRA based Mixture of Experts" in the README line that references the MixLoRA paper to use the hyphenated compound adjective "LoRA-based Mixture of Experts" so the grammar is correct; locate the exact string "LoRA based Mixture of Experts" in src/axolotl/integrations/mixlora/README.md and replace it with "LoRA-based Mixture of Experts".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 160-164: The __init__ parameter annotation uses the built-in
"callable" which causes mypy errors; import Callable (from collections.abc) and
change the parameter type annotation from "callable" to "Callable" (e.g., update
the activation_fn parameter in the class initializer shown with signature
containing lora_dropout and activation_fn) so mypy recognizes it as a proper
callable type; add the import line at the top of the module and update the
annotation in the constructor signature and any other occurrences of "callable"
in this file.
---
Nitpick comments:
In `@src/axolotl/integrations/mixlora/model.py`:
- Around line 278-281: The code silently falls back to F.silu if
self.base_ffn.act_fn is missing or not callable; update the block around
resolving self.activation_fn (the getattr on self.base_ffn and the callable
check) to log a warning when this fallback is used, e.g., use the module/class
logger to warn that base_ffn.act_fn is missing or not callable and that F.silu
will be used, then assign F.silu as now; ensure the log includes the actual
value/type of self.base_ffn.act_fn for easier debugging.
In `@src/axolotl/integrations/mixlora/README.md`:
- Line 6: Update the phrase "LoRA based Mixture of Experts" in the README line
that references the MixLoRA paper to use the hyphenated compound adjective
"LoRA-based Mixture of Experts" so the grammar is correct; locate the exact
string "LoRA based Mixture of Experts" in
src/axolotl/integrations/mixlora/README.md and replace it with "LoRA-based
Mixture of Experts".
In `@src/axolotl/integrations/mixlora/trainer.py`:
- Around line 93-109: The _add_mixlora_aux_loss function is adding the MixLoRA
router aux loss during eval; change it to only compute and add aux loss when
model.training is True: check model.training at the start of
_add_mixlora_aux_loss and return the original loss (and optionally log nothing
or log mixlora_aux_loss as 0) when not training; when training, compute
collect_mixlora_aux_loss using mixlora_router_aux_loss_coef (from
self.axolotl_cfg or MIXLORA_DEFAULTS), add aux_loss.to(loss.device) to loss, and
call self.store_metrics with the aux_loss.item() and correct train_eval
("train").
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c8765227-c016-4399-9cb0-c5ca52a708e3
📒 Files selected for processing (6)
src/axolotl/integrations/mixlora/README.mdsrc/axolotl/integrations/mixlora/loss.pysrc/axolotl/integrations/mixlora/model.pysrc/axolotl/integrations/mixlora/trainer.pysrc/axolotl/utils/schemas/peft.pytests/integrations/test_mixlora.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/integrations/mixlora/loss.py
- src/axolotl/utils/schemas/peft.py
| lora_dropout: float = 0.0, | ||
| activation_fn: callable = F.silu, | ||
| ): | ||
| super().__init__() | ||
| self.activation_fn = activation_fn |
There was a problem hiding this comment.
Fix mypy error: Use Callable from typing instead of callable.
The pipeline is failing because callable is a built-in function, not a valid type annotation. Use typing.Callable instead.
🐛 Proposed fix
Add the import at the top of the file:
from collections.abc import CallableThen update the parameter type:
def __init__(
self,
hidden_dim: int,
intermediate_dim: int,
lora_r: int,
lora_alpha: int,
lora_dropout: float = 0.0,
- activation_fn: callable = F.silu,
+ activation_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
):
super().__init__()
self.activation_fn = activation_fn🧰 Tools
🪛 GitHub Actions: lint
[error] 160-160: mypy: Function "builtins.callable" is not valid as a type [valid-type]. Perhaps you meant "typing.Callable" instead of "callable"?
🪛 GitHub Actions: Tests
[error] 160-160: mypy: Function "builtins.callable" is not valid as a type [valid-type]. note: Perhaps you meant "typing.Callable" instead of "callable"?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/integrations/mixlora/model.py` around lines 160 - 164, The
__init__ parameter annotation uses the built-in "callable" which causes mypy
errors; import Callable (from collections.abc) and change the parameter type
annotation from "callable" to "Callable" (e.g., update the activation_fn
parameter in the class initializer shown with signature containing lora_dropout
and activation_fn) so mypy recognizes it as a proper callable type; add the
import line at the top of the module and update the annotation in the
constructor signature and any other occurrences of "callable" in this file.
Description
adds mix lora finetuning support. Closes #1880
Motivation and Context
How has this been tested?
yet to be tested
AI Usage Disclaimer
claude
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests