Lora kernels bias support#3025
Conversation
|
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughThe changes extend LoRA kernel support to handle bias tensors alongside weights during parameter extraction and matrix multiplications. All relevant LoRA autograd functions and application methods are updated to accept, propagate, and utilize bias tensors. Additionally, the LoRA kernel patching logic is relaxed to allow patching even when biases are present, updating related eligibility checks and warning messages. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
78e6c29 to
ca2164a
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/kernels/lora.py (1)
546-546: Remove unused ctx.biases assignment.The
ctx.biasestuple is saved in the forward pass but never used in the backward pass. This unnecessary storage can be removed to save memory.- ctx.biases = (q_bias, k_bias, v_bias)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/kernels/lora.py(26 hunks)src/axolotl/monkeypatch/lora_kernels.py(3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/kernels/lora.py (1)
src/axolotl/kernels/quantize.py (1)
dequantize(20-150)
🔇 Additional comments (12)
src/axolotl/monkeypatch/lora_kernels.py (3)
402-403: LGTM! Warning message correctly updated for bias support.The updated warning message accurately reflects the new requirements - LoRA adapters and no DoRA support, without the previous bias restriction.
420-421: LGTM! Consistent warning message update.The output projection warning is consistently updated to match the QKV warning format.
437-438: LGTM! MLP warning message properly updated.The MLP layer warning message is consistently updated to remove the bias restriction.
src/axolotl/kernels/lora.py (9)
25-80: LGTM! Bias extraction properly implemented in get_lora_parameters.The function now correctly extracts and returns the bias tensor from the base layer, with proper handling for both cases (with and without adapters).
82-127: LGTM! Bias handling correctly integrated in matmul_lora.The function properly applies bias after the matrix multiplication and LoRA operations, which is the correct order. The parameter reordering to group weight and bias together improves readability.
216-244: Verify: Bias gradients are not computed in backward pass.The backward method returns
Nonefor all bias gradients (lines 223, 224, 230, 236, 243). This means biases won't be updated during training. Please confirm this is intentional - typically in LoRA, base model parameters (including biases) are frozen while only LoRA parameters are trained.
391-431: LGTM! Bias parameters correctly propagated in SwiGLU MLP.The function properly extracts bias tensors and passes them to the LoRA_MLP autograd function.
434-473: LGTM! Consistent bias handling in GEGLU MLP.The GEGLU variant maintains consistency with the SwiGLU implementation for bias handling.
551-691: Consistent with MLP: bias gradients not computed.Similar to LoRA_MLP, the backward pass returns
Nonefor all bias gradients. This aligns with the typical LoRA training approach where base model parameters remain frozen.
694-733: LGTM! QKV bias parameters properly handled.The function correctly extracts and passes bias tensors to the LoRA_QKV autograd function.
736-824: LGTM! Output projection bias support implemented consistently.The LoRA_O class follows the same pattern as LoRA_QKV and LoRA_MLP, with proper bias propagation in forward and None gradients in backward.
826-839: LGTM! Output projection application with bias support.The function correctly handles bias extraction and propagation.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/kernels/lora.py (1)
82-105: Update documentation to include bias parameterThe function signature includes a
bparameter for bias, but it's not documented in the Args section.Args: X: Input tensor [*, in_features] W: Base weight matrix [out_features, in_features] + b: Optional bias tensor [out_features] W_quant: Quantization state for W A: LoRA A matrix [rank, in_features] B: LoRA B matrix [out_features, rank] s: LoRA scaling factor out: Optional output tensor for inplace operations
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/kernels/lora.py(26 hunks)src/axolotl/monkeypatch/lora_kernels.py(3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/kernels/lora.py (1)
src/axolotl/kernels/quantize.py (1)
dequantize(20-150)
🔇 Additional comments (11)
src/axolotl/monkeypatch/lora_kernels.py (2)
420-421: LGTM!The warning message correctly reflects the updated requirements for patching output projections.
437-438: LGTM!The warning message correctly reflects the updated requirements for patching MLP layers.
src/axolotl/kernels/lora.py (9)
25-79: LGTM! Bias extraction properly implementedThe function correctly extracts and returns the bias tensor in all code paths, with appropriate documentation updates.
135-213: LGTM! Forward pass correctly handles biasesThe forward method properly accepts and propagates bias tensors through all projection computations.
391-473: LGTM! MLP application functions correctly handle biasesBoth
apply_lora_mlp_swigluandapply_lora_mlp_gegluconsistently extract and propagate bias tensors for all projections.
486-549: LGTM! QKV forward pass properly handles biasesThe forward method correctly accepts, propagates, and saves bias tensors for all three projections.
551-691: LGTM! QKV backward pass correctly handles gradientsThe backward method properly returns None for bias gradients and includes necessary type annotations.
694-733: LGTM! QKV application function correctly handles biasesThe function consistently extracts and propagates bias tensors for all three projections.
739-775: LGTM! Output projection forward pass handles bias correctlyThe forward method properly accepts and propagates the bias tensor.
777-839: LGTM! Output projection backward and application correctly handle biasesBoth the backward method and
apply_lora_ofunction properly handle bias tensors consistently with other projections.
349-351: dequantize correctly restores original weight orientation, so passing.t()is unnecessaryThe
dequantizefunction always returns a tensor of shapequant_state.shape(the original weight shape) and only re-transposes internally for 1×N inputs. Whether you calldequantize(W.t(), quant)or
dequantize(W, quant)the output will match
quant_state.shape, making explicit transposes redundant. The change at line 349 to drop.t()and usegate_weight = dequantize(gate_weight, gate_quant) dX += grad_gate @ gate_weightis therefore correct and preserves the intended dimensions. No further action required.
Likely an incorrect or invalid review comment.
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Bug Fixes
Chores
Tests