Ensure device mesh patching is applied#2842
Conversation
WalkthroughThe sequence parallel patching logic has been moved from the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PatchManager
participant PatchUtils
User->>PatchManager: apply_pre_model_load_patches(config)
PatchManager->>PatchManager: _apply_sequence_parallel_patches(config)
alt sequence_parallel_degree > 1
PatchManager->>PatchUtils: patch_prepare_data_loader()
PatchManager->>PatchUtils: patch_prepare_device_mesh(sequence_parallel_degree, fsdp)
end
Possibly related PRs
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (6)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
|
Looks like you need a guard in the _apply_sequence_parallel_patches function in case sp degree is none |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py (1)
429-435: Consider the implications of introducing config loading pipeline into kernel patching tests.While this change aligns with testing the actual usage pattern, it significantly changes the test scope. The
load_cfgfunction performs extensive validation, normalization, plugin preparation, and environment setup that wasn't part of the original test. This makes the test less isolated and introduces dependencies on the entire configuration loading pipeline.Consider whether these tests should remain focused on kernel patching functionality specifically, or if the broader config loading integration should be tested separately.
Additionally, add error handling for the file operations:
# Write cfg to yaml file path = Path(temp_dir) / "config.yaml" -with open(path, "w", encoding="utf-8") as fout: - fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) +try: + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) +except Exception as e: + pytest.fail(f"Failed to write test config: {e}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(3 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(0 hunks)tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py(4 hunks)
💤 Files with no reviewable changes (1)
- src/axolotl/utils/ctx_managers/sequence_parallel.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/monkeypatch/ring_attn/patch.py
- src/axolotl/loaders/patch_manager.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py (1)
src/axolotl/cli/config.py (1)
load_cfg(164-249)
🔇 Additional comments (2)
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py (2)
399-399: Function signature change looks appropriate.The addition of the
temp_dirfixture parameter aligns with the new file-based configuration approach.
516-516: Function signature change looks appropriate.The addition of the
temp_dirfixture parameter is consistent with the first modified function.
Description
Our accelerate patching to enable SP seems to have broken recently; this fixes it. I also am more explicit about using FSDP when enabled.
Motivation and Context
How has this been tested?
Confirmed working with FSDP x SP 2 x 2 device mesh on 4x H100 SXM. Also works with various optimizations (Liger optims, CCE).
Example config (modified from user's): https://gist.github.com/djsaunde/aca6285273cf9d476e69baa1cdcab6c7. This uses ~29GB VRAM per GPU.
I'll repro training losses and post here as well.
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests