Skip to content

tiled_mlp supports single gpu#2891

Merged
winglian merged 5 commits into
mainfrom
tiled-mlp-single-gpu
Jul 9, 2025
Merged

tiled_mlp supports single gpu#2891
winglian merged 5 commits into
mainfrom
tiled-mlp-single-gpu

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Jul 9, 2025

Summary by CodeRabbit

  • Bug Fixes
    • Improved distributed environment detection to prevent unnecessary operations when not running in distributed mode.
    • Enhanced validation logic to only require DeepSpeed when using tiled MLP with multiple GPUs, reducing false error messages for single-GPU or CPU setups.
  • New Features
    • Introduced a new memory-efficient checkpointing method that offloads data to CPU during training to save GPU memory on long sequences.
  • Improvements
    • Updated gradient checkpointing to use a new CPU offload checkpoint function for specific configurations, enhancing memory management during training.
    • Added explicit initialization of the distributed environment for single-GPU setups to improve DeepSpeed compatibility.
  • Documentation
    • Added a tip explaining how to use DeepSpeed ZeRO Stage 3 for single-GPU training by setting specific environment variables.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jul 9, 2025

Walkthrough

This update refines distributed execution logic in the tiled MLP patch by checking for a distributed environment before performing all-reduce operations. It also adjusts validation to only require DeepSpeed when tiled MLP is enabled with multiple GPUs. Additionally, it introduces a new CPU offload checkpointing class and updates gradient checkpointing patching to use it. The DeepSpeed environment setup is enhanced to initialize distributed state even for single-GPU runs. Documentation is updated with a tip for using ZeRO Stage 3 on single GPU.

Changes

File(s) Change Summary
src/axolotl/monkeypatch/tiled_mlp.py Added distributed environment detection via WORLD_SIZE; all-reduce now only occurs if distributed.
src/axolotl/utils/schemas/validation.py Validation now checks GPU count; only enforces DeepSpeed if tiled MLP enabled with multiple GPUs.
src/axolotl/loaders/patch_manager.py Updated gradient checkpointing patch to assign CheckpointFunctionWithCPUOffload unless use_reentrant=False.
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py Added CheckpointFunctionWithCPUOffload class to offload first tensor input to CPU during checkpointing.
src/axolotl/monkeypatch/gradient_checkpointing/init.py Added # noqa: F401 comment to suppress linter warning on import of CheckpointFunctionWithCPUOffload.
src/axolotl/utils/trainer.py Added explicit distributed initialization for single-GPU runs in setup_deepspeed_env.
docs/multi-gpu.qmd Added tip for running ZeRO Stage 3 on single GPU by setting environment variables.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Validation
    participant Env
    participant DeepSpeed

    User->>Validation: check_tiled_mlp_deepspeed(data)
    Validation->>Env: Get n_gpu from capabilities
    alt n_gpu > 1 and tiled_mlp enabled
        Validation->>DeepSpeed: Is deepspeed enabled?
        alt deepspeed not enabled
            Validation-->>User: Raise ValueError
        else deepspeed enabled
            Validation-->>User: Return data
        end
    else n_gpu <= 1 or tiled_mlp not enabled
        Validation-->>User: Return data
    end
Loading

Possibly related PRs

  • TiledMLP support #2865: Modifies the internal logic of the patch_tiled_mlp function to conditionally perform distributed all-reduce based on environment variables, directly related to the changes in this PR.

Suggested labels

scheduled_release

Poem

In the code where the MLPs play,
A rabbit hops through shards today.
Checks for GPUs, counts with care,
Only calls DeepSpeed if they're there.
Offloads tensors, saves the day,
Gradient checkpoints find their way.
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jul 9, 2025

@github-actions github-actions Bot temporarily deployed to preview July 9, 2025 14:46 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)

481-490: Logic improvement aligns with single GPU support objective.

The updated validation logic correctly distinguishes between single and multi-GPU scenarios for tiled MLP usage. The conditional check ensures deepspeed is only required when multiple GPUs are present, enabling single GPU support as intended.

Consider using capabilities.get("n_gpu", 0) consistently instead of the conditional check:

-        capabilities = data.get("capabilities")
-        n_gpu = 0
-        if capabilities and capabilities.get("n_gpu", 0) >= 1:
-            n_gpu = capabilities.get("n_gpu", 0)
+        capabilities = data.get("capabilities", {})
+        n_gpu = capabilities.get("n_gpu", 0)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7c5ea00 and d7671e8.

📒 Files selected for processing (2)
  • src/axolotl/monkeypatch/tiled_mlp.py (2 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/tiled_mlp.py (1)
src/axolotl/utils/distributed.py (1)
  • is_distributed (60-67)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/monkeypatch/tiled_mlp.py (2)

4-4: Import added for environment variable access.

The os import is necessary for accessing the WORLD_SIZE environment variable to detect distributed execution.


41-44: Conditional all-reduce correctly prevents unnecessary distributed operations.

The conditional logic properly guards the all-reduce operation, preventing collective communication in single GPU scenarios. This change is essential for enabling single GPU support while maintaining correct behavior in distributed environments.


mlp_forward = torch.compile(generic_mlp_forward)

is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Consider using existing distributed utility function.

The distributed detection logic is correct and aligns with the PR objective. However, there's an existing utility function is_distributed() in src/axolotl/utils/distributed.py that provides similar functionality.

Consider using the existing utility for consistency:

+from axolotl.utils.distributed import is_distributed
+
-        is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+        is_distributed_env = is_distributed()

Note that the existing function checks for initialized distributed state, which might be more robust than just checking the environment variable.


🏁 Script executed:

#!/bin/bash
# Description: Check if the existing is_distributed function is more comprehensive than WORLD_SIZE check

# Search for is_distributed function usage patterns
rg -A 10 "def is_distributed" src/axolotl/utils/distributed.py

# Check how WORLD_SIZE is used elsewhere in the codebase
rg -n "WORLD_SIZE" --type py

Length of output: 931


Use the centralized is_distributed() helper for consistency

The manual WORLD_SIZE check can be replaced with the existing utility in axolotl.utils.distributed, which also verifies that the distributed backend is initialized.

• File: src/axolotl/monkeypatch/tiled_mlp.py

+ from axolotl.utils.distributed import is_distributed
@@ src/axolotl/monkeypatch/tiled_mlp.py:33
-        is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+        is_distributed_env = is_distributed()

This change ensures all distributed‐mode checks share the same robust logic.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
# at the top of src/axolotl/monkeypatch/tiled_mlp.py
import os
from axolotl.utils.distributed import is_distributed
# …later, around line 33…
# Replace the manual WORLD_SIZE check with the centralized helper
is_distributed_env = is_distributed()
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py at line 33, replace the manual check of
the WORLD_SIZE environment variable with the centralized is_distributed() helper
from axolotl.utils.distributed. Import the is_distributed function at the top of
the file and use it to determine distributed mode, ensuring consistent and
robust distributed environment detection across the codebase.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jul 9, 2025

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)

119-129: Consider making tensor offloading strategy configurable

The current implementation only offloads the first tensor to CPU while keeping others (like attention masks) on GPU. While the reasoning is sound (attention masks can be large), consider making this configurable for different use cases where users might want to offload different tensors based on their memory constraints.

Would you like me to suggest an implementation that accepts a list of tensor indices to offload as a parameter?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d7671e8 and 9a0aabf.

📒 Files selected for processing (3)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/gradient_checkpointing/__init__.py (1 hunks)
  • src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/monkeypatch/gradient_checkpointing/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)
  • CheckpointFunctionWithCPUOffload (83-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/loaders/patch_manager.py (1)

167-173: Improved gradient checkpointing patch architecture

The switch from wrapper-based patching to directly replacing the CheckpointFunction class is a cleaner approach that better aligns with PyTorch's checkpoint architecture.

src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)

142-227: Well-implemented backward pass with proper compatibility handling

The backward method correctly handles tensor restoration, RNG state preservation, and autocast contexts. The version-specific handling for set_device_states (lines 185-194) ensures compatibility across different PyTorch versions.

@github-actions github-actions Bot temporarily deployed to preview July 9, 2025 15:16 Inactive
@github-actions github-actions Bot temporarily deployed to preview July 9, 2025 15:32 Inactive
@github-actions github-actions Bot temporarily deployed to preview July 9, 2025 16:24 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (2)

84-142: Well-implemented forward method with good memory optimization.

The forward method correctly implements CPU offloading for the first tensor argument, which can provide significant memory savings for long sequences. The implementation properly:

  • Handles device type inference and autocast settings
  • Manages RNG state preservation
  • Offloads only the first tensor (index 0) to CPU while keeping others on original device
  • Maintains proper tensor metadata (device, requires_grad)

Consider adding a configuration parameter to control which tensor indices to offload, as mentioned in the comment on line 122:

# upstream could accept a list of arg indices to offload
+ def __init__(self, offload_indices=None):
+     self.offload_indices = offload_indices or [0]

This would make the implementation more flexible for different use cases.


123-130: Consider documenting the tensor selection logic.

The current implementation only offloads the first tensor (index 0) with a comment about the second tensor being "usually a shared 4D attention mask which is huge [seq,seq]". This assumption might not hold for all use cases.

Add more detailed documentation about the tensor selection criteria:

                # cpu-offload
-               # we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
-               # upstream could accept a list of arg indices to offload
+               # Only offload the first tensor (typically hidden states) to avoid
+               # offloading large shared tensors like attention masks at index 1
+               # TODO: Make offload indices configurable for different model architectures
                if i == 0:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f5772dd and 82e5aed.

📒 Files selected for processing (4)
  • docs/multi-gpu.qmd (1 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (2 hunks)
  • src/axolotl/utils/trainer.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • docs/multi-gpu.qmd
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/loaders/patch_manager.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/trainer.py (1)

549-558: Ensure NCCL backend compatibility across non-CUDA scenarios

While the single-GPU distributed initialization in src/axolotl/utils/trainer.py (lines 549–558) correctly forces WORLD_SIZE=1 and calls dist.init_distributed(dist_backend="nccl"), please verify that:

  • NCCL is available and appropriate when running on MPS or XPU devices
  • CPU-only runs either skip this block or fall back to a compatible backend (e.g. gloo)
  • Any existing tests or scripts (e.g., in tests/e2e) don’t assume NCCL on non-CUDA hardware

Consider adding a conditional backend choice or a clear error message/fallback for environments where NCCL isn’t supported.

src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (1)

143-228: Backward method with robust state restoration and PyTorch compatibility

The backward implementation not only restores CPU-offloaded tensors and RNG state correctly, but also uses the has_device_type guard to support both pre-May 2025 and post-May 2025 PyTorch releases (the device_type parameter in set_device_states was introduced around May 29, 2025).
Key strengths:

  • Proper device restoration for offloaded tensors
  • Correct RNG state management via torch.random.fork_rng
  • Compatibility with PyTorch versions before and after set_device_states(device_type=…) was added
  • Proper autocast context management for both device and CPU
  • Correct gradient recomputation and return signature

No further changes needed.

@winglian winglian merged commit 76aeb16 into main Jul 9, 2025
11 of 14 checks passed
@winglian winglian deleted the tiled-mlp-single-gpu branch July 9, 2025 16:48
winglian added a commit that referenced this pull request Jul 9, 2025
* tiled_mlp supports single gpu

* use checkpoint offloading for arctic training

* patch torch checkpoint too

* support for single gpu zero3

* add linkback to where it was copied from
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant