Skip to content

FSDP2 + LoRA kernels#2992

Merged
djsaunde merged 13 commits into
mainfrom
lora-kernels-fsdp
Aug 4, 2025
Merged

FSDP2 + LoRA kernels#2992
djsaunde merged 13 commits into
mainfrom
lora-kernels-fsdp

Conversation

@djsaunde
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde commented Jul 30, 2025

Description

Enabling LoRA kernels with FSDP2. Couple of problems solved here:

  1. LoRA parameters were not being unsharded during the decoder block forward pass, so they're weren't unsharded at the time of the fused QKV + LoRA or MLP + LoRA functions. This is because the LoRA parameters are their own FSDP modules and get unsharded in the normal flow (i.e., sans kernels) during their own forward pass. @salmanmohammadi and I chatted about this briefly; I wonder if they should get rolled into the FDSP parameter group of the decoder layer, but apparently this causes problems?
  2. In FSDP2, all torch.nn.Parameter subclasses (e.g., bitsandbytes.nn.modules.Params4bit) get unsharded into regular torch.nn.Parameters and hence lose their metadata. We need Params4bit quant state, etc. in order to dequantization. FSDP2 + LoRA is already ungated on main, but was failing silently due to this with little to no memory savings.

Props to @ved1beta for the initial Params4Bit handling code (now modified in a patch in this PR)!

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Improved support for Fully Sharded Data Parallel (FSDP) version 2 with LoRA and QLoRA adapters, including enhanced compatibility with quantized parameters.
    • Added kernel-level LoRA optimizations for more efficient training.
  • Bug Fixes

    • Addressed issues with parameter handling for 4-bit quantized models in distributed training setups.
  • Tests

    • Introduced new end-to-end and integration tests to verify FSDP2 compatibility and correct patching of quantized parameter handling.
  • Chores

    • Updated the bitsandbytes package to version 0.46.1.
  • Documentation & Validation

    • Enhanced configuration validation to prevent incompatible combinations of LoRA kernel flags with 8-bit loading, DoRA, and RL training modes.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jul 30, 2025

📝 Walkthrough

Walkthrough

This update introduces a new monkeypatch module to support bitsandbytes 4-bit quantized parameters with PyTorch FSDP v2, modifies LoRA kernel handling for compatibility, and adds both integration and end-to-end tests for these changes. It also updates a dependency version in requirements.txt and refines validation logic for LoRA kernel flags.

Changes

Cohort / File(s) Change Summary
Dependency Update
requirements.txt
Updated bitsandbytes version from 0.46.0 to 0.46.1.
LoRA Kernel Compatibility
src/axolotl/kernels/lora.py
Refactored LoRA parameter access to ensure manual unsharding with FSDP v2; combined transpose and dtype conversion steps; reordered LoRA matmul scaling.
Patch Manager Extension
src/axolotl/loaders/patch_manager.py
Added _apply_fsdp2_bnb_patches method to apply FSDP v2 and bitsandbytes patches during pre-model load when using qlora adapter and FSDP enabled.
FSDP2 QLoRA Monkeypatch
src/axolotl/monkeypatch/fsdp2_qlora.py
New module patching FSDPParam and Params4bit to support bitsandbytes 4-bit quantized parameters with PyTorch FSDP v2, including torch function protocol patching for Params4bit.
E2E FSDP2 Kernel Tests
tests/e2e/multigpu/test_fsdp2.py
Added tests for LoRA and QLoRA SFT with kernel optimizations enabled, verifying successful training and checkpoint creation under FSDP v2.
Unit & Integration Tests for Patches
tests/e2e/patched/test_fsdp2_qlora.py
Added tests verifying Params4bit and FSDPParam monkeypatches, including attribute preservation and patch application, using mocks and integration checks.
LoRA Kernel Flags Validation
src/axolotl/utils/schemas/validation.py
Removed deprecated validator; renamed and refined existing validators; added new validators enforcing mutual exclusivity of LoRA kernel flags with DoRA and RL modes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

ready to merge

Suggested reviewers

  • winglian
  • SalmanMohammadi

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 588a365 and 646de8a.

📒 Files selected for processing (1)
  • src/axolotl/monkeypatch/fsdp2_qlora.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/monkeypatch/fsdp2_qlora.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: pre-commit
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: preview
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lora-kernels-fsdp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@djsaunde
Copy link
Copy Markdown
Collaborator Author

FSDP2 + LoRA is working, QLoRA currently erroring. Fixing now.

Comment thread src/axolotl/kernels/lora.py Outdated
@djsaunde
Copy link
Copy Markdown
Collaborator Author

I think I've discovered an issue with FSDP2 + QLoRA, even without LoRA kernels enabled.

It seems that FSDP2 is converting Linear4Bit parameters to regular torch.Tensors. Running a training with this config and toggling FSDP2, and debugging from this point in the transformers/models/llama/modeling_llama.py module:

(Pdb) ll
223         def forward(
224             self,
225             hidden_states: torch.Tensor,
226             position_embeddings: tuple[torch.Tensor, torch.Tensor],
227             attention_mask: Optional[torch.Tensor],
228             past_key_value: Optional[Cache] = None,
229             cache_position: Optional[torch.LongTensor] = None,
230             **kwargs: Unpack[TransformersKwargs],
231         ) -> tuple[torch.Tensor, torch.Tensor]:
232             import torch.distributed as dist
233  ->         dist.breakpoint()
234
235             input_shape = hidden_states.shape[:-1]
236             hidden_shape = (*input_shape, -1, self.head_dim)
...

With FSDP2:

(Pdb) self.q_proj.base_layer.weight
Parameter containing:
tensor([[-1.1614e+13],
        [ 1.4241e+22],
        [ 8.9369e+31],
        ...,
        [-1.5348e-26],
        [ 3.4903e+17],
        [-1.5272e-34]], device='cuda:0', dtype=torch.bfloat16)

Without FSDP2:

(Pdb) self.q_proj.base_layer.weight
Parameter containing:
Parameter(Params4bit([[-1.1614e+13],
            [ 1.4241e+22],
            [ 8.9369e+31],
            ...,
            [-1.5348e-26],
            [ 3.4903e+17],
            [-1.5272e-34]], device='cuda:0', dtype=torch.bfloat16))

I discovered this since the LoRA kernels explicitly try to access the base_layer.quant_state attribute, but this fails in the former case.

Comment thread src/axolotl/kernels/lora.py
@djsaunde djsaunde force-pushed the lora-kernels-fsdp branch from 6429342 to cb11314 Compare August 3, 2025 01:48
@djsaunde
Copy link
Copy Markdown
Collaborator Author

djsaunde commented Aug 3, 2025

Testing with llama 3.2 1b (FSDP2 + QLoRA + kernels config variant):

FSDP2 + LoRA

  • 2.4s/it
  • 6.1GB VRAM

FSDP2 + LoRA + kernels

  • 1.83s/it
  • 5GB VRAM

FSDP2 + QLoRA

  • 2.4s/it
  • 5.2GB VRAM

FSDP2 + QLoRA + kernels

  • 1.65s/it
  • 4.2GB VRAM

@djsaunde djsaunde marked this pull request as ready for review August 3, 2025 03:35
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (1)
src/axolotl/monkeypatch/fsdp2_qlora.py (1)

142-205: Add return value and document _local_tensor dependency.

Similar to the previous function, this needs a return value. Also, the dependency on the private _local_tensor attribute should be documented.

 def apply_init_unsharded_param_patch():
-    """Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
+    """Apply patch to FSDPParam.init_unsharded_param to support Params4bit.
+    
+    Note: This patch depends on the private _local_tensor attribute of sharded_param.
+    
+    Returns:
+        True if patching succeeded, False otherwise.
+    """

And at the end:

         # Replace the method
         FSDPParam.init_unsharded_param = patched_init_unsharded_param  # pylint: disable=undefined-variable  # noqa: F821
         LOG.info("Successfully applied surgical FSDP patch")
+        return True
     else:
         LOG.warning("Could not find target code for patching")
+        return False
+
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between deac7b1 and 27dc994.

📒 Files selected for processing (6)
  • requirements.txt (1 hunks)
  • src/axolotl/kernels/lora.py (3 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/fsdp2_qlora.py (1 hunks)
  • tests/e2e/multigpu/test_fsdp2.py (2 hunks)
  • tests/e2e/patched/test_fsdp2_qlora.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/e2e/multigpu/test_fsdp2.py (3)
tests/e2e/utils.py (1)
  • require_torch_2_7_0 (80-89)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/multigpu/test_fsdp1.py (1)
  • verify_training_success (22-48)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/fsdp2_qlora.py (3)
  • apply_bnb_torch_function_patch (64-75)
  • apply_init_sharded_param_patch (79-139)
  • apply_init_unsharded_param_patch (142-204)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
🔇 Additional comments (12)
requirements.txt (1)

4-4: LGTM! Dependency update aligns with new FSDP2 patches.

The bitsandbytes version bump to 0.46.1 is necessary to support the new FSDP2 + bitsandbytes integration patches introduced in this PR.

src/axolotl/loaders/patch_manager.py (2)

68-68: LGTM! Patch application follows established patterns.

The new _apply_fsdp2_bnb_patches() call is appropriately placed in the pre-model load patching sequence.


264-279: Conditional patching logic is well-structured.

The method correctly applies FSDP2 + bitsandbytes patches only when all required conditions are met:

  • FSDP is configured
  • FSDP version is 2
  • Adapter is set to "qlora"

This prevents unnecessary patching and follows the existing pattern in the class.

tests/e2e/patched/test_fsdp2_qlora.py (4)

16-28: Well-structured fixture for test data.

The mock_params4bit fixture appropriately creates a mock with all the necessary attributes that would be present on a real Params4bit instance.


34-39: Good verification of patch application.

The test correctly verifies that the patch was applied by checking that the __torch_function__ attribute exists and is a classmethod.


42-65: Comprehensive test of torch.chunk behavior preservation.

The test effectively verifies that:

  • torch.chunk preserves Params4bit attributes during the patched torch function
  • The correct number of chunks are created
  • All required attributes are properly passed to the constructor

88-131: Thorough integration test validates all patches.

The test properly:

  • Stores original method references before patching
  • Applies all three patches in sequence
  • Verifies each patch was applied by comparing method references
  • Handles both cases where original methods exist or don't exist

This provides good coverage of the complete patching workflow.

tests/e2e/multigpu/test_fsdp2.py (2)

177-240: Excellent e2e test for LoRA kernels with FSDP2.

The test properly validates the integration of:

  • FSDP version 2
  • LoRA adapter with DoRA parameterization
  • All three kernel types (mlp, qkv, o)

The configuration mirrors existing tests while adding the critical kernel flags, providing good coverage for the new functionality.


304-366: Comprehensive QLoRA + kernels integration test.

This test validates the complete QLoRA + FSDP2 + kernels workflow that was specifically mentioned in the PR objectives. The combination of:

  • 4-bit quantization (load_in_4bit: True)
  • QLoRA adapter
  • FSDP version 2
  • All kernel optimizations enabled

provides excellent end-to-end coverage of the new functionality.

src/axolotl/kernels/lora.py (3)

17-17: Good addition of DTensor import for FSDP2 support.

The import is correctly placed and necessary for the manual unsharding logic introduced later in the file.


59-72: Manual unsharding implementation is well-documented and necessary.

The comments clearly explain why manual unsharding is required for FSDP2 + LoRA kernels compatibility. The implementation correctly:

  1. Extracts linear layers before checking DTensor type
  2. Only unshards when needed (DTensor instances)
  3. Applies unsharding to both A and B parameters

The note about not resharding later due to complexity is reasonable since LoRA parameters are typically small.


119-120: Optimization combines tensor operations efficiently.

The change combines transpose and dtype conversion into a single step, then reorders the multiplication and scaling operations. This maintains the same mathematical result while potentially improving performance.

Comment thread src/axolotl/monkeypatch/fsdp2_qlora.py
Comment thread src/axolotl/monkeypatch/fsdp2_qlora.py
Comment thread src/axolotl/monkeypatch/fsdp2_qlora.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 3, 2025

Codecov Report

❌ Patch coverage is 11.23596% with 79 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/fsdp2_qlora.py 0.00% 64 Missing ⚠️
src/axolotl/kernels/lora.py 0.00% 10 Missing ⚠️
src/axolotl/loaders/patch_manager.py 42.85% 4 Missing ⚠️
src/axolotl/utils/schemas/validation.py 87.50% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian
Copy link
Copy Markdown
Collaborator

winglian commented Aug 3, 2025

I think I've discovered an issue with FSDP2 + QLoRA, even without LoRA kernels enabled.

It seems that FSDP2 is converting Linear4Bit parameters to regular torch.Tensors. Running a training with this config and toggling FSDP2, and debugging from this point in the transformers/models/llama/modeling_llama.py module:

We already discussed, but adding this to link the issues: bitsandbytes-foundation/bitsandbytes#1612

Comment thread tests/e2e/multigpu/test_fsdp2.py Outdated
@djsaunde
Copy link
Copy Markdown
Collaborator Author

djsaunde commented Aug 3, 2025

I think I've discovered an issue with FSDP2 + QLoRA, even without LoRA kernels enabled.
It seems that FSDP2 is converting Linear4Bit parameters to regular torch.Tensors. Running a training with this config and toggling FSDP2, and debugging from this point in the transformers/models/llama/modeling_llama.py module:

We already discussed, but adding this to link the issues: bitsandbytes-foundation/bitsandbytes#1612

Ah, new PR: bitsandbytes-foundation/bitsandbytes#1719

Comment thread src/axolotl/monkeypatch/fsdp2_qlora.py
Comment on lines +88 to +104
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""

patched_param_creation = """ import bitsandbytes as bnb
if isinstance(param, bnb.nn.modules.Params4bit):
self.sharded_param = bnb.nn.modules.Params4bit(
data=sharded_param,
requires_grad=param.requires_grad,
quant_state=param.quant_state,
blocksize=param.blocksize,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
quant_storage=param.quant_storage,
module=param.module,
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas on how this could be managed upstream in the long term?

Copy link
Copy Markdown
Collaborator Author

@djsaunde djsaunde Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exact code would be a hard sell obviously, but we could try to upstream more generic torch.nn.Parameter subclass support. E.g.:

cls = type(param)
if isinstance(param, torch.nn.Parameter) and cls is not torch.nn.Parameter:
    self.sharded_param = cls(
        data=sharded_param,
        **cls_kwargs,
    )
    self.sharded_param = self.to_sharded_dtensor(self.sharded_param)

Copy link
Copy Markdown
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amazing! thank you!

@djsaunde
Copy link
Copy Markdown
Collaborator Author

djsaunde commented Aug 3, 2025

Btw, I will probably add 8bit LoRA support and test with 8bit optims prior to merge.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Aug 3, 2025

📖 Documentation Preview: https://688feb57336ce20273243b95--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 646de8a

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/axolotl/utils/schemas/validation.py (2)

608-619: LGTM with improvement suggestion.

The method renaming for consistency and the typo fix in the error message are good improvements. However, consider the static analysis suggestion to simplify the nested if statements.

Apply this diff to combine the nested if statements:

-    def check_lora_kernels_8bit(cls, data):
-        if (
-            data.get("lora_mlp_kernel")
-            or data.get("lora_qkv_kernel")
-            or data.get("lora_o_kernel")
-        ):
-            if data.get("adapter") == "lora" and data.get("load_in_8bit"):
-                raise ValueError(
-                    "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
-                    "compatible with 8-bit LoRA a the moment."
-                )
+    def check_lora_kernels_8bit(cls, data):
+        if (
+            data.get("lora_mlp_kernel")
+            or data.get("lora_qkv_kernel")
+            or data.get("lora_o_kernel")
+        ) and data.get("adapter") == "lora" and data.get("load_in_8bit"):
+            raise ValueError(
+                "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
+                "compatible with 8-bit LoRA at the moment."
+            )

623-633: LGTM - Good addition for DoRA compatibility.

The new validation method correctly prevents incompatible combinations of LoRA kernels with DoRA, following the established validation pattern.

For consistency with the refactoring suggestion above, consider combining the if statements:

-    def check_lora_kernels_dora(cls, data):
-        if (
-            data.get("lora_mlp_kernel")
-            or data.get("lora_qkv_kernel")
-            or data.get("lora_o_kernel")
-        ) and data.get("peft_use_dora"):
-            raise ValueError(
-                "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
-                "compatible with DoRA at the moment."
-            )
+    def check_lora_kernels_dora(cls, data):
+        if (
+            data.get("lora_mlp_kernel")
+            or data.get("lora_qkv_kernel")
+            or data.get("lora_o_kernel")
+        ) and data.get("peft_use_dora"):
+            raise ValueError(
+                "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
+                "compatible with DoRA at the moment."
+            )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27dc994 and e119043.

📒 Files selected for processing (2)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
  • tests/e2e/multigpu/test_fsdp2.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/e2e/multigpu/test_fsdp2.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

609-614: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: docker-e2e-tests-1st (126, 12.6.3, 3.11, 2.6.0, 1, Dockerfile-uv.jinja)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
🔇 Additional comments (1)
src/axolotl/utils/schemas/validation.py (1)

637-647: LGTM - Consistent naming and organization.

The method renaming and reordering improve code organization by grouping all LoRA kernel validation methods together with consistent naming.

@djsaunde
Copy link
Copy Markdown
Collaborator Author

djsaunde commented Aug 3, 2025

8 bit LoRA support will require some more work, probably best to do in another PR.

8 bit optim appears to work, but I didn't see a big diff in it/s or VRAM.

@djsaunde djsaunde merged commit e758343 into main Aug 4, 2025
15 of 16 checks passed
@djsaunde djsaunde deleted the lora-kernels-fsdp branch August 4, 2025 00:05
@coderabbitai coderabbitai Bot mentioned this pull request Aug 25, 2025
@coderabbitai coderabbitai Bot mentioned this pull request Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants