feat: LoRA kernel support for bias, dropout, dora, embeddings#3528
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces DoRA (Direction-only LoRA) support via a Triton-accelerated kernel, enhances LoRA implementations with bias and embedding layers, adds configuration options, updates validation logic, and provides comprehensive end-to-end tests validating kernel correctness against PEFT references. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://69c002288510b6e3cbdc4228--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 10138a5 |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/kernels/lora.py`:
- Around line 1274-1284: The forward currently multiplies the combined embedding
by mag_scale before saving, so backward computes d_mag from the already-scaled
result and double-applies mag_scale; change the save to store the pre-scaled
activation instead: compute pre_scaled = result + s * lora_result (or the
equivalent unscaled combined tensor) and call ctx.save_for_backward(x,
A.to(dtype), B.to(dtype), after_A, magnitude, mag_scale, pre_scaled) so d_mag is
derived from the unscaled combination; update the same change in the second
occurrence (around lines 1309-1316) where has_dora is handled.
- Around line 1208-1224: The embedding branch must mirror the DTensor unshard
logic used for the linear path: before returning embed.lora_embedding_A and
embed.lora_embedding_B (and embed.lora_magnitude_vector when present), detect
DTensor-backed parameters and replace local shards with their unsharded/full
tensors the same way the linear path does (i.e., apply the same
unshard/redistribute logic used for W in the linear branch); update the handling
of A, B and magnitude (embed.lora_embedding_A[active_adapter],
embed.lora_embedding_B[active_adapter],
embed.lora_magnitude_vector[active_adapter]) to perform DTensor
unshard/extract-full-tensor when necessary and then return those full tensors
along with s, dropout and base_layer.
- Around line 160-175: The current DoRA cache uses A.data_ptr() / B.data_ptr()
to detect parameter changes which fails because optimizers mutate tensors
in-place; replace pointer-based invalidation with a content-version check or
disable the cache: update the cache key logic used around magnitude._dora_cache
(and the early-return that checks cached_a_ptr/cached_b_ptr) so it stores and
compares a lightweight content fingerprint (e.g., a small checksum/hash derived
from A.detach() and B.detach(), or a tensor version counter updated on change)
instead of data_ptr(), or simply remove the early-return and always recompute
the norm in the forward pass (while still optionally storing
magnitude._dora_cache = (fingerprint, weight_norm) after Triton path); adjust
code paths around triton_dora_scale, magnitude._dora_cache, and the initial
cache check to use the chosen content-based invalidation approach.
- Around line 1374-1375: Move the dropout so it only affects the LoRA
contribution by applying it inside LoRA_Embedding.forward() rather than after
apply() returns: locate LoRA_Embedding.forward() and within it apply dropout to
the LoRA path variable (after_A or lora_result) when self.training and dropout
is not None (mirroring the X_drop usage in other fused kernels), then return the
sum of the frozen base embedding and the now-dropped lora_result; remove the
external dropout call that currently wraps the combined result so the base
embedding is not subject to dropout.
- Around line 1320-1355: The custom embedding backward for LoRA ignores
padding_idx and scale_grad_by_freq: when padding_idx is set or
scale_grad_by_freq is true you must either reproduce F.embedding's
masking/scaling logic before accumulating into d_A_T or avoid the fused path and
return None so PyTorch's default embedding backward runs; to fix, in the
backward branch that computes d_after_A (symbol d_after_A) and accumulates with
d_A_T.index_add_(0, x_flat, d_after_A), filter out positions where x_flat ==
padding_idx (or zero-out corresponding rows in d_after_A) and, if
scale_grad_by_freq is true, divide each contribution by token frequency (compute
counts from x_flat) to match F.embedding behavior; alternatively early-check
padding_idx/scale_grad_by_freq and skip the fused gradient patch (return None
for d_A) when those features are active.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 74a3ac11-ca8c-4142-8e5f-8c04e89e65a8
📒 Files selected for processing (7)
src/axolotl/kernels/dora.pysrc/axolotl/kernels/lora.pysrc/axolotl/kernels/quantize.pysrc/axolotl/monkeypatch/lora_kernels.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/validation.pytests/e2e/kernels/test_lora_features.py
| # Check cache on magnitude tensor (avoids expensive norm recomputation) | ||
| cache = getattr(magnitude, "_dora_cache", None) | ||
| if cache is not None: | ||
| cached_a_ptr, cached_b_ptr, cached_norm = cache | ||
| if cached_a_ptr == A.data_ptr() and cached_b_ptr == B.data_ptr(): | ||
| return magnitude.to(dtype) / cached_norm | ||
|
|
||
| # Full recomputation - try Triton first | ||
| if _should_use_triton_dora() and W.is_cuda: | ||
| from .dora import triton_dora_scale | ||
|
|
||
| result = triton_dora_scale(W, W_quant, A, B, s, magnitude, dtype) | ||
| # Cache the weight_norm (derive from result: norm = mag / result) | ||
| weight_norm = (magnitude.to(dtype) / result).detach() | ||
| magnitude._dora_cache = (A.data_ptr(), B.data_ptr(), weight_norm) | ||
| return result |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Read the relevant section of the file to understand the cache invalidation logic
wc -l src/axolotl/kernels/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 100
🏁 Script executed:
# Get context around the flagged lines
sed -n '140,230p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 3955
🏁 Script executed:
# Look for the function signature to understand what A and B are
sed -n '130,180p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2296
🏁 Script executed:
# Look for the second location mentioned in the comment (lines 204-221)
sed -n '195,225p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1461
🏁 Script executed:
# Search for where these functions are called and what A/B parameters are
rg "A\.data_ptr|B\.data_ptr|_compute_dora_scale" src/axolotl/kernels/lora.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 972
🏁 Script executed:
# Look for where A and B parameters come from in the LoRA context
rg -B 10 -A 5 "def _compute_dora_scale\(" src/axolotl/kernels/lora.py | head -40Repository: axolotl-ai-cloud/axolotl
Length of output: 472
🏁 Script executed:
# Search the broader codebase to understand what A and B are (LoRA adapter parameters)
rg "LoRA|adapter.*param" src/axolotl/kernels/lora.py | grep -E "(A|B)" | head -15Repository: axolotl-ai-cloud/axolotl
Length of output: 989
🏁 Script executed:
# Search for how A and B parameters are created and used in the LoRA context
rg -B 5 -A 10 "lora_A|lora_B|\.A\s|\.B\s" src/axolotl/kernels/lora.py | head -50Repository: axolotl-ai-cloud/axolotl
Length of output: 1206
🏁 Script executed:
# Look for where these functions are called to understand the data flow
rg -B 3 -A 3 "_compute_dora_scale\(" src/axolotl/kernels/lora.py | head -60Repository: axolotl-ai-cloud/axolotl
Length of output: 2578
🏁 Script executed:
# Check if there's any documentation or tests about the caching behavior
find . -type f \( -name "*.py" -o -name "*.md" \) -path "*/test*" | xargs rg -l "dora.*cache\|_dora_cache" 2>/dev/null | head -5Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Search the broader codebase to see how LoRA parameters are typically accessed
rg "get_lora.*param|LoRA.*param" src/axolotl/kernels/lora.py -B 2 -A 8 | head -50Repository: axolotl-ai-cloud/axolotl
Length of output: 1683
🏁 Script executed:
# Let's check what PyTorch version is being used and search for any tests about this behavior
rg -i "pytorch|torch" pyproject.toml setup.py requirements.txt 2>/dev/null | head -10Repository: axolotl-ai-cloud/axolotl
Length of output: 691
🏁 Script executed:
# Look for PyTorch documentation or comments about data_ptr behavior
rg "data_ptr|in.place.*optim|optim.*place" src/axolotl/ -i | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 1167
🏁 Script executed:
# Check if there are any actual tests for the DoRA cache mechanism
find . -name "*.py" -path "*/test*" -o -name "*test*.py" | xargs rg -l "dora\|DoRA" 2>/dev/nullRepository: axolotl-ai-cloud/axolotl
Length of output: 52
🏁 Script executed:
# Let's look at what happens during training with these parameters - how are A/B modified?
rg "lora_A|lora_B" src/axolotl/kernels/lora.py -B 5 -A 5 | head -80Repository: axolotl-ai-cloud/axolotl
Length of output: 919
🌐 Web query:
PyTorch optimizer in-place updates data_ptr() stability AdamW SGD
💡 Result:
PyTorch optimizers like AdamW and SGD perform in-place updates to model parameters during optimizer.step, modifying the tensor data directly via operations like param.add_(grad, alpha=-lr). This updates the parameter values without allocating new tensors or changing the data_ptr, ensuring the same underlying memory buffer is used before and after the update. The data_ptr remains stable across optimizer steps. This stability holds for both standard implementations and optimized variants (foreach, fused). Fused kernels (e.g., fused_adamw, fused_sgd) are designed for in-place parameter updates on the existing tensor storage, preserving data_ptr. No evidence of data_ptr changes or instability was found in official docs, source code, or discussions. Optimizers run under torch.no_grad context (via _use_grad_for_differentiable decorator with differentiable=False by default), allowing safe in-place modifications on leaf parameters without autograd errors. Users relying on stable data_ptr (e.g., for C++ extensions or custom ops) can expect consistency. Potential caveats: - Multiple optimizers sharing parameters can cause inplace errors during backprop if steps interleave improperly. - Specialized cases like checkpointing or torch.compile may indirectly affect behavior, but core inplace updates preserve data_ptr. - Recent PRs distinguish mutable_data_ptr vs const_data_ptr for COW safety, but Python-level data_ptr remains stable for parameters.
Citations:
- 1: https://docs.pytorch.org/docs/stable/optim.html
- 2: https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html
- 3: https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html
- 4: https://pytorch.org/docs/stable/_modules/torch/optim/adamw.html
- 5: https://stackoverflow.com/questions/75134619/how-does-optimizer-step-perform-an-in-place-operation
- 6: https://stackoverflow.com/questions/78942171/why-doesnt-the-in-place-operation-on-leaf-variables-in-pytorch-optimizers-cause
- 7: https://pytorch.org/docs/stable/generated/torch.Tensor.data_ptr.html
Fix DoRA cache invalidation: data_ptr() does not change on optimizer updates, causing stale norm caching.
Standard PyTorch optimizers (AdamW, SGD) perform in-place parameter updates via operations like param.add_(grad, alpha=-lr). These preserve the underlying memory pointer, so A.data_ptr() and B.data_ptr() remain constant even after values are modified. The cache comparison at line 165 and 213 will always match after the first step, causing the cached weight norm to become stale. The DoRA scale will drift from correct values throughout training.
Replace pointer-based invalidation with a content hash or version counter. Alternatively, invalidate the cache on every forward pass and rely on the computed norm being fast enough.
Also applies to: 204-221
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/kernels/lora.py` around lines 160 - 175, The current DoRA cache
uses A.data_ptr() / B.data_ptr() to detect parameter changes which fails because
optimizers mutate tensors in-place; replace pointer-based invalidation with a
content-version check or disable the cache: update the cache key logic used
around magnitude._dora_cache (and the early-return that checks
cached_a_ptr/cached_b_ptr) so it stores and compares a lightweight content
fingerprint (e.g., a small checksum/hash derived from A.detach() and B.detach(),
or a tensor version counter updated on change) instead of data_ptr(), or simply
remove the early-return and always recompute the norm in the forward pass (while
still optionally storing magnitude._dora_cache = (fingerprint, weight_norm)
after Triton path); adjust code paths around triton_dora_scale,
magnitude._dora_cache, and the initial cache check to use the chosen
content-based invalidation approach.
| A = embed.lora_embedding_A[active_adapter] # nn.Parameter [rank, vocab] | ||
| B = embed.lora_embedding_B[active_adapter] # nn.Parameter [hidden_dim, rank] | ||
| s = embed.scaling[active_adapter] | ||
|
|
||
| dropout = None | ||
| if hasattr(embed, "lora_dropout") and active_adapter in embed.lora_dropout: | ||
| dropout = embed.lora_dropout[active_adapter] | ||
|
|
||
| magnitude = None | ||
| if ( | ||
| hasattr(embed, "lora_magnitude_vector") | ||
| and embed.lora_magnitude_vector | ||
| and active_adapter in embed.lora_magnitude_vector | ||
| ): | ||
| magnitude = embed.lora_magnitude_vector[active_adapter].weight | ||
|
|
||
| return W, A, B, s, dropout, magnitude, base_layer |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's locate and examine the lora.py file to understand the context
find . -name "lora.py" -path "*/kernels/*" | head -5Repository: axolotl-ai-cloud/axolotl
Length of output: 97
🏁 Script executed:
# Get the file structure and locate the relevant code sections
git ls-files "*/kernels/lora.py"Repository: axolotl-ai-cloud/axolotl
Length of output: 95
🏁 Script executed:
# Let's examine the linear path implementation in the lora.py file
cat -n src/axolotl/kernels/lora.py | head -100Repository: axolotl-ai-cloud/axolotl
Length of output: 4163
🏁 Script executed:
# Search for DTensor-related code and unshard operations in lora.py
rg "DTensor|unshard" src/axolotl/kernels/lora.py -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 803
🏁 Script executed:
# Find the linear path function to see how it handles DTensor unsharding
rg "def.*linear|def.*forward" src/axolotl/kernels/lora.py -A 5Repository: axolotl-ai-cloud/axolotl
Length of output: 735
🏁 Script executed:
# Search for the embedding function that matches the review comment lines
rg "def get_lora_embedding" src/axolotl/kernels/lora.py -A 30Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Find the lora_embedding_kernel function
rg "lora_embedding_kernel|def.*embedding.*kernel" src/axolotl/kernels/lora.py -B 2 -A 10Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Look at lines around 1208-1224 mentioned in the review
sed -n '1190,1240p' src/axolotl/kernels/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1302
🏁 Script executed:
# Search for all functions that return lora_embedding_A and lora_embedding_B
rg "lora_embedding_A|lora_embedding_B" src/axolotl/kernels/lora.py -B 5 -A 5Repository: axolotl-ai-cloud/axolotl
Length of output: 893
🏁 Script executed:
# Get the full get_embedding_lora_parameters function
sed -n '1180,1250p' src/axolotl/kernels/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 1869
🏁 Script executed:
# Find where get_embedding_lora_parameters is called and used
rg "get_embedding_lora_parameters" src/axolotl/kernels/lora.py -B 2 -A 10Repository: axolotl-ai-cloud/axolotl
Length of output: 812
🏁 Script executed:
# Search for lora_embedding_kernel in the entire codebase
rg "lora_embedding_kernel" --type pyRepository: axolotl-ai-cloud/axolotl
Length of output: 665
🏁 Script executed:
# Let's check the total size of lora.py to understand structure
wc -l src/axolotl/kernels/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 100
🏁 Script executed:
# Find the LoRA_Embedding class definition
rg "class LoRA_Embedding" src/axolotl/kernels/lora.py -A 50Repository: axolotl-ai-cloud/axolotl
Length of output: 1584
🏁 Script executed:
# Search for where LoRA_Embedding.apply is used and what forward does
rg "class LoRA_Embedding|def forward" src/axolotl/kernels/lora.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 182
Mirror the DTensor unshard handling for embedding LoRA params.
The linear path unshards DTensor-backed adapter weights before entering the fused autograd code (lines 81-88), but the embedding path returns lora_embedding_A/B directly without this check. Under FSDP2, the embedding kernel receives local shards instead of the full matrices it expects, creating a safety mismatch with the linear kernels.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/kernels/lora.py` around lines 1208 - 1224, The embedding branch
must mirror the DTensor unshard logic used for the linear path: before returning
embed.lora_embedding_A and embed.lora_embedding_B (and
embed.lora_magnitude_vector when present), detect DTensor-backed parameters and
replace local shards with their unsharded/full tensors the same way the linear
path does (i.e., apply the same unshard/redistribute logic used for W in the
linear branch); update the handling of A, B and magnitude
(embed.lora_embedding_A[active_adapter], embed.lora_embedding_B[active_adapter],
embed.lora_magnitude_vector[active_adapter]) to perform DTensor
unshard/extract-full-tensor when necessary and then return those full tensors
along with s, dropout and base_layer.
| if has_dora: | ||
| mag_scale = _compute_dora_scale( | ||
| W.t(), None, A, B, s, magnitude, dtype | ||
| ) | ||
| # DoRA: mag_scale * (base + s * lora) + bias | ||
| # base embedding has no bias | ||
| result = mag_scale.unsqueeze(0) * (result + s * lora_result) | ||
| ctx.save_for_backward( | ||
| x, A.to(dtype), B.to(dtype), after_A, | ||
| magnitude, mag_scale, result, # result = combined * mag_scale | ||
| ) |
There was a problem hiding this comment.
Save the pre-scaled embedding activation for DoRA backward.
d_mag is derived from the unscaled base + s * lora_result, but the forward path currently saves result after multiplying by mag_scale. That applies the scale twice in backward and makes embedding-DoRA magnitude gradients too large.
🐛 Suggested fix
if has_dora:
mag_scale = _compute_dora_scale(
W.t(), None, A, B, s, magnitude, dtype
)
- # DoRA: mag_scale * (base + s * lora) + bias
- # base embedding has no bias
- result = mag_scale.unsqueeze(0) * (result + s * lora_result)
+ combined = result + s * lora_result
+ result = mag_scale.unsqueeze(0) * combined
ctx.save_for_backward(
x, A.to(dtype), B.to(dtype), after_A,
- magnitude, mag_scale, result, # result = combined * mag_scale
+ magnitude, mag_scale, combined,
)Also applies to: 1309-1316
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/kernels/lora.py` around lines 1274 - 1284, The forward currently
multiplies the combined embedding by mag_scale before saving, so backward
computes d_mag from the already-scaled result and double-applies mag_scale;
change the save to store the pre-scaled activation instead: compute pre_scaled =
result + s * lora_result (or the equivalent unscaled combined tensor) and call
ctx.save_for_backward(x, A.to(dtype), B.to(dtype), after_A, magnitude,
mag_scale, pre_scaled) so d_mag is derived from the unscaled combination; update
the same change in the second occurrence (around lines 1309-1316) where has_dora
is handled.
| if has_lora: | ||
| # Use float32 for gradient computation (LoRA params are fp32) | ||
| compute_dtype = torch.float32 | ||
|
|
||
| after_A_flat = after_A.view(-1, after_A.shape[-1]).to(compute_dtype) | ||
| grad_flat = grad_output.view(-1, grad_output.shape[-1]).to(compute_dtype) | ||
| B_f = B.to(compute_dtype) | ||
|
|
||
| # B is [hidden_dim, rank], B_T = B.t() = [rank, hidden_dim] | ||
| # lora_result = after_A @ B_T → d/d(B_T) = s * after_A^T @ grad | ||
| B_T = B_f.t() # [rank, hidden_dim] | ||
| d_B_T = torch.empty_like(B_T) | ||
| d_B_T.addmm_(after_A_flat.t(), grad_flat, alpha=s, beta=0) | ||
| d_B = d_B_T.t() # [hidden_dim, rank] | ||
|
|
||
| # d_A: gradient flows through F.embedding lookup | ||
| # d_after_A = s * grad @ B = [T, hidden] @ [hidden, rank] = [T, rank] | ||
| d_after_A = s * grad_flat @ B_f | ||
|
|
||
| # F.embedding backward: scatter d_after_A into A^T gradient | ||
| x_flat = x.view(-1) | ||
| A_f = A.to(compute_dtype) | ||
| d_A_T = torch.zeros_like(A_f.t()) # [vocab, rank] | ||
| d_A_T.index_add_(0, x_flat, d_after_A) | ||
| d_A = d_A_T.t() # [rank, vocab] | ||
|
|
||
| # x, W, A, B, s, magnitude, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse | ||
| return ( | ||
| None, # x | ||
| None, # W (base embedding weight grad handled by PyTorch) | ||
| d_A, # A | ||
| d_B, # B | ||
| None, # s | ||
| d_mag, # magnitude | ||
| None, None, None, None, None, # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the full function definition around lines 1320-1355
sed -n '1280,1360p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2844
🏁 Script executed:
# Find the function definition to understand the full signature and inputs
sed -n '1200,1280p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2827
🏁 Script executed:
# Get the full function from the beginning
rg -A 200 "def.*backward" src/axolotl/kernels/lora.py | head -250Repository: axolotl-ai-cloud/axolotl
Length of output: 9123
🏁 Script executed:
# Get the exact lines 1320-1355
sed -n '1310,1365p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 2283
🏁 Script executed:
# Find the function that contains this code - look backwards for "def backward"
sed -n '1200,1360p' src/axolotl/kernels/lora.py | grep -n "def backward"Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Search for any embedding-related functions in the file
rg "embedding" src/axolotl/kernels/lora.py -i -B 3 -A 3Repository: axolotl-ai-cloud/axolotl
Length of output: 3576
🏁 Script executed:
# Get the complete LoRA_Embedding class definition
rg -A 300 "class LoRA_Embedding" src/axolotl/kernels/lora.py | head -400Repository: axolotl-ai-cloud/axolotl
Length of output: 5954
🏁 Script executed:
# Also check the backward function signature specifically
sed -n '1240,1365p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 4470
🏁 Script executed:
# Check what PyTorch's F.embedding backward does with padding_idx
# Let's also search for any W gradient handling or patching logic
rg "padding_idx\|scale_grad_by_freq" src/axolotl/kernels/lora.py -B 3 -A 3Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Check if there's any separate handling for W gradients in embedding
rg "embedding.*grad\|grad.*embedding" src/axolotl/kernels/ -i -B 2 -A 2Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Look for how this LoRA_Embedding is being used and patched
rg "apply_lora_embedding\|LoRA_Embedding" src/axolotl/ -B 3 -A 3Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Look at the broader context - check if W is saved in saved_tensors
sed -n '1200,1290p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 3028
🏁 Script executed:
# Search for any usage patterns to understand how this is patched
fd "\.py$" src/axolotl/ --exec grep -l "apply_lora_embedding" {} \;Repository: axolotl-ai-cloud/axolotl
Length of output: 135
🏁 Script executed:
# Check if there's any patching or module replacement logic
rg "forward_pre_hook\|forward_hook\|apply\|register" src/axolotl/kernels/lora.py | head -20Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Check the monkeypatch file to see how apply_lora_embedding is being used
cat src/axolotl/monkeypatch/lora_kernels.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 18197
🏁 Script executed:
# Also check if there's any documentation about F.embedding backward semantics
# Let's look at PyTorch F.embedding to understand what padding_idx and scale_grad_by_freq do
python3 << 'EOF'
import torch
import torch.nn.functional as F
# Check F.embedding documentation
help(F.embedding)
EOFRepository: axolotl-ai-cloud/axolotl
Length of output: 185
🏁 Script executed:
# Let's verify what happens with W - check if it's supposed to be computed
sed -n '1290,1330p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1555
🏁 Script executed:
# Check if there's any W gradient handling elsewhere
rg "d_W\|grad.*W" src/axolotl/kernels/lora.py | grep -i embedRepository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Look at how F.embedding handles padding_idx semantics by checking PyTorch docs
python3 << 'EOF'
# Let's write a simple test to show how PyTorch's F.embedding backward handles padding_idx
import torch
# Create a simple embedding and test with padding_idx
embedding_weight = torch.randn(10, 5, requires_grad=True)
x = torch.tensor([0, 1, 2, 2, 2]) # 2 is padding_idx
# Forward with padding_idx
output = torch.nn.functional.embedding(x, embedding_weight, padding_idx=2)
# Simple backward
loss = output.sum()
loss.backward()
print("Gradient for padding_idx row (should be zero):")
print(embedding_weight.grad[2]) # Should be zero
print("\nGradient for non-padding rows (should be non-zero):")
print(embedding_weight.grad[0:2])
EOFRepository: axolotl-ai-cloud/axolotl
Length of output: 185
🏁 Script executed:
# Check if padding_idx is actually used in the backward - it's stored in ctx but let's verify
sed -n '1240,1360p' src/axolotl/kernels/lora.py | grep -n "padding_idx"Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Let's check the complete backward to see all uses of context variables
sed -n '1270,1360p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 3043
🏁 Script executed:
# Search for any tests or usage of LoRA_Embedding to understand expected behavior
fd "test.*lora\|lora.*test" src/axolotl --type f -x grep -l "embedding" {} \;Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Check if there are any comments about W gradient handling or padding
rg "base embedding\|W.*grad\|padding" src/axolotl/kernels/lora.py -B 2 -A 2 | head -40Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Let's look at what PyTorch's standard embedding backward actually does
# by checking any comments about it in the code
rg "F.embedding\|embedding.*backward" src/axolotl/kernels/lora.py -B 3 -A 3Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Check if W is ever supposed to have gradients in this LoRA kernel
# Look for how it's used in forward and what the comment says
sed -n '1200,1250p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 1796
🏁 Script executed:
# Most importantly, let's see if padding_idx is stored and how it should be used
sed -n '1200,1400p' src/axolotl/kernels/lora.py | cat -nRepository: axolotl-ai-cloud/axolotl
Length of output: 7036
🏁 Script executed:
# Check the complete LoRA_Embedding class one more time to see full backward
sed -n '1000,1365p' src/axolotl/kernels/lora.py | grep -A 200 "class LoRA_Embedding"Repository: axolotl-ai-cloud/axolotl
Length of output: 50
🏁 Script executed:
# Get file length first
wc -l src/axolotl/kernels/lora.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 100
🏁 Script executed:
# Search for LoRA_Embedding class more broadly
rg "class LoRA_Embedding" src/axolotl/kernels/lora.py -A 350 | head -400Repository: axolotl-ai-cloud/axolotl
Length of output: 5954
Fix the embedding backward to respect padding_idx and scale_grad_by_freq, or skip patching when these features are active.
The custom embedding backward ignores padding_idx during gradient accumulation and doesn't apply scale_grad_by_freq scaling. The standard F.embedding backward masks gradients for padded positions, but index_add_ accumulates from all positions. This causes the fused LoRA_A and LoRA_B gradients to diverge from the unfused path when these features are enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/kernels/lora.py` around lines 1320 - 1355, The custom embedding
backward for LoRA ignores padding_idx and scale_grad_by_freq: when padding_idx
is set or scale_grad_by_freq is true you must either reproduce F.embedding's
masking/scaling logic before accumulating into d_A_T or avoid the fused path and
return None so PyTorch's default embedding backward runs; to fix, in the
backward branch that computes d_after_A (symbol d_after_A) and accumulates with
d_A_T.index_add_(0, x_flat, d_after_A), filter out positions where x_flat ==
padding_idx (or zero-out corresponding rows in d_after_A) and, if
scale_grad_by_freq is true, divide each contribution by token frequency (compute
counts from x_flat) to match F.embedding behavior; alternatively early-check
padding_idx/scale_grad_by_freq and skip the fused gradient patch (return None
for d_A) when those features are active.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
Summary by CodeRabbit
Release Notes
New Features
Tests