Conversation
WalkthroughA new "tiled MLP" patching feature was introduced, enabling memory-efficient long-context support for transformer models. This includes adding configuration options for enabling tiled MLP and setting shard counts, validation to require DeepSpeed ZeRO when enabled, and patching logic that modifies the MLP forward pass to shard input sequences for tiled execution. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant Validator
participant PatchManager
participant TiledMLPPatch
User->>Config: Set tiled_mlp=True, deepspeed=True/False
Config->>Validator: Validate config
Validator->>Config: Raise error if tiled_mlp=True and deepspeed!=True
User->>PatchManager: Load model
PatchManager->>TiledMLPPatch: Apply tiled MLP patch if tiled_mlp=True
TiledMLPPatch->>Model: Patch MLP forward method
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (3)
⏰ Context from checks skipped due to timeout of 90000ms (9)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/tiled_mlp.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/monkeypatch/tiled_mlp.py (1)
src/axolotl/monkeypatch/mixtral/__init__.py (1)
mlp_forward(11-16)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/tiled_mlp.py (1)
patch_tiled_mlp(7-52)
🔇 Additional comments (4)
src/axolotl/utils/schemas/config.py (1)
551-556: LGTM!The
tiled_mlpfield is properly defined with a clear description and appropriate default value.src/axolotl/loaders/patch_manager.py (2)
235-240: LGTM!The
_apply_tiled_mlpmethod follows the established pattern for conditional patching.
68-68:model_config_typeinitialization is guaranteedThe
cfg.model_config_typeproperty is set unconditionally insrc/axolotl/utils/config/__init__.py:
- In the config registration logic:
This assignment ensurescfg.model_config_type = model_config.model_typemodel_config_typealways exists oncfgbefore any patches are applied.No further validation is required here.
src/axolotl/utils/schemas/validation.py (1)
479-485: LGTM!The validation correctly enforces the DeepSpeed requirement for tiled MLP functionality.
| def patch_tiled_mlp(model_type, use_original_mlp=False): | ||
| from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP |
There was a problem hiding this comment.
Add error handling for DeepSpeed import.
The DeepSpeed import should be wrapped in a try-except block to provide a clearer error message if DeepSpeed is not installed.
Apply this diff to handle the import error gracefully:
def patch_tiled_mlp(model_type, use_original_mlp=False):
- from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
+ try:
+ from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
+ except ImportError as e:
+ raise ImportError(
+ "DeepSpeed is required for tiled_mlp but not installed. "
+ "Please install it with 'pip install deepspeed'."
+ ) from e📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def patch_tiled_mlp(model_type, use_original_mlp=False): | |
| from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP | |
| def patch_tiled_mlp(model_type, use_original_mlp=False): | |
| try: | |
| from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP | |
| except ImportError as e: | |
| raise ImportError( | |
| "DeepSpeed is required for tiled_mlp but not installed. " | |
| "Please install it with 'pip install deepspeed'." | |
| ) from e |
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py around lines 7 to 8, the import of
TiledMLP from DeepSpeed is done without error handling. Wrap the import
statement in a try-except block to catch ImportError and raise a clear,
informative error message indicating that DeepSpeed is required but not
installed. This will help users understand the missing dependency more clearly.
| def tiled_mlp_forward(self, x): | ||
| input_shape = x.shape | ||
| seqlen = input_shape[-2] | ||
| hidden = input_shape[-1] | ||
| num_shards = math.ceil(seqlen / hidden) | ||
| num_shards_tensor = torch.tensor(num_shards, device=x.device) | ||
| dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) | ||
| num_shards = num_shards_tensor.item() | ||
|
|
||
| compute_params = [self.down_proj.weight, self.gate_proj.weight, self.up_proj.weight] | ||
|
|
||
| down_res = TiledMLP.apply( | ||
| mlp_forward, | ||
| self, | ||
| x, | ||
| num_shards, | ||
| compute_params, | ||
| ) | ||
| return down_res |
There was a problem hiding this comment.
Add distributed initialization check and validate weight attributes.
The implementation has two issues:
- It uses distributed communication without checking if distributed is initialized
- It assumes specific weight attributes exist
Apply this diff to handle both issues:
def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
num_shards = math.ceil(seqlen / hidden)
- num_shards_tensor = torch.tensor(num_shards, device=x.device)
- dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
- num_shards = num_shards_tensor.item()
+
+ # Synchronize num_shards across devices if distributed
+ if dist.is_initialized():
+ num_shards_tensor = torch.tensor(num_shards, device=x.device)
+ dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
+ num_shards = num_shards_tensor.item()
- compute_params = [self.down_proj.weight, self.gate_proj.weight, self.up_proj.weight]
+ # Validate and collect weight parameters
+ try:
+ compute_params = [self.down_proj.weight, self.gate_proj.weight, self.up_proj.weight]
+ except AttributeError as e:
+ raise AttributeError(
+ f"MLP class for {model_type} doesn't have expected weight attributes. "
+ f"Error: {str(e)}"
+ ) from eCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py around lines 27 to 45, add a check to
ensure distributed communication is only used if
torch.distributed.is_initialized() returns true to avoid errors when distributed
is not set up. Also, validate that the weight attributes (down_proj.weight,
gate_proj.weight, up_proj.weight) exist on self before using them, adding
appropriate error handling or fallback logic if they are missing.
| def generic_mlp_forward(self_, hs): | ||
| return self_.down_proj(self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)) | ||
|
|
||
| mlp_forward = torch.compile(generic_mlp_forward) | ||
|
|
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
The generic MLP forward makes assumptions about attribute names.
The generic forward function assumes the MLP has specific attributes (down_proj, gate_proj, up_proj, act_fn) which might not be universal across all transformer model types.
Consider adding error handling or validation:
else:
def generic_mlp_forward(self_, hs):
- return self_.down_proj(self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs))
+ try:
+ return self_.down_proj(self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs))
+ except AttributeError as e:
+ raise AttributeError(
+ f"MLP class for {model_type} doesn't have expected attributes. "
+ f"Consider using use_original_mlp=True. Error: {str(e)}"
+ ) from eAlternatively, verify which model types are supported:
🏁 Script executed:
#!/bin/bash
# Description: Check which transformer models have the expected MLP structure
# Search for MLP classes with the expected attributes
ast-grep --pattern 'class $_MLP($_):
$$$
def __init__($$$):
$$$
self.down_proj = $$$
$$$
self.gate_proj = $$$
$$$
self.up_proj = $$$
$$$'
# Also check for models that might have different MLP structures
rg -A 10 "class \w+MLP" --type py | grep -E "(self\.(w1|w2|w3|fc1|fc2|wi|wo))"Length of output: 252
To confirm which MLP variants exist (and whether they all use down_proj/gate_proj/up_proj), let’s search for any classes and attribute assignments in the codebase:
🏁 Script executed:
#!/bin/bash
# 1. Find where these attributes are defined
echo "=== self.gate_proj occurrences ==="
rg -n "self\.gate_proj" --type py || true
echo
echo "=== self.down_proj occurrences ==="
rg -n "self\.down_proj" --type py || true
echo
echo "=== self.up_proj occurrences ==="
rg -n "self\.up_proj" --type py || true
echo
# 2. List all classes whose names end with MLP
echo "=== Classes matching *MLP ==="
rg -n "class .*MLP" --type py || trueLength of output: 2577
Validate MLP attributes before applying the generic forward
The generic_mlp_forward in tiled_mlp.py unconditionally calls
self_.down_proj, self_.gate_proj, self_.up_proj, and self_.act_fn.
However, other MLP variants (e.g. llama4’s down_projs/gate_projs lists or xformers’ FusedMLP) don’t use these exact names, which will lead to AttributeError at runtime.
Suggested fix in src/axolotl/monkeypatch/tiled_mlp.py:
else:
def generic_mlp_forward(self_, hs):
- return self_.down_proj(self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs))
+ # Verify required attributes exist
+ required = ("down_proj", "gate_proj", "up_proj", "act_fn")
+ missing = [attr for attr in required if not hasattr(self_, attr)]
+ if missing:
+ raise AttributeError(
+ f"{self_.__class__.__name__} missing MLP attributes {missing}. "
+ "Consider using `use_original_mlp=True` or adapting your model."
+ )
+ return self_.down_proj(
+ self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
+ )
mlp_forward = torch.compile(generic_mlp_forward)Files to update:
- src/axolotl/monkeypatch/tiled_mlp.py (the
generic_mlp_forwardblock)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def generic_mlp_forward(self_, hs): | |
| return self_.down_proj(self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)) | |
| mlp_forward = torch.compile(generic_mlp_forward) | |
| def generic_mlp_forward(self_, hs): | |
| # Verify required attributes exist | |
| required = ("down_proj", "gate_proj", "up_proj", "act_fn") | |
| missing = [attr for attr in required if not hasattr(self_, attr)] | |
| if missing: | |
| raise AttributeError( | |
| f"{self_.__class__.__name__} missing MLP attributes {missing}. " | |
| "Consider using `use_original_mlp=True` or adapting your model." | |
| ) | |
| return self_.down_proj( | |
| self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs) | |
| ) | |
| mlp_forward = torch.compile(generic_mlp_forward) |
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py around lines 22 to 26, the
generic_mlp_forward function assumes the presence of attributes down_proj,
gate_proj, up_proj, and act_fn on self_, which causes AttributeError for MLP
variants with different attribute names. Modify generic_mlp_forward to first
check if these attributes exist on self_ before accessing them, and handle cases
where they are missing or named differently to avoid runtime errors.
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (4)
src/axolotl/monkeypatch/tiled_mlp.py (4)
10-10: Add error handling for DeepSpeed import.The DeepSpeed import should be wrapped in a try-except block to provide a clearer error message if DeepSpeed is not installed.
Apply this diff to handle the import error gracefully:
def patch_tiled_mlp(model_type, use_original_mlp=False): - from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP + try: + from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP + except ImportError as e: + raise ImportError( + "DeepSpeed is required for tiled_mlp but not installed. " + "Please install it with 'pip install deepspeed'." + ) from e
25-28: Validate MLP attributes before applying the generic forward.The
generic_mlp_forwardfunction assumes the presence of attributesdown_proj,gate_proj,up_proj, andact_fnonself_, which will causeAttributeErrorfor MLP variants with different attribute names.Apply this diff to add validation:
def generic_mlp_forward(self_, hs): - return self_.down_proj( - self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs) - ) + # Verify required attributes exist + required = ("down_proj", "gate_proj", "up_proj", "act_fn") + missing = [attr for attr in required if not hasattr(self_, attr)] + if missing: + raise AttributeError( + f"{self_.__class__.__name__} missing MLP attributes {missing}. " + "Consider using `use_original_mlp=True` or adapting your model." + ) + return self_.down_proj( + self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs) + )
37-39: Add distributed initialization check.The implementation uses distributed communication without checking if distributed is initialized, which will cause errors when distributed is not set up.
Apply this diff to add the distributed check:
num_shards = math.ceil(seqlen / hidden) - num_shards_tensor = torch.tensor(num_shards, device=x.device) - dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) - num_shards = num_shards_tensor.item() + + # Synchronize num_shards across devices if distributed + if dist.is_initialized(): + num_shards_tensor = torch.tensor(num_shards, device=x.device) + dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) + num_shards = num_shards_tensor.item()
41-45: Validate weight attributes before accessing them.The implementation assumes specific weight attributes exist without validation, which will cause
AttributeErrorfor MLP variants with different attribute names.Apply this diff to add validation:
- compute_params = [ - self.down_proj.weight, - self.gate_proj.weight, - self.up_proj.weight, - ] + # Validate and collect weight parameters + try: + compute_params = [ + self.down_proj.weight, + self.gate_proj.weight, + self.up_proj.weight, + ] + except AttributeError as e: + raise AttributeError( + f"MLP class for {model_type} doesn't have expected weight attributes. " + f"Error: {str(e)}" + ) from e
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/tiled_mlp.py (1)
9-61: Consider adding type hints and improved documentation.The function would benefit from type hints and more detailed documentation explaining the tiled MLP concept and its parameters.
Consider this enhancement:
-def patch_tiled_mlp(model_type, use_original_mlp=False): +def patch_tiled_mlp(model_type: str, use_original_mlp: bool = False) -> None: + """ + Patch the MLP class of a transformer model to use tiled MLP for memory efficiency. + + Args: + model_type: The transformer model type (e.g., 'llama', 'mistral') + use_original_mlp: Whether to use the original MLP forward method instead of generic + + Raises: + ImportError: If DeepSpeed is not installed + RuntimeError: If the MLP class cannot be imported or patched + AttributeError: If required MLP attributes are missing + """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/tiled_mlp.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- src/axolotl/loaders/patch_manager.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/tiled_mlp.py (1)
src/axolotl/monkeypatch/mixtral/__init__.py (1)
mlp_forward(11-16)
Description
Arctic Long Sequence Training (see https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/) introduced TiledMLP into deepspeed to reduce the activation footprint of long sequences in the MLP modules. This adds support for that via the
tiled_mlp: trueparameter in our YAML. This currently only works with deepspeed zero1 through zero3. Single GPU, DDP, and FSDP aren't supported with this currently.When using bf16, there appears to be some numerical differences in train and eval loss as well as grad norm, which @stas00 is helping to pinpoint. In the meantime, it's worth adding as it significantly reduces VRAM requirements with the tradeoff of some reduced accuracy.

Summary by CodeRabbit
Summary by CodeRabbit