Skip to content

feat: LoRA kernel support for bias, dropout, dora, embeddings#3528

Merged
winglian merged 6 commits into
mainfrom
lora-kernels-v2
Mar 22, 2026
Merged

feat: LoRA kernel support for bias, dropout, dora, embeddings#3528
winglian merged 6 commits into
mainfrom
lora-kernels-v2

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Mar 22, 2026

Description

  • Bias support for LoRA kernels: bias="lora_only" and bias="all" now work with lora_*_kernel=True
  • Dropout support: lora_dropout > 0 now works with kernel patches (shared mask across fused projections for efficiency)
  • DoRA support: Weight-Decomposed LoRA (peft_use_dora: true) now works with all kernel types, with weight norm caching to minimize overhead
  • Embedding LoRA kernel (lora_embedding_kernel): New fused autograd function for nn.Embedding layers with LoRA, including DoRA support (7-9x speedup for embedding DoRA)
  • Triton DoRA kernel: Fused row-wise weight norm computation that avoids materializing the full [out_features, in_features] B@A matrix
  Linear LoRA Kernels, bsz=1, seq_len=2048, rank=32

  ┌───────────────────┬──────────┬──────────┬────────────┬─────────┬──────────┬───────────┐
  │      Config       │ Fwd (ms) │ Bwd (ms) │ Total (ms) │ Speedup │ Peak Mem │ Mem Saved │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ PEFT ref (LoRA)   │ 118.0    │ 185.8    │ 303.8      │ 1.00x   │ 24539 MB │ —         │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernels (LoRA)    │ 106.0    │ 171.4    │ 277.4      │ 1.10x   │ 20525 MB │ -4.0 GB   │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernels + dropout │ 106.7    │ 172.8    │ 279.6      │ 1.09x   │ 21529 MB │ -3.0 GB   │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernels + bias    │ 106.1    │ 171.6    │ 277.8      │ 1.09x   │ 20525 MB │ -4.0 GB   │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ PEFT ref (DoRA)   │ 180.2    │ 236.6    │ 416.7      │ 0.73x   │ 31268 MB │ +6.7 GB   │
  ├───────────────────┼──────────┼──────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernels (DoRA)    │ 112.2    │ 181.1    │ 293.2      │ 1.04x   │ 22769 MB │ -1.8 GB   │
  └───────────────────┴──────────┴──────────┴────────────┴─────────┴──────────┴───────────┘

  Embedding LoRA Kernel, bsz=1, seq_len-2048, vocab=151_936, rank=32

  ┌─────────────────┬────────────┬─────────┬──────────┬───────────┐
  │     Config      │ Total (μs) │ Speedup │ Peak Mem │ Mem Saved │
  ├─────────────────┼────────────┼─────────┼──────────┼───────────┤
  │ PEFT ref (LoRA) │ 425        │ 1.00x   │ 3960 MB  │ —         │
  ├─────────────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernel (LoRA)   │ 525        │ 0.81x   │ 3983 MB  │ ~same     │
  ├─────────────────┼────────────┼─────────┼──────────┼───────────┤
  │ PEFT ref (DoRA) │ 5871       │ 1.00x   │ 7491 MB  │ —         │
  ├─────────────────┼────────────┼─────────┼──────────┼───────────┤
  │ Kernel (DoRA)   │ 650        │ 9.03x   │ 4015 MB  │ -3.5 GB   │
  └─────────────────┴────────────┴─────────┴──────────┴───────────┘

Summary by CodeRabbit

Release Notes

  • New Features

    • Added DoRA (Dimension-wise Low-Rank Adaptation) support for enhanced parameter-efficient fine-tuning.
    • Introduced LoRA embedding layer kernel optimization with new configuration option.
    • Extended LoRA to support bias parameters in projections.
    • Removed compatibility restrictions allowing DoRA to work alongside other LoRA kernel optimizations.
  • Tests

    • Added comprehensive end-to-end test coverage for LoRA kernel implementations against reference behavior.

@winglian winglian requested a review from NanoCode012 March 22, 2026 06:04
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 22, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8a336142-2737-4043-9169-20c067849ccb

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces DoRA (Direction-only LoRA) support via a Triton-accelerated kernel, enhances LoRA implementations with bias and embedding layers, adds configuration options, updates validation logic, and provides comprehensive end-to-end tests validating kernel correctness against PEFT references.

Changes

Cohort / File(s) Summary
DoRA Kernel
src/axolotl/kernels/dora.py
Adds Triton JIT-compiled _dora_fused_norm_kernel computing per-output denominator norms and magnitude scaling, with Python wrapper triton_dora_scale handling dequantization, block sizing, and kernel launch.
LoRA Kernel Refactor
src/axolotl/kernels/lora.py
Extends get_lora_parameters to return LoRA biases, dropout modules, and DoRA magnitudes; updates matmul_lora to handle dropout-applied inputs and biases; refactors LoRA_MLP, LoRA_QKV, LoRA_O with conditional DoRA paths; adds embedding LoRA support; introduces DoRA scale computation utilities and caching.
Quantization Support
src/axolotl/kernels/quantize.py
Updates dequantize() to detect and handle non-double-quantized QuantState instances by bypassing double-quantization flow and directly calling standard dequantize_4bit.
Monkeypatch & Config
src/axolotl/monkeypatch/lora_kernels.py, src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/validation.py
Adds embedding layer LoRA kernel patching; removes DoRA blocking validation; relaxes lora_magnitude_vector absence requirement for kernel eligibility; adds lora_embedding_kernel config field with auto-enable logic; removes early-exit on dropout for kernel activation.
Tests
tests/e2e/kernels/test_lora_features.py
Comprehensive end-to-end test suite validating fused kernel correctness (forward/backward) against PEFT across LoRA bias, dropout, DoRA, embeddings, quantization, and Triton DoRA scaling scenarios with diagnostic comparison helper.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

ready to merge, scheduled_release

Suggested reviewers

  • NanoCode012
  • djsaunde
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding LoRA kernel support for bias, dropout, DoRA, and embeddings, which aligns with the substantial feature additions across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch lora-kernels-v2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 22, 2026

📖 Documentation Preview: https://69c002288510b6e3cbdc4228--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 10138a5

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Mar 22, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/kernels/lora.py`:
- Around line 1274-1284: The forward currently multiplies the combined embedding
by mag_scale before saving, so backward computes d_mag from the already-scaled
result and double-applies mag_scale; change the save to store the pre-scaled
activation instead: compute pre_scaled = result + s * lora_result (or the
equivalent unscaled combined tensor) and call ctx.save_for_backward(x,
A.to(dtype), B.to(dtype), after_A, magnitude, mag_scale, pre_scaled) so d_mag is
derived from the unscaled combination; update the same change in the second
occurrence (around lines 1309-1316) where has_dora is handled.
- Around line 1208-1224: The embedding branch must mirror the DTensor unshard
logic used for the linear path: before returning embed.lora_embedding_A and
embed.lora_embedding_B (and embed.lora_magnitude_vector when present), detect
DTensor-backed parameters and replace local shards with their unsharded/full
tensors the same way the linear path does (i.e., apply the same
unshard/redistribute logic used for W in the linear branch); update the handling
of A, B and magnitude (embed.lora_embedding_A[active_adapter],
embed.lora_embedding_B[active_adapter],
embed.lora_magnitude_vector[active_adapter]) to perform DTensor
unshard/extract-full-tensor when necessary and then return those full tensors
along with s, dropout and base_layer.
- Around line 160-175: The current DoRA cache uses A.data_ptr() / B.data_ptr()
to detect parameter changes which fails because optimizers mutate tensors
in-place; replace pointer-based invalidation with a content-version check or
disable the cache: update the cache key logic used around magnitude._dora_cache
(and the early-return that checks cached_a_ptr/cached_b_ptr) so it stores and
compares a lightweight content fingerprint (e.g., a small checksum/hash derived
from A.detach() and B.detach(), or a tensor version counter updated on change)
instead of data_ptr(), or simply remove the early-return and always recompute
the norm in the forward pass (while still optionally storing
magnitude._dora_cache = (fingerprint, weight_norm) after Triton path); adjust
code paths around triton_dora_scale, magnitude._dora_cache, and the initial
cache check to use the chosen content-based invalidation approach.
- Around line 1374-1375: Move the dropout so it only affects the LoRA
contribution by applying it inside LoRA_Embedding.forward() rather than after
apply() returns: locate LoRA_Embedding.forward() and within it apply dropout to
the LoRA path variable (after_A or lora_result) when self.training and dropout
is not None (mirroring the X_drop usage in other fused kernels), then return the
sum of the frozen base embedding and the now-dropped lora_result; remove the
external dropout call that currently wraps the combined result so the base
embedding is not subject to dropout.
- Around line 1320-1355: The custom embedding backward for LoRA ignores
padding_idx and scale_grad_by_freq: when padding_idx is set or
scale_grad_by_freq is true you must either reproduce F.embedding's
masking/scaling logic before accumulating into d_A_T or avoid the fused path and
return None so PyTorch's default embedding backward runs; to fix, in the
backward branch that computes d_after_A (symbol d_after_A) and accumulates with
d_A_T.index_add_(0, x_flat, d_after_A), filter out positions where x_flat ==
padding_idx (or zero-out corresponding rows in d_after_A) and, if
scale_grad_by_freq is true, divide each contribution by token frequency (compute
counts from x_flat) to match F.embedding behavior; alternatively early-check
padding_idx/scale_grad_by_freq and skip the fused gradient patch (return None
for d_A) when those features are active.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 74a3ac11-ca8c-4142-8e5f-8c04e89e65a8

📥 Commits

Reviewing files that changed from the base of the PR and between fc3b3d1 and 50beb04.

📒 Files selected for processing (7)
  • src/axolotl/kernels/dora.py
  • src/axolotl/kernels/lora.py
  • src/axolotl/kernels/quantize.py
  • src/axolotl/monkeypatch/lora_kernels.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/validation.py
  • tests/e2e/kernels/test_lora_features.py

Comment on lines +160 to +175
# Check cache on magnitude tensor (avoids expensive norm recomputation)
cache = getattr(magnitude, "_dora_cache", None)
if cache is not None:
cached_a_ptr, cached_b_ptr, cached_norm = cache
if cached_a_ptr == A.data_ptr() and cached_b_ptr == B.data_ptr():
return magnitude.to(dtype) / cached_norm

# Full recomputation - try Triton first
if _should_use_triton_dora() and W.is_cuda:
from .dora import triton_dora_scale

result = triton_dora_scale(W, W_quant, A, B, s, magnitude, dtype)
# Cache the weight_norm (derive from result: norm = mag / result)
weight_norm = (magnitude.to(dtype) / result).detach()
magnitude._dora_cache = (A.data_ptr(), B.data_ptr(), weight_norm)
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Read the relevant section of the file to understand the cache invalidation logic
wc -l src/axolotl/kernels/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 100


🏁 Script executed:

# Get context around the flagged lines
sed -n '140,230p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 3955


🏁 Script executed:

# Look for the function signature to understand what A and B are
sed -n '130,180p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2296


🏁 Script executed:

# Look for the second location mentioned in the comment (lines 204-221)
sed -n '195,225p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 1461


🏁 Script executed:

# Search for where these functions are called and what A/B parameters are
rg "A\.data_ptr|B\.data_ptr|_compute_dora_scale" src/axolotl/kernels/lora.py | head -20

Repository: axolotl-ai-cloud/axolotl

Length of output: 972


🏁 Script executed:

# Look for where A and B parameters come from in the LoRA context
rg -B 10 -A 5 "def _compute_dora_scale\(" src/axolotl/kernels/lora.py | head -40

Repository: axolotl-ai-cloud/axolotl

Length of output: 472


🏁 Script executed:

# Search the broader codebase to understand what A and B are (LoRA adapter parameters)
rg "LoRA|adapter.*param" src/axolotl/kernels/lora.py | grep -E "(A|B)" | head -15

Repository: axolotl-ai-cloud/axolotl

Length of output: 989


🏁 Script executed:

# Search for how A and B parameters are created and used in the LoRA context
rg -B 5 -A 10 "lora_A|lora_B|\.A\s|\.B\s" src/axolotl/kernels/lora.py | head -50

Repository: axolotl-ai-cloud/axolotl

Length of output: 1206


🏁 Script executed:

# Look for where these functions are called to understand the data flow
rg -B 3 -A 3 "_compute_dora_scale\(" src/axolotl/kernels/lora.py | head -60

Repository: axolotl-ai-cloud/axolotl

Length of output: 2578


🏁 Script executed:

# Check if there's any documentation or tests about the caching behavior
find . -type f \( -name "*.py" -o -name "*.md" \) -path "*/test*" | xargs rg -l "dora.*cache\|_dora_cache" 2>/dev/null | head -5

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Search the broader codebase to see how LoRA parameters are typically accessed
rg "get_lora.*param|LoRA.*param" src/axolotl/kernels/lora.py -B 2 -A 8 | head -50

Repository: axolotl-ai-cloud/axolotl

Length of output: 1683


🏁 Script executed:

# Let's check what PyTorch version is being used and search for any tests about this behavior
rg -i "pytorch|torch" pyproject.toml setup.py requirements.txt 2>/dev/null | head -10

Repository: axolotl-ai-cloud/axolotl

Length of output: 691


🏁 Script executed:

# Look for PyTorch documentation or comments about data_ptr behavior
rg "data_ptr|in.place.*optim|optim.*place" src/axolotl/ -i | head -20

Repository: axolotl-ai-cloud/axolotl

Length of output: 1167


🏁 Script executed:

# Check if there are any actual tests for the DoRA cache mechanism
find . -name "*.py" -path "*/test*" -o -name "*test*.py" | xargs rg -l "dora\|DoRA" 2>/dev/null

Repository: axolotl-ai-cloud/axolotl

Length of output: 52


🏁 Script executed:

# Let's look at what happens during training with these parameters - how are A/B modified?
rg "lora_A|lora_B" src/axolotl/kernels/lora.py -B 5 -A 5 | head -80

Repository: axolotl-ai-cloud/axolotl

Length of output: 919


🌐 Web query:

PyTorch optimizer in-place updates data_ptr() stability AdamW SGD

💡 Result:

PyTorch optimizers like AdamW and SGD perform in-place updates to model parameters during optimizer.step, modifying the tensor data directly via operations like param.add_(grad, alpha=-lr). This updates the parameter values without allocating new tensors or changing the data_ptr, ensuring the same underlying memory buffer is used before and after the update. The data_ptr remains stable across optimizer steps. This stability holds for both standard implementations and optimized variants (foreach, fused). Fused kernels (e.g., fused_adamw, fused_sgd) are designed for in-place parameter updates on the existing tensor storage, preserving data_ptr. No evidence of data_ptr changes or instability was found in official docs, source code, or discussions. Optimizers run under torch.no_grad context (via _use_grad_for_differentiable decorator with differentiable=False by default), allowing safe in-place modifications on leaf parameters without autograd errors. Users relying on stable data_ptr (e.g., for C++ extensions or custom ops) can expect consistency. Potential caveats: - Multiple optimizers sharing parameters can cause inplace errors during backprop if steps interleave improperly. - Specialized cases like checkpointing or torch.compile may indirectly affect behavior, but core inplace updates preserve data_ptr. - Recent PRs distinguish mutable_data_ptr vs const_data_ptr for COW safety, but Python-level data_ptr remains stable for parameters.

Citations:


Fix DoRA cache invalidation: data_ptr() does not change on optimizer updates, causing stale norm caching.

Standard PyTorch optimizers (AdamW, SGD) perform in-place parameter updates via operations like param.add_(grad, alpha=-lr). These preserve the underlying memory pointer, so A.data_ptr() and B.data_ptr() remain constant even after values are modified. The cache comparison at line 165 and 213 will always match after the first step, causing the cached weight norm to become stale. The DoRA scale will drift from correct values throughout training.

Replace pointer-based invalidation with a content hash or version counter. Alternatively, invalidate the cache on every forward pass and rely on the computed norm being fast enough.

Also applies to: 204-221

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/kernels/lora.py` around lines 160 - 175, The current DoRA cache
uses A.data_ptr() / B.data_ptr() to detect parameter changes which fails because
optimizers mutate tensors in-place; replace pointer-based invalidation with a
content-version check or disable the cache: update the cache key logic used
around magnitude._dora_cache (and the early-return that checks
cached_a_ptr/cached_b_ptr) so it stores and compares a lightweight content
fingerprint (e.g., a small checksum/hash derived from A.detach() and B.detach(),
or a tensor version counter updated on change) instead of data_ptr(), or simply
remove the early-return and always recompute the norm in the forward pass (while
still optionally storing magnitude._dora_cache = (fingerprint, weight_norm)
after Triton path); adjust code paths around triton_dora_scale,
magnitude._dora_cache, and the initial cache check to use the chosen
content-based invalidation approach.

Comment on lines +1208 to +1224
A = embed.lora_embedding_A[active_adapter] # nn.Parameter [rank, vocab]
B = embed.lora_embedding_B[active_adapter] # nn.Parameter [hidden_dim, rank]
s = embed.scaling[active_adapter]

dropout = None
if hasattr(embed, "lora_dropout") and active_adapter in embed.lora_dropout:
dropout = embed.lora_dropout[active_adapter]

magnitude = None
if (
hasattr(embed, "lora_magnitude_vector")
and embed.lora_magnitude_vector
and active_adapter in embed.lora_magnitude_vector
):
magnitude = embed.lora_magnitude_vector[active_adapter].weight

return W, A, B, s, dropout, magnitude, base_layer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's locate and examine the lora.py file to understand the context
find . -name "lora.py" -path "*/kernels/*" | head -5

Repository: axolotl-ai-cloud/axolotl

Length of output: 97


🏁 Script executed:

# Get the file structure and locate the relevant code sections
git ls-files "*/kernels/lora.py"

Repository: axolotl-ai-cloud/axolotl

Length of output: 95


🏁 Script executed:

# Let's examine the linear path implementation in the lora.py file
cat -n src/axolotl/kernels/lora.py | head -100

Repository: axolotl-ai-cloud/axolotl

Length of output: 4163


🏁 Script executed:

# Search for DTensor-related code and unshard operations in lora.py
rg "DTensor|unshard" src/axolotl/kernels/lora.py -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 803


🏁 Script executed:

# Find the linear path function to see how it handles DTensor unsharding
rg "def.*linear|def.*forward" src/axolotl/kernels/lora.py -A 5

Repository: axolotl-ai-cloud/axolotl

Length of output: 735


🏁 Script executed:

# Search for the embedding function that matches the review comment lines
rg "def get_lora_embedding" src/axolotl/kernels/lora.py -A 30

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Find the lora_embedding_kernel function
rg "lora_embedding_kernel|def.*embedding.*kernel" src/axolotl/kernels/lora.py -B 2 -A 10

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Look at lines around 1208-1224 mentioned in the review
sed -n '1190,1240p' src/axolotl/kernels/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1302


🏁 Script executed:

# Search for all functions that return lora_embedding_A and lora_embedding_B
rg "lora_embedding_A|lora_embedding_B" src/axolotl/kernels/lora.py -B 5 -A 5

Repository: axolotl-ai-cloud/axolotl

Length of output: 893


🏁 Script executed:

# Get the full get_embedding_lora_parameters function
sed -n '1180,1250p' src/axolotl/kernels/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1869


🏁 Script executed:

# Find where get_embedding_lora_parameters is called and used
rg "get_embedding_lora_parameters" src/axolotl/kernels/lora.py -B 2 -A 10

Repository: axolotl-ai-cloud/axolotl

Length of output: 812


🏁 Script executed:

# Search for lora_embedding_kernel in the entire codebase
rg "lora_embedding_kernel" --type py

Repository: axolotl-ai-cloud/axolotl

Length of output: 665


🏁 Script executed:

# Let's check the total size of lora.py to understand structure
wc -l src/axolotl/kernels/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 100


🏁 Script executed:

# Find the LoRA_Embedding class definition
rg "class LoRA_Embedding" src/axolotl/kernels/lora.py -A 50

Repository: axolotl-ai-cloud/axolotl

Length of output: 1584


🏁 Script executed:

# Search for where LoRA_Embedding.apply is used and what forward does
rg "class LoRA_Embedding|def forward" src/axolotl/kernels/lora.py | head -20

Repository: axolotl-ai-cloud/axolotl

Length of output: 182


Mirror the DTensor unshard handling for embedding LoRA params.

The linear path unshards DTensor-backed adapter weights before entering the fused autograd code (lines 81-88), but the embedding path returns lora_embedding_A/B directly without this check. Under FSDP2, the embedding kernel receives local shards instead of the full matrices it expects, creating a safety mismatch with the linear kernels.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/kernels/lora.py` around lines 1208 - 1224, The embedding branch
must mirror the DTensor unshard logic used for the linear path: before returning
embed.lora_embedding_A and embed.lora_embedding_B (and
embed.lora_magnitude_vector when present), detect DTensor-backed parameters and
replace local shards with their unsharded/full tensors the same way the linear
path does (i.e., apply the same unshard/redistribute logic used for W in the
linear branch); update the handling of A, B and magnitude
(embed.lora_embedding_A[active_adapter], embed.lora_embedding_B[active_adapter],
embed.lora_magnitude_vector[active_adapter]) to perform DTensor
unshard/extract-full-tensor when necessary and then return those full tensors
along with s, dropout and base_layer.

Comment on lines +1274 to +1284
if has_dora:
mag_scale = _compute_dora_scale(
W.t(), None, A, B, s, magnitude, dtype
)
# DoRA: mag_scale * (base + s * lora) + bias
# base embedding has no bias
result = mag_scale.unsqueeze(0) * (result + s * lora_result)
ctx.save_for_backward(
x, A.to(dtype), B.to(dtype), after_A,
magnitude, mag_scale, result, # result = combined * mag_scale
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Save the pre-scaled embedding activation for DoRA backward.

d_mag is derived from the unscaled base + s * lora_result, but the forward path currently saves result after multiplying by mag_scale. That applies the scale twice in backward and makes embedding-DoRA magnitude gradients too large.

🐛 Suggested fix
             if has_dora:
                 mag_scale = _compute_dora_scale(
                     W.t(), None, A, B, s, magnitude, dtype
                 )
-                # DoRA: mag_scale * (base + s * lora) + bias
-                # base embedding has no bias
-                result = mag_scale.unsqueeze(0) * (result + s * lora_result)
+                combined = result + s * lora_result
+                result = mag_scale.unsqueeze(0) * combined
                 ctx.save_for_backward(
                     x, A.to(dtype), B.to(dtype), after_A,
-                    magnitude, mag_scale, result,  # result = combined * mag_scale
+                    magnitude, mag_scale, combined,
                 )

Also applies to: 1309-1316

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/kernels/lora.py` around lines 1274 - 1284, The forward currently
multiplies the combined embedding by mag_scale before saving, so backward
computes d_mag from the already-scaled result and double-applies mag_scale;
change the save to store the pre-scaled activation instead: compute pre_scaled =
result + s * lora_result (or the equivalent unscaled combined tensor) and call
ctx.save_for_backward(x, A.to(dtype), B.to(dtype), after_A, magnitude,
mag_scale, pre_scaled) so d_mag is derived from the unscaled combination; update
the same change in the second occurrence (around lines 1309-1316) where has_dora
is handled.

Comment on lines +1320 to +1355
if has_lora:
# Use float32 for gradient computation (LoRA params are fp32)
compute_dtype = torch.float32

after_A_flat = after_A.view(-1, after_A.shape[-1]).to(compute_dtype)
grad_flat = grad_output.view(-1, grad_output.shape[-1]).to(compute_dtype)
B_f = B.to(compute_dtype)

# B is [hidden_dim, rank], B_T = B.t() = [rank, hidden_dim]
# lora_result = after_A @ B_T → d/d(B_T) = s * after_A^T @ grad
B_T = B_f.t() # [rank, hidden_dim]
d_B_T = torch.empty_like(B_T)
d_B_T.addmm_(after_A_flat.t(), grad_flat, alpha=s, beta=0)
d_B = d_B_T.t() # [hidden_dim, rank]

# d_A: gradient flows through F.embedding lookup
# d_after_A = s * grad @ B = [T, hidden] @ [hidden, rank] = [T, rank]
d_after_A = s * grad_flat @ B_f

# F.embedding backward: scatter d_after_A into A^T gradient
x_flat = x.view(-1)
A_f = A.to(compute_dtype)
d_A_T = torch.zeros_like(A_f.t()) # [vocab, rank]
d_A_T.index_add_(0, x_flat, d_after_A)
d_A = d_A_T.t() # [rank, vocab]

# x, W, A, B, s, magnitude, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
return (
None, # x
None, # W (base embedding weight grad handled by PyTorch)
d_A, # A
d_B, # B
None, # s
d_mag, # magnitude
None, None, None, None, None, # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the full function definition around lines 1320-1355
sed -n '1280,1360p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2844


🏁 Script executed:

# Find the function definition to understand the full signature and inputs
sed -n '1200,1280p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2827


🏁 Script executed:

# Get the full function from the beginning
rg -A 200 "def.*backward" src/axolotl/kernels/lora.py | head -250

Repository: axolotl-ai-cloud/axolotl

Length of output: 9123


🏁 Script executed:

# Get the exact lines 1320-1355
sed -n '1310,1365p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 2283


🏁 Script executed:

# Find the function that contains this code - look backwards for "def backward"
sed -n '1200,1360p' src/axolotl/kernels/lora.py | grep -n "def backward"

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Search for any embedding-related functions in the file
rg "embedding" src/axolotl/kernels/lora.py -i -B 3 -A 3

Repository: axolotl-ai-cloud/axolotl

Length of output: 3576


🏁 Script executed:

# Get the complete LoRA_Embedding class definition
rg -A 300 "class LoRA_Embedding" src/axolotl/kernels/lora.py | head -400

Repository: axolotl-ai-cloud/axolotl

Length of output: 5954


🏁 Script executed:

# Also check the backward function signature specifically
sed -n '1240,1365p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 4470


🏁 Script executed:

# Check what PyTorch's F.embedding backward does with padding_idx
# Let's also search for any W gradient handling or patching logic
rg "padding_idx\|scale_grad_by_freq" src/axolotl/kernels/lora.py -B 3 -A 3

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Check if there's any separate handling for W gradients in embedding
rg "embedding.*grad\|grad.*embedding" src/axolotl/kernels/ -i -B 2 -A 2

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Look for how this LoRA_Embedding is being used and patched
rg "apply_lora_embedding\|LoRA_Embedding" src/axolotl/ -B 3 -A 3

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Look at the broader context - check if W is saved in saved_tensors
sed -n '1200,1290p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 3028


🏁 Script executed:

# Search for any usage patterns to understand how this is patched
fd "\.py$" src/axolotl/ --exec grep -l "apply_lora_embedding" {} \;

Repository: axolotl-ai-cloud/axolotl

Length of output: 135


🏁 Script executed:

# Check if there's any patching or module replacement logic
rg "forward_pre_hook\|forward_hook\|apply\|register" src/axolotl/kernels/lora.py | head -20

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Check the monkeypatch file to see how apply_lora_embedding is being used
cat src/axolotl/monkeypatch/lora_kernels.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 18197


🏁 Script executed:

# Also check if there's any documentation about F.embedding backward semantics
# Let's look at PyTorch F.embedding to understand what padding_idx and scale_grad_by_freq do
python3 << 'EOF'
import torch
import torch.nn.functional as F

# Check F.embedding documentation
help(F.embedding)
EOF

Repository: axolotl-ai-cloud/axolotl

Length of output: 185


🏁 Script executed:

# Let's verify what happens with W - check if it's supposed to be computed
sed -n '1290,1330p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 1555


🏁 Script executed:

# Check if there's any W gradient handling elsewhere
rg "d_W\|grad.*W" src/axolotl/kernels/lora.py | grep -i embed

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Look at how F.embedding handles padding_idx semantics by checking PyTorch docs
python3 << 'EOF'
# Let's write a simple test to show how PyTorch's F.embedding backward handles padding_idx
import torch

# Create a simple embedding and test with padding_idx
embedding_weight = torch.randn(10, 5, requires_grad=True)
x = torch.tensor([0, 1, 2, 2, 2])  # 2 is padding_idx

# Forward with padding_idx
output = torch.nn.functional.embedding(x, embedding_weight, padding_idx=2)

# Simple backward
loss = output.sum()
loss.backward()

print("Gradient for padding_idx row (should be zero):")
print(embedding_weight.grad[2])  # Should be zero
print("\nGradient for non-padding rows (should be non-zero):")
print(embedding_weight.grad[0:2])
EOF

Repository: axolotl-ai-cloud/axolotl

Length of output: 185


🏁 Script executed:

# Check if padding_idx is actually used in the backward - it's stored in ctx but let's verify
sed -n '1240,1360p' src/axolotl/kernels/lora.py | grep -n "padding_idx"

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Let's check the complete backward to see all uses of context variables
sed -n '1270,1360p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 3043


🏁 Script executed:

# Search for any tests or usage of LoRA_Embedding to understand expected behavior
fd "test.*lora\|lora.*test" src/axolotl --type f -x grep -l "embedding" {} \;

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Check if there are any comments about W gradient handling or padding
rg "base embedding\|W.*grad\|padding" src/axolotl/kernels/lora.py -B 2 -A 2 | head -40

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Let's look at what PyTorch's standard embedding backward actually does
# by checking any comments about it in the code
rg "F.embedding\|embedding.*backward" src/axolotl/kernels/lora.py -B 3 -A 3

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Check if W is ever supposed to have gradients in this LoRA kernel
# Look for how it's used in forward and what the comment says
sed -n '1200,1250p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 1796


🏁 Script executed:

# Most importantly, let's see if padding_idx is stored and how it should be used
sed -n '1200,1400p' src/axolotl/kernels/lora.py | cat -n

Repository: axolotl-ai-cloud/axolotl

Length of output: 7036


🏁 Script executed:

# Check the complete LoRA_Embedding class one more time to see full backward
sed -n '1000,1365p' src/axolotl/kernels/lora.py | grep -A 200 "class LoRA_Embedding"

Repository: axolotl-ai-cloud/axolotl

Length of output: 50


🏁 Script executed:

# Get file length first
wc -l src/axolotl/kernels/lora.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 100


🏁 Script executed:

# Search for LoRA_Embedding class more broadly
rg "class LoRA_Embedding" src/axolotl/kernels/lora.py -A 350 | head -400

Repository: axolotl-ai-cloud/axolotl

Length of output: 5954


Fix the embedding backward to respect padding_idx and scale_grad_by_freq, or skip patching when these features are active.

The custom embedding backward ignores padding_idx during gradient accumulation and doesn't apply scale_grad_by_freq scaling. The standard F.embedding backward masks gradients for padded positions, but index_add_ accumulates from all positions. This causes the fused LoRA_A and LoRA_B gradients to diverge from the unfused path when these features are enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/kernels/lora.py` around lines 1320 - 1355, The custom embedding
backward for LoRA ignores padding_idx and scale_grad_by_freq: when padding_idx
is set or scale_grad_by_freq is true you must either reproduce F.embedding's
masking/scaling logic before accumulating into d_A_T or avoid the fused path and
return None so PyTorch's default embedding backward runs; to fix, in the
backward branch that computes d_after_A (symbol d_after_A) and accumulates with
d_A_T.index_add_(0, x_flat, d_after_A), filter out positions where x_flat ==
padding_idx (or zero-out corresponding rows in d_after_A) and, if
scale_grad_by_freq is true, divide each contribution by token frequency (compute
counts from x_flat) to match F.embedding behavior; alternatively early-check
padding_idx/scale_grad_by_freq and skip the fused gradient patch (return None
for d_A) when those features are active.

Comment thread src/axolotl/kernels/lora.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 22, 2026

Codecov Report

❌ Patch coverage is 5.49020% with 482 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/kernels/lora.py 5.39% 421 Missing ⚠️
src/axolotl/kernels/dora.py 0.00% 35 Missing ⚠️
src/axolotl/monkeypatch/lora_kernels.py 0.00% 24 Missing ⚠️
src/axolotl/kernels/quantize.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian merged commit b3289fd into main Mar 22, 2026
14 of 17 checks passed
@winglian winglian deleted the lora-kernels-v2 branch March 22, 2026 17:53
@winglian winglian removed the scheduled_release This PR is slated for the upcoming release label Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant