Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def apply_pre_model_load_patches(self):
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
self._apply_tiled_mlp(self.cfg.model_config_type)

def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
Expand Down Expand Up @@ -243,6 +244,12 @@ def _apply_sequence_parallel_patches(self):
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)

def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp

patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)

def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
Expand Down
64 changes: 64 additions & 0 deletions src/axolotl/monkeypatch/tiled_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Monkeypatch for Tiled MLP implementation"""

import math

import torch
import torch.distributed as dist


def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP

try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")

if use_original_mlp:
mlp_forward = mlp_cls.forward
else:

def generic_mlp_forward(self_, hs):
return self_.down_proj(
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
)

mlp_forward = torch.compile(generic_mlp_forward)

def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
if cfg_num_shards is None:
num_shards = math.ceil(seqlen / hidden)
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
else:
num_shards = cfg_num_shards

compute_params = [
self.down_proj.weight,
self.gate_proj.weight,
self.up_proj.weight,
]

down_res = TiledMLP.apply(
mlp_forward,
self,
x,
num_shards,
compute_params,
)
return down_res

mlp_cls.forward = tiled_mlp_forward
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
14 changes: 14 additions & 0 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,20 @@ class AxolotlInputConfig(
},
)

tiled_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use ALST tiled mlp for memory efficient long context"
},
)

tiled_mlp_num_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
},
)

llama4_linearized_experts: bool | None = None

deepspeed: str | dict[str, Any] | None = Field(
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def pretrain_with_tps(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
return data


class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
Expand Down
Loading