update lora kernels docs#3186
Conversation
📝 WalkthroughWalkthroughUpdates documentation to reflect FSDP2 support and revised LoRA dropout/bias policy. Simplifies kernel patch eligibility to require only lora_dropout: 0. Refactors tests to validate the new condition and confirm patches apply with bias enabled. No public APIs changed. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (7)
src/axolotl/monkeypatch/lora_kernels.py (4)
325-328: Clarify Note: also call out DoRA and multi‑adapter limitations.The code skips patching when dropout is non‑zero, and per‑layer patching also rejects DoRA (lora_magnitude_vector). Document both, plus the single‑adapter constraint, to reduce confusion.
Apply this doc tweak:
- Note: - The optimizations require LoRA adapters with no dropout. The function will skip - patching if that condition isn't met. + Note: + - The optimizations require LoRA adapters with no dropout (`lora_dropout: 0`); + otherwise, patching is skipped. + - DoRA (`lora_magnitude_vector`) is not currently supported; affected projections + will log a warning and remain unpatched. + - Multiple active adapters are not supported (single active adapter required).
343-349: Include actual dropout value in warning.Improves debuggability when patches are skipped.
- if not can_patch: - LOG.warning("Cannot patch layers - requires `lora_dropout: 0`") + if not can_patch: + LOG.warning( + f"Cannot patch layers - requires `lora_dropout: 0` (got {lora_config.lora_dropout})" + ) LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file") return model
350-354: Ensure log level is restored on exceptions.Wrap the temporary log level change in a try/finally to avoid leaking INFO level on early errors (e.g., unsupported activation).
Example:
original_level = LOG.getEffectiveLevel() try: LOG.setLevel(logging.INFO) # ... existing patching logic ... finally: LOG.setLevel(original_level)Also applies to: 440-442
236-246: Type hint nit: yielded type is nn.Module, not Tuple[nn.Module].The generator yields a single module. Adjust the annotation.
-def find_self_attn_in_layer( - layer: nn.Module, -) -> Generator[Tuple[nn.Module], None, None]: +def find_self_attn_in_layer( + layer: nn.Module, +) -> Generator[nn.Module, None, None]:docs/lora_optims.qmd (1)
96-101: Document DoRA limitation and single‑adapter constraint.Code rejects DoRA per‑layer and enforces a single active adapter. Add bullets so users don’t discover this at runtime.
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`) - This may limit model expressivity - Adapters that already include bias terms are supported. + - DoRA (`lora_magnitude_vector`) is not supported; affected projections will remain unpatched. + - Multiple active adapters are not supported (only one active LoRA adapter).tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py (2)
224-247: Make the test offline and strengthen assertions.Avoid network/model downloads by using the existing small LLaMA fixture. Also assert that attention methods weren’t injected when skipping patches.
-def test_kernel_patch_requires_zero_dropout(): +def test_kernel_patch_requires_zero_dropout(small_llama_model): """Kernel patching should be skipped when dropout is enabled.""" config = { "peft_type": "LORA", "task_type": "CAUSAL_LM", "r": 8, "lora_alpha": 16, "target_modules": ["gate_proj", "up_proj", "down_proj"], "lora_dropout": 0.1, "bias": "none", } - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") - peft_config = get_peft_config(config) - model = PeftModelForCausalLM(model, peft_config) + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(small_llama_model, peft_config) cfg = DictDefault({"lora_mlp_kernel": True}) patched_model = apply_lora_kernel_patches(model, cfg) layer = patched_model.model.model.layers[0].mlp # Verify no patches applied when dropout is non-zero assert layer.forward.__func__ is not apply_lora_mlp_swiglu assert layer.forward.__func__ is not apply_lora_mlp_geglu + # Attention should also remain unmodified (methods not injected) + self_attn = patched_model.model.model.layers[0].self_attn + assert not hasattr(self_attn, "apply_qkv") + assert not hasattr(self_attn, "apply_o")
249-271: Same: avoid downloads; rely on the small fixture.Keeps the suite faster and more reliable while still validating bias support.
-def test_kernel_patch_with_bias_enabled(): +def test_kernel_patch_with_bias_enabled(small_llama_model): """Kernel patching should succeed when LoRA bias is enabled.""" config = { "peft_type": "LORA", "task_type": "CAUSAL_LM", "r": 8, "lora_alpha": 16, "target_modules": ["gate_proj", "up_proj", "down_proj"], "lora_dropout": 0, "bias": "lora_only", } - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") - peft_config = get_peft_config(config) - model = PeftModelForCausalLM(model, peft_config) + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(small_llama_model, peft_config) cfg = DictDefault({"lora_mlp_kernel": True}) patched_model = apply_lora_kernel_patches(model, cfg) layer = patched_model.model.model.layers[0].mlp # Verify patches applied when bias support is enabled assert layer.forward.__func__ is apply_lora_mlp_swiglu
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
docs/lora_optims.qmd(3 hunks)src/axolotl/monkeypatch/lora_kernels.py(2 hunks)tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py (3)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/monkeypatch/lora_kernels.py (1)
apply_lora_kernel_patches(303-442)src/axolotl/kernels/lora.py (5)
forward(133-211)forward(484-547)forward(739-773)apply_lora_mlp_swiglu(389-429)apply_lora_mlp_geglu(432-471)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
🔇 Additional comments (2)
docs/lora_optims.qmd (2)
8-12: LGTM: FSDP2 mention and clearer kernel overview.The expanded scope and phrasing look good.
131-135: LGTM: “Support for dropout” future work aligns with code.Matches the current gating in monkeypatch logic.
|
📖 Documentation Preview: https://68d5723b14499b64527dbcdf--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 3299f18 |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
ignore, not quite right |
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)