tiled_mlp supports single gpu#2891
Conversation
WalkthroughThis update refines distributed execution logic in the tiled MLP patch by checking for a distributed environment before performing all-reduce operations. It also adjusts validation to only require DeepSpeed when tiled MLP is enabled with multiple GPUs. Additionally, it introduces a new CPU offload checkpointing class and updates gradient checkpointing patching to use it. The DeepSpeed environment setup is enhanced to initialize distributed state even for single-GPU runs. Documentation is updated with a tip for using ZeRO Stage 3 on single GPU. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Validation
participant Env
participant DeepSpeed
User->>Validation: check_tiled_mlp_deepspeed(data)
Validation->>Env: Get n_gpu from capabilities
alt n_gpu > 1 and tiled_mlp enabled
Validation->>DeepSpeed: Is deepspeed enabled?
alt deepspeed not enabled
Validation-->>User: Raise ValueError
else deepspeed enabled
Validation-->>User: Return data
end
else n_gpu <= 1 or tiled_mlp not enabled
Validation-->>User: Return data
end
Possibly related PRs
Suggested labels
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)
481-490: Logic improvement aligns with single GPU support objective.The updated validation logic correctly distinguishes between single and multi-GPU scenarios for tiled MLP usage. The conditional check ensures deepspeed is only required when multiple GPUs are present, enabling single GPU support as intended.
Consider using
capabilities.get("n_gpu", 0)consistently instead of the conditional check:- capabilities = data.get("capabilities") - n_gpu = 0 - if capabilities and capabilities.get("n_gpu", 0) >= 1: - n_gpu = capabilities.get("n_gpu", 0) + capabilities = data.get("capabilities", {}) + n_gpu = capabilities.get("n_gpu", 0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/monkeypatch/tiled_mlp.py(2 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/tiled_mlp.py (1)
src/axolotl/utils/distributed.py (1)
is_distributed(60-67)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/monkeypatch/tiled_mlp.py (2)
4-4: Import added for environment variable access.The
osimport is necessary for accessing theWORLD_SIZEenvironment variable to detect distributed execution.
41-44: Conditional all-reduce correctly prevents unnecessary distributed operations.The conditional logic properly guards the all-reduce operation, preventing collective communication in single GPU scenarios. This change is essential for enabling single GPU support while maintaining correct behavior in distributed environments.
|
|
||
| mlp_forward = torch.compile(generic_mlp_forward) | ||
|
|
||
| is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1 |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Consider using existing distributed utility function.
The distributed detection logic is correct and aligns with the PR objective. However, there's an existing utility function is_distributed() in src/axolotl/utils/distributed.py that provides similar functionality.
Consider using the existing utility for consistency:
+from axolotl.utils.distributed import is_distributed
+
- is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+ is_distributed_env = is_distributed()Note that the existing function checks for initialized distributed state, which might be more robust than just checking the environment variable.
🏁 Script executed:
#!/bin/bash
# Description: Check if the existing is_distributed function is more comprehensive than WORLD_SIZE check
# Search for is_distributed function usage patterns
rg -A 10 "def is_distributed" src/axolotl/utils/distributed.py
# Check how WORLD_SIZE is used elsewhere in the codebase
rg -n "WORLD_SIZE" --type pyLength of output: 931
Use the centralized is_distributed() helper for consistency
The manual WORLD_SIZE check can be replaced with the existing utility in axolotl.utils.distributed, which also verifies that the distributed backend is initialized.
• File: src/axolotl/monkeypatch/tiled_mlp.py
+ from axolotl.utils.distributed import is_distributed
@@ src/axolotl/monkeypatch/tiled_mlp.py:33
- is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+ is_distributed_env = is_distributed()This change ensures all distributed‐mode checks share the same robust logic.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1 | |
| # at the top of src/axolotl/monkeypatch/tiled_mlp.py | |
| import os | |
| from axolotl.utils.distributed import is_distributed | |
| # …later, around line 33… | |
| # Replace the manual WORLD_SIZE check with the centralized helper | |
| is_distributed_env = is_distributed() |
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py at line 33, replace the manual check of
the WORLD_SIZE environment variable with the centralized is_distributed() helper
from axolotl.utils.distributed. Import the is_distributed function at the top of
the file and use it to determine distributed mode, ensuring consistent and
robust distributed environment detection across the codebase.
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
119-129: Consider making tensor offloading strategy configurableThe current implementation only offloads the first tensor to CPU while keeping others (like attention masks) on GPU. While the reasoning is sound (attention masks can be large), consider making this configurable for different use cases where users might want to offload different tensors based on their memory constraints.
Would you like me to suggest an implementation that accepts a list of tensor indices to offload as a parameter?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/gradient_checkpointing/__init__.py(1 hunks)src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/monkeypatch/gradient_checkpointing/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
CheckpointFunctionWithCPUOffload(83-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/loaders/patch_manager.py (1)
167-173: Improved gradient checkpointing patch architectureThe switch from wrapper-based patching to directly replacing the
CheckpointFunctionclass is a cleaner approach that better aligns with PyTorch's checkpoint architecture.src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
142-227: Well-implemented backward pass with proper compatibility handlingThe backward method correctly handles tensor restoration, RNG state preservation, and autocast contexts. The version-specific handling for
set_device_states(lines 185-194) ensures compatibility across different PyTorch versions.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (2)
84-142: Well-implemented forward method with good memory optimization.The forward method correctly implements CPU offloading for the first tensor argument, which can provide significant memory savings for long sequences. The implementation properly:
- Handles device type inference and autocast settings
- Manages RNG state preservation
- Offloads only the first tensor (index 0) to CPU while keeping others on original device
- Maintains proper tensor metadata (device, requires_grad)
Consider adding a configuration parameter to control which tensor indices to offload, as mentioned in the comment on line 122:
# upstream could accept a list of arg indices to offload + def __init__(self, offload_indices=None): + self.offload_indices = offload_indices or [0]This would make the implementation more flexible for different use cases.
123-130: Consider documenting the tensor selection logic.The current implementation only offloads the first tensor (index 0) with a comment about the second tensor being "usually a shared 4D attention mask which is huge [seq,seq]". This assumption might not hold for all use cases.
Add more detailed documentation about the tensor selection criteria:
# cpu-offload - # we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq] - # upstream could accept a list of arg indices to offload + # Only offload the first tensor (typically hidden states) to avoid + # offloading large shared tensors like attention masks at index 1 + # TODO: Make offload indices configurable for different model architectures if i == 0:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
docs/multi-gpu.qmd(1 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py(2 hunks)src/axolotl/utils/trainer.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- docs/multi-gpu.qmd
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/loaders/patch_manager.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/trainer.py (1)
549-558: Ensure NCCL backend compatibility across non-CUDA scenariosWhile the single-GPU distributed initialization in
src/axolotl/utils/trainer.py(lines 549–558) correctly forcesWORLD_SIZE=1and callsdist.init_distributed(dist_backend="nccl"), please verify that:
- NCCL is available and appropriate when running on MPS or XPU devices
- CPU-only runs either skip this block or fall back to a compatible backend (e.g. gloo)
- Any existing tests or scripts (e.g., in
tests/e2e) don’t assume NCCL on non-CUDA hardwareConsider adding a conditional backend choice or a clear error message/fallback for environments where NCCL isn’t supported.
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
143-228: Backward method with robust state restoration and PyTorch compatibilityThe
backwardimplementation not only restores CPU-offloaded tensors and RNG state correctly, but also uses thehas_device_typeguard to support both pre-May 2025 and post-May 2025 PyTorch releases (thedevice_typeparameter inset_device_stateswas introduced around May 29, 2025).
Key strengths:
- Proper device restoration for offloaded tensors
- Correct RNG state management via
torch.random.fork_rng- Compatibility with PyTorch versions before and after
set_device_states(device_type=…)was added- Proper autocast context management for both device and CPU
- Correct gradient recomputation and return signature
No further changes needed.
* tiled_mlp supports single gpu * use checkpoint offloading for arctic training * patch torch checkpoint too * support for single gpu zero3 * add linkback to where it was copied from
Summary by CodeRabbit