Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def apply_pre_model_load_patches(self):
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()

def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
Expand Down Expand Up @@ -231,6 +232,17 @@ def _apply_gemma3_conditional_generation_forward_patch(self):

patch_gemma3_conditional_generation_forward()

def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
Comment thread
djsaunde marked this conversation as resolved.
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)

patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)

def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
Expand Down
29 changes: 21 additions & 8 deletions src/axolotl/monkeypatch/ring_attn/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.

Raies:
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
Expand All @@ -168,23 +168,34 @@ def patch_prepare_data_loader():
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)

items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)

# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)

patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__

accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")


def patch_prepare_device_mesh(sequence_parallel_degree: int):
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.

Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""

def _prepare_device_mesh(self):
Expand All @@ -207,12 +218,14 @@ def _prepare_device_mesh(self):
)
device_ids = list(range(world_size))

# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)

# Replace the original method with our new method
Expand Down
8 changes: 0 additions & 8 deletions src/axolotl/utils/ctx_managers/sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
Expand Down Expand Up @@ -238,12 +236,6 @@ def _register_ring_attn(self):
ring_attn_func=self.ring_attn_func,
)

# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
)

def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
Expand Down
20 changes: 18 additions & 2 deletions tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test_model_architecture(model_config):


# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer

Expand Down Expand Up @@ -426,6 +426,14 @@ def test_kernel_training_integration():
}
)

# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))

# Load config
cfg = load_cfg(str(path))

# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)

Expand Down Expand Up @@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert found_patched_attn


def test_kernel_training_integration_dropout_non_zero():
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""

from axolotl.cli.utils import load_model_and_tokenizer
Expand Down Expand Up @@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
}
)

# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))

# Load config
cfg = load_cfg(str(path))
Comment thread
djsaunde marked this conversation as resolved.

# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)

Expand Down
Loading