TiledMLP support for FSDP2#2950
Conversation
|
📖 Documentation Preview: https://688302cab89a2df9e6e11911--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit ed04f49 |
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
📝 WalkthroughWalkthroughThis change adds a new public method in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/axolotl/utils/callbacks/__init__.py (1)
874-876: Consider the necessity of garbage collection at step 0.The addition of
state.global_step == 0condition ensures garbage collection occurs at step 0, but this might be redundant with the newon_train_beginmethod that also calls_gc(). While the overhead is minimal and the safety benefit might be worthwhile for distributed training, consider if both are necessary.src/axolotl/monkeypatch/tiled_mlp/fsdp.py (2)
38-40: Add return type annotation for consistency.The backward method is missing its return type annotation.
@staticmethod -def backward(ctx, *grads) -> torch.Tensor: +def backward(ctx, *grads) -> tuple:
31-36: Consider adding input validation for robustness.The forward method assumes the input can be evenly chunked. Consider adding validation or handling for cases where the tensor dimensions might not divide evenly by the shard count.
x_shards = list(torch.chunk(x, chunks=shards, dim=1)) +# Validate that chunking was successful +if len(x_shards) != shards: + raise ValueError(f"Expected {shards} shards but got {len(x_shards)}. Input dim 1 size: {x.shape[1]}") with torch.no_grad(): output_shards = [fn(self, x_shard) for x_shard in x_shards]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/__init__.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/fsdp.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/patch.py(3 hunks)src/axolotl/utils/callbacks/__init__.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/tiled_mlp/patch.py (2)
src/axolotl/utils/callbacks/models.py (1)
get_causal_lm_model_cls_prefix(8-23)src/axolotl/monkeypatch/tiled_mlp/fsdp.py (1)
TiledMLPFSDP(11-84)
🔇 Additional comments (6)
src/axolotl/utils/callbacks/__init__.py (1)
866-869: Good addition for memory management at training start.The new
on_train_beginmethod properly triggers garbage collection when training begins, which is especially beneficial for distributed training scenarios like FSDP that require careful memory management.src/axolotl/monkeypatch/tiled_mlp/__init__.py (1)
1-14: LGTM!The module initialization is well-structured with clear exports and documentation.
src/axolotl/utils/schemas/validation.py (1)
491-496: Validation logic correctly updated for FSDP support.The condition now properly checks for either DeepSpeed or FSDP configuration when using tiled MLP with multiple GPUs, which aligns with the PR's objective of adding FSDP support.
src/axolotl/monkeypatch/tiled_mlp/patch.py (1)
1-80: Excellent refactoring to support multiple tiled MLP implementations.The generalization of the patching function with
tiled_mlp_clsparameter and the use offunctools.partialto create specialized versions is a clean design that follows the DRY principle while maintaining flexibility.src/axolotl/monkeypatch/tiled_mlp/fsdp.py (2)
87-98: Consider if thread synchronization is necessary.The
GradientAccumulatoruses a threading lock, but PyTorch's autograd engine typically serializes backward passes. Unless you're explicitly using multi-threaded backward passes, this synchronization might be unnecessary overhead.Can you confirm if the backward passes are indeed executed in parallel threads? If not, the threading lock could be removed for better performance.
119-122: Ensure gradient hook return value compatibility with FSDPThe hook in src/axolotl/monkeypatch/tiled_mlp/fsdp.py (lines 119–122) returns the accumulated gradient on the last shard, but FSDP’s autograd integration typically expects hooks to return None to avoid unintended interference. Please verify that returning the tensor here won’t break FSDP’s gradient sharding and reduction logic, or consider refactoring to always return None and apply self.accumulated_grads[param] directly in a post‐hook step.
• File: src/axolotl/monkeypatch/tiled_mlp/fsdp.py
• Lines: 119–122if is_last_shard: param.grad = self.accumulated_grads[param] return self.accumulated_grads[param] return None
0062dc7 to
b366257
Compare
| x_shards = list(torch.chunk(x, chunks=shards, dim=1)) | ||
|
|
||
| # Create a gradient accumulator for parameters | ||
| grad_accumulator = GradientAccumulator(compute_params, shards) |
There was a problem hiding this comment.
Is there a way to detect this is running under FSDP?
You can detect it's running under deepspeed by checking hasattr(param, "ds_param") as deepspeed installs custom attributes into params.
I am thinking these two can then be merged into one class.
Also this would work for DDP as well, right?
There was a problem hiding this comment.
I couldn't consistently find the ds_param attribute, but param_idx_in_group seems to work across the various ZeRO-* implementations so it's still a single unified patch.
There was a problem hiding this comment.
For z1 and z2: hasattr(param, "param_idx_in_group")
For z3: hasattr(param, "ds_id").
It might be better to extend the z3 API to all deepspeed scenarios as the unified API. Please advise if that would be helpful.
There was a problem hiding this comment.
oh, so for z3 param_idx_in_group won't be there?
There was a problem hiding this comment.
IMHO, having some .ds_blah attribute that would indicate the model is wrapped in ZeRO would be most desirable for consistency.
like .ds_zero flag perhaps?
There was a problem hiding this comment.
@sfc-gh-sbekman, I agreed with adding .ds_zero attribute.
@winglian, will this work?
There was a problem hiding this comment.
which would be tricky as it'd require a new deepspeed release first, no?
There was a problem hiding this comment.
Yes, new deepspeed release would be required. Until then, @winglian can continue with the separate detection logic for z1+z2 vs z3. That should work.
There was a problem hiding this comment.
thanks @sfc-gh-truwase ! We'll use param_idx_in_group and ds_id checks for now
24c3cb0 to
e762f1f
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (1)
154-154: Missing line annotation marker.Line 154 is missing the
~annotation marker that indicates it's a new/changed line.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/__init__.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/base.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/patch.py(3 hunks)src/axolotl/utils/callbacks/__init__.py(1 hunks)src/axolotl/utils/schemas/validation.py(0 hunks)
💤 Files with no reviewable changes (1)
- src/axolotl/utils/schemas/validation.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/axolotl/monkeypatch/tiled_mlp/init.py
- src/axolotl/utils/callbacks/init.py
- src/axolotl/loaders/patch_manager.py
- src/axolotl/monkeypatch/tiled_mlp/patch.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/monkeypatch/tiled_mlp/base.py (3)
1-9: LGTM!The module documentation and imports are appropriate for implementing TiledMLP with distributed training support.
11-86: Well-implemented custom autograd function!The TiledMLP implementation correctly handles:
- Input sharding and parallel computation in forward pass
- Manual gradient accumulation across shards in backward pass
- Proper context management and cleanup
- Correct return signature for autograd functions
88-153: Excellent gradient accumulation implementation!The GradientAccumulator class properly handles:
- Thread-safe gradient accumulation with locking
- Configurable precision for mixed-precision training
- Correct gradient scaling (1/shards) for averaging
- Proper hook lifecycle management
- Memory cleanup to prevent leaks
e762f1f to
b9905f3
Compare
There was a problem hiding this comment.
@stas00 I'm sure this could be better, so happy to accept any changes you have on this README.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
examples/alst/README.md (2)
3-3: Fix the “Arctic” typo.
Artic→Arcticto align with the title and the acronym.-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization +Arctic Long Sequence Training (ALST) is a technique for training long-context models using various optimization
5-7: Align bullet-point style and tighten wording.Minor consistency / style tweaks: keep each bullet in third-person singular, add periods, and avoid sentence-case shifts.
- - TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage - - Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage - - Activation Offloading: Offload activations to CPU RAM to reduce memory usage + - TiledMLP: Leverages tiling over the sequence dimension in MLP layers to reduce memory usage. + - Tiled Loss: Uses optimized loss functions such as Liger-Kernel or Cut Cross Entropy to reduce memory usage. + - Activation Offloading: Offloads activations to CPU RAM to reduce memory usage.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
README.md(1 hunks)examples/alst/README.md(1 hunks)examples/alst/llama3-8b-deepspeed-alst.yaml(1 hunks)examples/alst/llama3-8b-fsdp2-alst.yaml(1 hunks)
🪛 LanguageTool
examples/alst/README.md
[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...
(A_VARIETY_OF)
✅ Files skipped from review due to trivial changes (3)
- README.md
- examples/alst/llama3-8b-fsdp2-alst.yaml
- examples/alst/llama3-8b-deepspeed-alst.yaml
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md
[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...
(A_VARIETY_OF)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: pre-commit
| sample_packing: true | ||
|
|
||
| tiled_mlp: true | ||
| sequence_parallel_degree: 8 |
There was a problem hiding this comment.
note to self that we'll need to change this to context_parallel_size when the other PR lands.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
examples/alst/README.md (2)
3-4: Fix typo and tighten wordingTypo: “Artic” → “Arctic”.
The sentence is also a bit wordy (“technique … using a variety of optimization techniques”). Streamline and avoid repeating technique / techniques.-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization -techniques. It is a combination of: +Arctic Long Sequence Training (ALST) is a method for training long-context models that combines several optimization approaches:
5-8: Align bullet-point style for consistencyThe list mixes sentence fragments and full sentences, and capitalization after the dash is inconsistent. Pick one style (e.g., capitalize each item and phrase them all as noun clauses).
Example:
-- Ulysses SP: all-to-all-based sequence parallelism method -- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage -- Fused Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy that don't manifest full logits to reduce memory usage -- Activation Offloading: Offload activations to CPU RAM to reduce memory usage +- Ulysses SP – All-to-all-based sequence parallelism +- Tiled MLP – Tiling over the sequence dimension in MLP layers to reduce memory footprint +- Fused Tiled Loss – Optimized loss functions (e.g., Liger-Kernel, Cut Cross Entropy) that avoid materializing full logits +- Activation Offloading – Moving activations to CPU RAM to lower GPU memory usagePurely editorial, but it makes the README look more polished.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
README.md(1 hunks)examples/alst/README.md(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- README.md
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md
[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...
(A_VARIETY_OF)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
| @@ -0,0 +1,10 @@ | |||
| # Arctic Long Sequence Training (ALST) | |||
|
|
|||
| Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization | |||
There was a problem hiding this comment.
| Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization | |
| Arctic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization |
90ad5e9 to
6a1c20a
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
examples/alst/README.md (1)
3-3: Fix typo in "Arctic"There's a typo in line 3: "Artic" should be "Arctic".
-Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization +Arctic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (1)
110-118: Potential memory inefficiency in gradient initializationThe gradient initialization creates a copy when
param.gradexists, which could be memory-intensive for large models. Consider checking if the existing gradient is already in the correct dtype.for param in self.params: if param.grad is not None: - self.accumulated_grads[param] = param.grad.to( - self.grad_accumulation_dtype - ) + if param.grad.dtype == self.grad_accumulation_dtype: + self.accumulated_grads[param] = param.grad.clone() + else: + self.accumulated_grads[param] = param.grad.to( + self.grad_accumulation_dtype + ) param.grad = None
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
README.md(1 hunks)examples/alst/README.md(1 hunks)examples/alst/llama3-8b-deepspeed-alst.yaml(1 hunks)examples/alst/llama3-8b-fsdp2-alst.yaml(1 hunks)src/axolotl/integrations/liger/args.py(1 hunks)src/axolotl/loaders/model.py(1 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/tiled_mlp/__init__.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/base.py(1 hunks)src/axolotl/monkeypatch/tiled_mlp/patch.py(4 hunks)src/axolotl/utils/callbacks/__init__.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(0 hunks)
💤 Files with no reviewable changes (1)
- src/axolotl/utils/schemas/validation.py
🚧 Files skipped from review as they are similar to previous changes (10)
- README.md
- src/axolotl/loaders/model.py
- src/axolotl/monkeypatch/tiled_mlp/init.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/utils/callbacks/init.py
- src/axolotl/integrations/liger/args.py
- src/axolotl/monkeypatch/tiled_mlp/patch.py
- src/axolotl/loaders/patch_manager.py
- examples/alst/llama3-8b-deepspeed-alst.yaml
- examples/alst/llama3-8b-fsdp2-alst.yaml
🧰 Additional context used
🪛 LanguageTool
examples/alst/README.md
[style] ~3-~3: Consider using a more concise synonym.
Context: ... for training long context models using a variety of optimization techniques. It is a combin...
(A_VARIETY_OF)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/monkeypatch/tiled_mlp/base.py (5)
16-36: LGTM! Well-structured forward pass implementationThe forward pass correctly:
- Saves necessary context for backward pass
- Chunks input along feature dimension
- Disables gradients for computation to avoid double tracking
- Concatenates outputs properly
The context saving is efficient and the no_grad context prevents unnecessary gradient computation during forward pass.
124-138: Excellent gradient hook implementation with proper synchronizationThe hook implementation correctly:
- Uses thread-safe accumulation with locks
- Scales gradients by 1/shards to maintain proper averaging
- Only assigns final gradient on the last shard to avoid race conditions
- Handles type conversions properly
This is a well-thought-out approach for distributed gradient accumulation.
148-153: Good cleanup implementationThe cleanup method properly removes all hooks and clears references to prevent memory leaks. The explicit deletion of
accumulated_gradsis good practice.
74-80: ✅ Autograd chain maintained in custom backward passThe pattern
torch.autograd.backward(output, incoming_grad_shard)intiled_mlp/base.pyis consistent with other monkey-patched modules, so the autograd graph will be correctly connected. No further changes needed.Similar usages found:
- src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py (lines 506, 508, 510)
- src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py (line 65)
98-98: No action needed: union type syntax is supported under Python ≥3.10The project’s
pyproject.tomlspecifiesrequires-python = ">=3.10", so thetorch.dtype | Noneannotation is fully compatible and does not need to be replaced withUnionorOptional.Likely an incorrect or invalid review comment.
| shard_step = x_shards[0].numel() | ||
| for i, x_shard in enumerate(x_shards): | ||
| x_shard.requires_grad_(x_requires_grad) | ||
|
|
||
| shard_offset = i * shard_step | ||
| x_shard.grad = ( | ||
| x_grad.view(-1) | ||
| .narrow(0, shard_offset, x_shard.numel()) | ||
| .view_as(x_shard) | ||
| ) | ||
| incoming_grad_shard = ( | ||
| incoming_grad.view(-1) | ||
| .narrow(0, shard_offset, x_shard.numel()) | ||
| .view_as(x_shard) | ||
| ) |
There was a problem hiding this comment.
Potential issue with gradient view calculation
The shard offset calculation assumes all shards have the same numel(), but this may not be true if the tensor doesn't divide evenly by the number of shards. This could lead to incorrect gradient assignments.
Consider using the actual cumulative sizes instead:
- shard_step = x_shards[0].numel()
+ shard_offsets = []
+ cumulative_size = 0
+ for shard in x_shards:
+ shard_offsets.append(cumulative_size)
+ cumulative_size += shard.numel()
+
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
- shard_offset = i * shard_step
+ shard_offset = shard_offsets[i]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| shard_step = x_shards[0].numel() | |
| for i, x_shard in enumerate(x_shards): | |
| x_shard.requires_grad_(x_requires_grad) | |
| shard_offset = i * shard_step | |
| x_shard.grad = ( | |
| x_grad.view(-1) | |
| .narrow(0, shard_offset, x_shard.numel()) | |
| .view_as(x_shard) | |
| ) | |
| incoming_grad_shard = ( | |
| incoming_grad.view(-1) | |
| .narrow(0, shard_offset, x_shard.numel()) | |
| .view_as(x_shard) | |
| ) | |
| shard_offsets = [] | |
| cumulative_size = 0 | |
| for shard in x_shards: | |
| shard_offsets.append(cumulative_size) | |
| cumulative_size += shard.numel() | |
| for i, x_shard in enumerate(x_shards): | |
| x_shard.requires_grad_(x_requires_grad) | |
| shard_offset = shard_offsets[i] | |
| x_shard.grad = ( | |
| x_grad.view(-1) | |
| .narrow(0, shard_offset, x_shard.numel()) | |
| .view_as(x_shard) | |
| ) | |
| incoming_grad_shard = ( | |
| incoming_grad.view(-1) | |
| .narrow(0, shard_offset, x_shard.numel()) | |
| .view_as(x_shard) | |
| ) |
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp/base.py around lines 57 to 71, the current
shard_offset calculation assumes all shards have the same number of elements,
which can cause incorrect gradient assignments if shards differ in size. To fix
this, compute shard_offset as the cumulative sum of the sizes of all previous
shards instead of using a fixed step. This ensures the gradient views align
correctly with each shard's actual size.
Description
Use manual gradient accumulation with parameter hooks to work with FSDP2. I didn't test with FSDP1, but I expect it should work with that also.
@sfc-gh-sbekman
WandB: https://wandb.ai/axolotl-ai/tiledmlp-fsdp2?nw=nwuserwingaxolotl
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Improvements
Removals