Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/axolotl/monkeypatch/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,33 @@ def patch_parallelism_config():

ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)


def patch_prepare_cp():
import functools

import torch
from accelerate import Accelerator

def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args

from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method

cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)

self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)

for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)

return args

Accelerator._prepare_cp = patched_prepare_cp
3 changes: 3 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,9 @@ def setup_parallelism_envs(cfg):
set_accelerate_parallelism_config = True
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp

patch_prepare_cp()
if set_accelerate_parallelism_config:
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"

Expand Down