From d483aae6f72dc707b35c5856760d0edb94491bdf Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 8 Dec 2025 21:41:29 +0000 Subject: [PATCH] WIP patch accelerate CP/SP --- .../accelerate/parallelism_config.py | 30 +++++++++++++++++++ src/axolotl/utils/trainer.py | 3 ++ 2 files changed, 33 insertions(+) diff --git a/src/axolotl/monkeypatch/accelerate/parallelism_config.py b/src/axolotl/monkeypatch/accelerate/parallelism_config.py index b2157fb6b6..9b71e914ac 100644 --- a/src/axolotl/monkeypatch/accelerate/parallelism_config.py +++ b/src/axolotl/monkeypatch/accelerate/parallelism_config.py @@ -75,3 +75,33 @@ def patch_parallelism_config(): ParallelismConfig._validate_accelerator = _validate_accelerator AcceleratorState.is_fsdp2 = property(patched_is_fsdp2) + + +def patch_prepare_cp(): + import functools + + import torch + from accelerate import Accelerator + + def patched_prepare_cp(self, *args): + if self.parallelism_config.cp_backend == "deepspeed": + return args + + from accelerate.big_modeling import _attach_context_parallel_hooks + from torch.distributed.tensor.experimental import context_parallel + from torch.distributed.tensor.experimental._attention import set_rotate_method + + cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy + set_rotate_method(cp_comm_strategy) + + self._cp_context = functools.partial( + context_parallel, mesh=self.torch_device_mesh["cp"] + ) + + for arg in args: + if isinstance(arg, torch.nn.Module): + _attach_context_parallel_hooks(arg) + + return args + + Accelerator._prepare_cp = patched_prepare_cp diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d97577d863..0a92b7d454 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -634,6 +634,9 @@ def setup_parallelism_envs(cfg): set_accelerate_parallelism_config = True os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size) os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true" + from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp + + patch_prepare_cp() if set_accelerate_parallelism_config: os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"