Skip to content

TiledMLP support for FSDP2#2950

Merged
winglian merged 8 commits into
mainfrom
tiled-mlp-no-ds
Jul 25, 2025
Merged

TiledMLP support for FSDP2#2950
winglian merged 8 commits into
mainfrom
tiled-mlp-no-ds

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Jul 19, 2025

Description

Use manual gradient accumulation with parameter hooks to work with FSDP2. I didn't test with FSDP1, but I expect it should work with that also.
@sfc-gh-sbekman

WandB: https://wandb.ai/axolotl-ai/tiledmlp-fsdp2?nw=nwuserwingaxolotl

Screenshot 2025-07-19 at 1 30 11 PM

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Introduced a new tiled multi-layer perceptron (TiledMLP) supporting distributed and single GPU training, with advanced gradient accumulation for mixed precision scenarios.
    • Added a new patch application step between plugin initialization and model loading to enhance model setup.
    • Added example configurations and documentation for Arctic Long Sequence Training (ALST) with Llama 3 8B models using DeepSpeed and FSDP.
  • Improvements

    • Enhanced garbage collection during training to trigger at the very start and at step zero.
    • Updated tiled MLP patch to default to using the original MLP implementation and dynamically select DeepSpeed support when applicable.
    • Refined conflict validation to allow tiled MLP usage when explicitly configured to use the original MLP.
    • Changed default configuration to enable original MLP usage for tiled MLP by default.
  • Removals

    • Removed the validation check that enforced DeepSpeed usage with tiled MLP in multi-GPU setups.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jul 19, 2025

📖 Documentation Preview: https://688302cab89a2df9e6e11911--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit ed04f49

@codecov
Copy link
Copy Markdown

codecov Bot commented Jul 19, 2025

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jul 19, 2025

📝 Walkthrough

Walkthrough

This change adds a new public method in PatchManager to apply tiled MLP patches after plugin initialization and before model loading. It introduces a base tiled MLP autograd implementation with distributed and mixed precision support, updates the tiled MLP patch to select between DeepSpeed and local implementations, and modifies related schema defaults and validation. Garbage collection is triggered at training start. Additionally, new example configurations and documentation for Arctic Long Sequence Training (ALST) are added.

Changes

Files/Paths Change Summary
src/axolotl/loaders/patch_manager.py Added public method apply_post_plugin_pre_model_load_patches calling _apply_tiled_mlp; minor import formatting update.
src/axolotl/loaders/model.py Inserted call to apply_post_plugin_pre_model_load_patches between plugin pre-load and model build steps.
src/axolotl/monkeypatch/tiled_mlp/init.py New module exporting patch_tiled_mlp.
src/axolotl/monkeypatch/tiled_mlp/patch.py Changed default use_original_mlp to True; added logic to select DeepSpeed or local tiled MLP implementation dynamically.
src/axolotl/monkeypatch/tiled_mlp/base.py New module implementing TiledMLP autograd function and GradientAccumulator for tiled distributed mixed precision MLP.
src/axolotl/utils/callbacks/init.py Added on_train_begin method in GCCallback for immediate garbage collection; updated on_step_begin condition.
src/axolotl/utils/schemas/validation.py Removed check_tiled_mlp_deepspeed validator that enforced DeepSpeed requirement for multi-GPU tiled MLP.
src/axolotl/utils/schemas/config.py Changed default of tiled_mlp_use_original_mlp from None to True in AxolotlInputConfig.
src/axolotl/integrations/liger/args.py Refined check_tiled_mlp_conflict validation to allow conflict if tiled_mlp_use_original_mlp is true; updated error msg.
README.md Added July 2025 update announcing TiledMLP support for ALST with multi-GPU training and example usage link.
examples/alst/README.md Added ALST example README describing tiled MLP, tiled loss, and activation offloading techniques with ALST paper link.
examples/alst/llama3-8b-deepspeed-alst.yaml Added new DeepSpeed ALST configuration file for Llama 3 8B with tiled MLP and long sequence training settings.
examples/alst/llama3-8b-fsdp2-alst.yaml Added new FSDP2 ALST configuration file for Llama 3 8B with tiled MLP, activation offloading, and sequence parallelism.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~45 minutes

Possibly related PRs

  • TiledMLP support #2865: Extends the PatchManager with a new method to apply tiled MLP patches post-plugin initialization, building on the existing private _apply_tiled_mlp method.
  • models.py -> loaders/ module refactor #2680: Refactors model loading and patch management, related to the integration of patch application in the model loading sequence.

Suggested labels

scheduled_release

Suggested reviewers

  • djsaunde

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6a1c20a and ed04f49.

📒 Files selected for processing (2)
  • examples/alst/llama3-8b-fsdp2-alst.yaml (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py
  • examples/alst/llama3-8b-fsdp2-alst.yaml
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch tiled-mlp-no-ds

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/axolotl/utils/callbacks/__init__.py (1)

874-876: Consider the necessity of garbage collection at step 0.

The addition of state.global_step == 0 condition ensures garbage collection occurs at step 0, but this might be redundant with the new on_train_begin method that also calls _gc(). While the overhead is minimal and the safety benefit might be worthwhile for distributed training, consider if both are necessary.

src/axolotl/monkeypatch/tiled_mlp/fsdp.py (2)

38-40: Add return type annotation for consistency.

The backward method is missing its return type annotation.

 @staticmethod
-def backward(ctx, *grads) -> torch.Tensor:
+def backward(ctx, *grads) -> tuple:

31-36: Consider adding input validation for robustness.

The forward method assumes the input can be evenly chunked. Consider adding validation or handling for cases where the tensor dimensions might not divide evenly by the shard count.

 x_shards = list(torch.chunk(x, chunks=shards, dim=1))
+# Validate that chunking was successful
+if len(x_shards) != shards:
+    raise ValueError(f"Expected {shards} shards but got {len(x_shards)}. Input dim 1 size: {x.shape[1]}")
 with torch.no_grad():
     output_shards = [fn(self, x_shard) for x_shard in x_shards]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 170322a and 0062dc7.

📒 Files selected for processing (6)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/__init__.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/fsdp.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py (3 hunks)
  • src/axolotl/utils/callbacks/__init__.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/tiled_mlp/patch.py (2)
src/axolotl/utils/callbacks/models.py (1)
  • get_causal_lm_model_cls_prefix (8-23)
src/axolotl/monkeypatch/tiled_mlp/fsdp.py (1)
  • TiledMLPFSDP (11-84)
🔇 Additional comments (6)
src/axolotl/utils/callbacks/__init__.py (1)

866-869: Good addition for memory management at training start.

The new on_train_begin method properly triggers garbage collection when training begins, which is especially beneficial for distributed training scenarios like FSDP that require careful memory management.

src/axolotl/monkeypatch/tiled_mlp/__init__.py (1)

1-14: LGTM!

The module initialization is well-structured with clear exports and documentation.

src/axolotl/utils/schemas/validation.py (1)

491-496: Validation logic correctly updated for FSDP support.

The condition now properly checks for either DeepSpeed or FSDP configuration when using tiled MLP with multiple GPUs, which aligns with the PR's objective of adding FSDP support.

src/axolotl/monkeypatch/tiled_mlp/patch.py (1)

1-80: Excellent refactoring to support multiple tiled MLP implementations.

The generalization of the patching function with tiled_mlp_cls parameter and the use of functools.partial to create specialized versions is a clean design that follows the DRY principle while maintaining flexibility.

src/axolotl/monkeypatch/tiled_mlp/fsdp.py (2)

87-98: Consider if thread synchronization is necessary.

The GradientAccumulator uses a threading lock, but PyTorch's autograd engine typically serializes backward passes. Unless you're explicitly using multi-threaded backward passes, this synchronization might be unnecessary overhead.

Can you confirm if the backward passes are indeed executed in parallel threads? If not, the threading lock could be removed for better performance.


119-122: Ensure gradient hook return value compatibility with FSDP

The hook in src/axolotl/monkeypatch/tiled_mlp/fsdp.py (lines 119–122) returns the accumulated gradient on the last shard, but FSDP’s autograd integration typically expects hooks to return None to avoid unintended interference. Please verify that returning the tensor here won’t break FSDP’s gradient sharding and reduction logic, or consider refactoring to always return None and apply self.accumulated_grads[param] directly in a post‐hook step.

• File: src/axolotl/monkeypatch/tiled_mlp/fsdp.py
• Lines: 119–122

if is_last_shard:
    param.grad = self.accumulated_grads[param]
    return self.accumulated_grads[param]
return None

Comment thread src/axolotl/loaders/patch_manager.py
@winglian winglian changed the title TiledMLP support for FSDP TiledMLP support for FSDP2 Jul 19, 2025
x_shards = list(torch.chunk(x, chunks=shards, dim=1))

# Create a gradient accumulator for parameters
grad_accumulator = GradientAccumulator(compute_params, shards)
Copy link
Copy Markdown

@stas00 stas00 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to detect this is running under FSDP?

You can detect it's running under deepspeed by checking hasattr(param, "ds_param") as deepspeed installs custom attributes into params.

I am thinking these two can then be merged into one class.

Also this would work for DDP as well, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't consistently find the ds_param attribute, but param_idx_in_group seems to work across the various ZeRO-* implementations so it's still a single unified patch.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian

For z1 and z2: hasattr(param, "param_idx_in_group")

For z3: hasattr(param, "ds_id").

It might be better to extend the z3 API to all deepspeed scenarios as the unified API. Please advise if that would be helpful.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, so for z3 param_idx_in_group won't be there?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, having some .ds_blah attribute that would indicate the model is wrapped in ZeRO would be most desirable for consistency.

like .ds_zero flag perhaps?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-sbekman, I agreed with adding .ds_zero attribute.

@winglian, will this work?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which would be tricky as it'd require a new deepspeed release first, no?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, new deepspeed release would be required. Until then, @winglian can continue with the separate detection logic for z1+z2 vs z3. That should work.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @sfc-gh-truwase ! We'll use param_idx_in_group and ds_id checks for now

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (1)

154-154: Missing line annotation marker.

Line 154 is missing the ~ annotation marker that indicates it's a new/changed line.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 24c3cb0 and e762f1f.

📒 Files selected for processing (6)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/__init__.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/base.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py (3 hunks)
  • src/axolotl/utils/callbacks/__init__.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (0 hunks)
💤 Files with no reviewable changes (1)
  • src/axolotl/utils/schemas/validation.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • src/axolotl/monkeypatch/tiled_mlp/init.py
  • src/axolotl/utils/callbacks/init.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/tiled_mlp/patch.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/monkeypatch/tiled_mlp/base.py (3)

1-9: LGTM!

The module documentation and imports are appropriate for implementing TiledMLP with distributed training support.


11-86: Well-implemented custom autograd function!

The TiledMLP implementation correctly handles:

  • Input sharding and parallel computation in forward pass
  • Manual gradient accumulation across shards in backward pass
  • Proper context management and cleanup
  • Correct return signature for autograd functions

88-153: Excellent gradient accumulation implementation!

The GradientAccumulator class properly handles:

  • Thread-safe gradient accumulation with locking
  • Configurable precision for mixed-precision training
  • Correct gradient scaling (1/shards) for averaging
  • Proper hook lifecycle management
  • Memory cleanup to prevent leaks

Comment thread examples/alst/README.md
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 I'm sure this could be better, so happy to accept any changes you have on this README.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
examples/alst/README.md (2)

3-3: Fix the “Arctic” typo.

ArticArctic to align with the title and the acronym.

-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
+Arctic Long Sequence Training (ALST) is a technique for training long-context models using various optimization

5-7: Align bullet-point style and tighten wording.

Minor consistency / style tweaks: keep each bullet in third-person singular, add periods, and avoid sentence-case shifts.

- - TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
- - Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
- - Activation Offloading: Offload activations to CPU RAM to reduce memory usage
+ - TiledMLP: Leverages tiling over the sequence dimension in MLP layers to reduce memory usage.
+ - Tiled Loss: Uses optimized loss functions such as Liger-Kernel or Cut Cross Entropy to reduce memory usage.
+ - Activation Offloading: Offloads activations to CPU RAM to reduce memory usage.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9905f3 and db6fd18.

📒 Files selected for processing (4)
  • README.md (1 hunks)
  • examples/alst/README.md (1 hunks)
  • examples/alst/llama3-8b-deepspeed-alst.yaml (1 hunks)
  • examples/alst/llama3-8b-fsdp2-alst.yaml (1 hunks)
🪛 LanguageTool
examples/alst/README.md

[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...

(A_VARIETY_OF)

✅ Files skipped from review due to trivial changes (3)
  • README.md
  • examples/alst/llama3-8b-fsdp2-alst.yaml
  • examples/alst/llama3-8b-deepspeed-alst.yaml
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md

[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...

(A_VARIETY_OF)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: pre-commit

Comment thread README.md
Comment thread examples/alst/llama3-8b-fsdp2-alst.yaml Outdated
sample_packing: true

tiled_mlp: true
sequence_parallel_degree: 8
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self that we'll need to change this to context_parallel_size when the other PR lands.

Copy link
Copy Markdown

@sfc-gh-sbekman sfc-gh-sbekman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few small doc suggestions

Comment thread examples/alst/README.md
Comment thread examples/alst/README.md
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
examples/alst/README.md (2)

3-4: Fix typo and tighten wording

Typo: “Artic” → “Arctic”.
The sentence is also a bit wordy (“technique … using a variety of optimization techniques”). Streamline and avoid repeating technique / techniques.

-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
-techniques. It is a combination of:
+Arctic Long Sequence Training (ALST) is a method for training long-context models that combines several optimization approaches:

5-8: Align bullet-point style for consistency

The list mixes sentence fragments and full sentences, and capitalization after the dash is inconsistent. Pick one style (e.g., capitalize each item and phrase them all as noun clauses).

Example:

-- Ulysses SP: all-to-all-based sequence parallelism method
-- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
-- Fused Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy that don't manifest full logits to reduce memory usage
-- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
+- Ulysses SP – All-to-all-based sequence parallelism  
+- Tiled MLP – Tiling over the sequence dimension in MLP layers to reduce memory footprint  
+- Fused Tiled Loss – Optimized loss functions (e.g., Liger-Kernel, Cut Cross Entropy) that avoid materializing full logits  
+- Activation Offloading – Moving activations to CPU RAM to lower GPU memory usage

Purely editorial, but it makes the README look more polished.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db6fd18 and 90ad5e9.

📒 Files selected for processing (2)
  • README.md (1 hunks)
  • examples/alst/README.md (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • README.md
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md

[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...

(A_VARIETY_OF)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview

Comment thread examples/alst/README.md
@@ -0,0 +1,10 @@
# Arctic Long Sequence Training (ALST)

Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
Copy link
Copy Markdown

@sfc-gh-sbekman sfc-gh-sbekman Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
Arctic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
examples/alst/README.md (1)

3-3: Fix typo in "Arctic"

There's a typo in line 3: "Artic" should be "Arctic".

-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
+Arctic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (1)

110-118: Potential memory inefficiency in gradient initialization

The gradient initialization creates a copy when param.grad exists, which could be memory-intensive for large models. Consider checking if the existing gradient is already in the correct dtype.

         for param in self.params:
             if param.grad is not None:
-                self.accumulated_grads[param] = param.grad.to(
-                    self.grad_accumulation_dtype
-                )
+                if param.grad.dtype == self.grad_accumulation_dtype:
+                    self.accumulated_grads[param] = param.grad.clone()
+                else:
+                    self.accumulated_grads[param] = param.grad.to(
+                        self.grad_accumulation_dtype
+                    )
                 param.grad = None
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 90ad5e9 and 6a1c20a.

📒 Files selected for processing (13)
  • README.md (1 hunks)
  • examples/alst/README.md (1 hunks)
  • examples/alst/llama3-8b-deepspeed-alst.yaml (1 hunks)
  • examples/alst/llama3-8b-fsdp2-alst.yaml (1 hunks)
  • src/axolotl/integrations/liger/args.py (1 hunks)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/__init__.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/base.py (1 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py (4 hunks)
  • src/axolotl/utils/callbacks/__init__.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (0 hunks)
💤 Files with no reviewable changes (1)
  • src/axolotl/utils/schemas/validation.py
🚧 Files skipped from review as they are similar to previous changes (10)
  • README.md
  • src/axolotl/loaders/model.py
  • src/axolotl/monkeypatch/tiled_mlp/init.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/callbacks/init.py
  • src/axolotl/integrations/liger/args.py
  • src/axolotl/monkeypatch/tiled_mlp/patch.py
  • src/axolotl/loaders/patch_manager.py
  • examples/alst/llama3-8b-deepspeed-alst.yaml
  • examples/alst/llama3-8b-fsdp2-alst.yaml
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md

[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...

(A_VARIETY_OF)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/monkeypatch/tiled_mlp/base.py (5)

16-36: LGTM! Well-structured forward pass implementation

The forward pass correctly:

  • Saves necessary context for backward pass
  • Chunks input along feature dimension
  • Disables gradients for computation to avoid double tracking
  • Concatenates outputs properly

The context saving is efficient and the no_grad context prevents unnecessary gradient computation during forward pass.


124-138: Excellent gradient hook implementation with proper synchronization

The hook implementation correctly:

  • Uses thread-safe accumulation with locks
  • Scales gradients by 1/shards to maintain proper averaging
  • Only assigns final gradient on the last shard to avoid race conditions
  • Handles type conversions properly

This is a well-thought-out approach for distributed gradient accumulation.


148-153: Good cleanup implementation

The cleanup method properly removes all hooks and clears references to prevent memory leaks. The explicit deletion of accumulated_grads is good practice.


74-80: ✅ Autograd chain maintained in custom backward pass

The pattern torch.autograd.backward(output, incoming_grad_shard) in tiled_mlp/base.py is consistent with other monkey-patched modules, so the autograd graph will be correctly connected. No further changes needed.

Similar usages found:

  • src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py (lines 506, 508, 510)
  • src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (line 65)

98-98: No action needed: union type syntax is supported under Python ≥3.10

The project’s pyproject.toml specifies requires-python = ">=3.10", so the torch.dtype | None annotation is fully compatible and does not need to be replaced with Union or Optional.

Likely an incorrect or invalid review comment.

Comment on lines +57 to +71
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)

shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential issue with gradient view calculation

The shard offset calculation assumes all shards have the same numel(), but this may not be true if the tensor doesn't divide evenly by the number of shards. This could lead to incorrect gradient assignments.

Consider using the actual cumulative sizes instead:

-        shard_step = x_shards[0].numel()
+        shard_offsets = []
+        cumulative_size = 0
+        for shard in x_shards:
+            shard_offsets.append(cumulative_size)
+            cumulative_size += shard.numel()
+            
         for i, x_shard in enumerate(x_shards):
             x_shard.requires_grad_(x_requires_grad)

-            shard_offset = i * shard_step
+            shard_offset = shard_offsets[i]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
shard_offsets = []
cumulative_size = 0
for shard in x_shards:
shard_offsets.append(cumulative_size)
cumulative_size += shard.numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_offset = shard_offsets[i]
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp/base.py around lines 57 to 71, the current
shard_offset calculation assumes all shards have the same number of elements,
which can cause incorrect gradient assignments if shards differ in size. To fix
this, compute shard_offset as the cumulative sum of the sizes of all previous
shards instead of using a fixed step. This ensures the gradient views align
correctly with each shard's actual size.

@winglian winglian merged commit f7ea140 into main Jul 25, 2025
23 of 27 checks passed
@winglian winglian deleted the tiled-mlp-no-ds branch July 25, 2025 11:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants