FSDP2 + LoRA kernels#2992
Conversation
📝 WalkthroughWalkthroughThis update introduces a new monkeypatch module to support bitsandbytes 4-bit quantized parameters with PyTorch FSDP v2, modifies LoRA kernel handling for compatibility, and adds both integration and end-to-end tests for these changes. It also updates a dependency version in requirements.txt and refines validation logic for LoRA kernel flags. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
|
FSDP2 + LoRA is working, QLoRA currently erroring. Fixing now. |
|
I think I've discovered an issue with FSDP2 + QLoRA, even without LoRA kernels enabled. It seems that FSDP2 is converting Linear4Bit parameters to regular torch.Tensors. Running a training with this config and toggling FSDP2, and debugging from this point in the (Pdb) ll
223 def forward(
224 self,
225 hidden_states: torch.Tensor,
226 position_embeddings: tuple[torch.Tensor, torch.Tensor],
227 attention_mask: Optional[torch.Tensor],
228 past_key_value: Optional[Cache] = None,
229 cache_position: Optional[torch.LongTensor] = None,
230 **kwargs: Unpack[TransformersKwargs],
231 ) -> tuple[torch.Tensor, torch.Tensor]:
232 import torch.distributed as dist
233 -> dist.breakpoint()
234
235 input_shape = hidden_states.shape[:-1]
236 hidden_shape = (*input_shape, -1, self.head_dim)
...With FSDP2: (Pdb) self.q_proj.base_layer.weight
Parameter containing:
tensor([[-1.1614e+13],
[ 1.4241e+22],
[ 8.9369e+31],
...,
[-1.5348e-26],
[ 3.4903e+17],
[-1.5272e-34]], device='cuda:0', dtype=torch.bfloat16)Without FSDP2: (Pdb) self.q_proj.base_layer.weight
Parameter containing:
Parameter(Params4bit([[-1.1614e+13],
[ 1.4241e+22],
[ 8.9369e+31],
...,
[-1.5348e-26],
[ 3.4903e+17],
[-1.5272e-34]], device='cuda:0', dtype=torch.bfloat16))I discovered this since the LoRA kernels explicitly try to access the |
6429342 to
cb11314
Compare
|
Testing with llama 3.2 1b (FSDP2 + QLoRA + kernels config variant): FSDP2 + LoRA
FSDP2 + LoRA + kernels
FSDP2 + QLoRA
FSDP2 + QLoRA + kernels
|
There was a problem hiding this comment.
Actionable comments posted: 3
🔭 Outside diff range comments (1)
src/axolotl/monkeypatch/fsdp2_qlora.py (1)
142-205: Add return value and document _local_tensor dependency.Similar to the previous function, this needs a return value. Also, the dependency on the private
_local_tensorattribute should be documented.def apply_init_unsharded_param_patch(): - """Apply patch to FSDPParam.init_unsharded_param to support Params4bit.""" + """Apply patch to FSDPParam.init_unsharded_param to support Params4bit. + + Note: This patch depends on the private _local_tensor attribute of sharded_param. + + Returns: + True if patching succeeded, False otherwise. + """And at the end:
# Replace the method FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821 LOG.info("Successfully applied surgical FSDP patch") + return True else: LOG.warning("Could not find target code for patching") + return False +
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
requirements.txt(1 hunks)src/axolotl/kernels/lora.py(3 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/fsdp2_qlora.py(1 hunks)tests/e2e/multigpu/test_fsdp2.py(2 hunks)tests/e2e/patched/test_fsdp2_qlora.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/e2e/multigpu/test_fsdp2.py (3)
tests/e2e/utils.py (1)
require_torch_2_7_0(80-89)src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/multigpu/test_fsdp1.py (1)
verify_training_success(22-48)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/fsdp2_qlora.py (3)
apply_bnb_torch_function_patch(64-75)apply_init_sharded_param_patch(79-139)apply_init_unsharded_param_patch(142-204)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
🔇 Additional comments (12)
requirements.txt (1)
4-4: LGTM! Dependency update aligns with new FSDP2 patches.The bitsandbytes version bump to 0.46.1 is necessary to support the new FSDP2 + bitsandbytes integration patches introduced in this PR.
src/axolotl/loaders/patch_manager.py (2)
68-68: LGTM! Patch application follows established patterns.The new
_apply_fsdp2_bnb_patches()call is appropriately placed in the pre-model load patching sequence.
264-279: Conditional patching logic is well-structured.The method correctly applies FSDP2 + bitsandbytes patches only when all required conditions are met:
- FSDP is configured
- FSDP version is 2
- Adapter is set to "qlora"
This prevents unnecessary patching and follows the existing pattern in the class.
tests/e2e/patched/test_fsdp2_qlora.py (4)
16-28: Well-structured fixture for test data.The mock_params4bit fixture appropriately creates a mock with all the necessary attributes that would be present on a real Params4bit instance.
34-39: Good verification of patch application.The test correctly verifies that the patch was applied by checking that the
__torch_function__attribute exists and is a classmethod.
42-65: Comprehensive test of torch.chunk behavior preservation.The test effectively verifies that:
- torch.chunk preserves Params4bit attributes during the patched torch function
- The correct number of chunks are created
- All required attributes are properly passed to the constructor
88-131: Thorough integration test validates all patches.The test properly:
- Stores original method references before patching
- Applies all three patches in sequence
- Verifies each patch was applied by comparing method references
- Handles both cases where original methods exist or don't exist
This provides good coverage of the complete patching workflow.
tests/e2e/multigpu/test_fsdp2.py (2)
177-240: Excellent e2e test for LoRA kernels with FSDP2.The test properly validates the integration of:
- FSDP version 2
- LoRA adapter with DoRA parameterization
- All three kernel types (mlp, qkv, o)
The configuration mirrors existing tests while adding the critical kernel flags, providing good coverage for the new functionality.
304-366: Comprehensive QLoRA + kernels integration test.This test validates the complete QLoRA + FSDP2 + kernels workflow that was specifically mentioned in the PR objectives. The combination of:
- 4-bit quantization (
load_in_4bit: True)- QLoRA adapter
- FSDP version 2
- All kernel optimizations enabled
provides excellent end-to-end coverage of the new functionality.
src/axolotl/kernels/lora.py (3)
17-17: Good addition of DTensor import for FSDP2 support.The import is correctly placed and necessary for the manual unsharding logic introduced later in the file.
59-72: Manual unsharding implementation is well-documented and necessary.The comments clearly explain why manual unsharding is required for FSDP2 + LoRA kernels compatibility. The implementation correctly:
- Extracts linear layers before checking DTensor type
- Only unshards when needed (DTensor instances)
- Applies unsharding to both A and B parameters
The note about not resharding later due to complexity is reasonable since LoRA parameters are typically small.
119-120: Optimization combines tensor operations efficiently.The change combines transpose and dtype conversion into a single step, then reorders the multiplication and scaling operations. This maintains the same mathematical result while potentially improving performance.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
We already discussed, but adding this to link the issues: bitsandbytes-foundation/bitsandbytes#1612 |
Ah, new PR: bitsandbytes-foundation/bitsandbytes#1719 |
| original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) | ||
| self.sharded_param.requires_grad_(param.requires_grad)""" | ||
|
|
||
| patched_param_creation = """ import bitsandbytes as bnb | ||
| if isinstance(param, bnb.nn.modules.Params4bit): | ||
| self.sharded_param = bnb.nn.modules.Params4bit( | ||
| data=sharded_param, | ||
| requires_grad=param.requires_grad, | ||
| quant_state=param.quant_state, | ||
| blocksize=param.blocksize, | ||
| compress_statistics=param.compress_statistics, | ||
| quant_type=param.quant_type, | ||
| quant_storage=param.quant_storage, | ||
| module=param.module, | ||
| bnb_quantized=param.bnb_quantized, | ||
| ) | ||
| self.sharded_param = self.to_sharded_dtensor(self.sharded_param) |
There was a problem hiding this comment.
Any ideas on how this could be managed upstream in the long term?
There was a problem hiding this comment.
This exact code would be a hard sell obviously, but we could try to upstream more generic torch.nn.Parameter subclass support. E.g.:
cls = type(param)
if isinstance(param, torch.nn.Parameter) and cls is not torch.nn.Parameter:
self.sharded_param = cls(
data=sharded_param,
**cls_kwargs,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)|
Btw, I will probably add 8bit LoRA support and test with 8bit optims prior to merge. |
|
📖 Documentation Preview: https://688feb57336ce20273243b95--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 646de8a |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/utils/schemas/validation.py (2)
608-619: LGTM with improvement suggestion.The method renaming for consistency and the typo fix in the error message are good improvements. However, consider the static analysis suggestion to simplify the nested if statements.
Apply this diff to combine the nested if statements:
- def check_lora_kernels_8bit(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " - "compatible with 8-bit LoRA a the moment." - ) + def check_lora_kernels_8bit(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with 8-bit LoRA at the moment." + )
623-633: LGTM - Good addition for DoRA compatibility.The new validation method correctly prevents incompatible combinations of LoRA kernels with DoRA, following the established validation pattern.
For consistency with the refactoring suggestion above, consider combining the if statements:
- def check_lora_kernels_dora(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ) and data.get("peft_use_dora"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " - "compatible with DoRA at the moment." - ) + def check_lora_kernels_dora(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("peft_use_dora"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with DoRA at the moment." + )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/utils/schemas/validation.py(1 hunks)tests/e2e/multigpu/test_fsdp2.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/e2e/multigpu/test_fsdp2.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
609-614: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: docker-e2e-tests-1st (126, 12.6.3, 3.11, 2.6.0, 1, Dockerfile-uv.jinja)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
🔇 Additional comments (1)
src/axolotl/utils/schemas/validation.py (1)
637-647: LGTM - Consistent naming and organization.The method renaming and reordering improve code organization by grouping all LoRA kernel validation methods together with consistent naming.
|
8 bit LoRA support will require some more work, probably best to do in another PR. 8 bit optim appears to work, but I didn't see a big diff in it/s or VRAM. |
Description
Enabling LoRA kernels with FSDP2. Couple of problems solved here:
torch.nn.Parametersubclasses (e.g.,bitsandbytes.nn.modules.Params4bit) get unsharded into regulartorch.nn.Parameters and hence lose their metadata. We needParams4bitquant state, etc. in order to dequantization. FSDP2 + LoRA is already ungated on main, but was failing silently due to this with little to no memory savings.Props to @ved1beta for the initial Params4Bit handling code (now modified in a patch in this PR)!
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores
Documentation & Validation