From eecd258cf3db4734c19579bd9564335ed7bd16fe Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 27 Jun 2025 21:09:32 +0000 Subject: [PATCH 1/3] move patches; make patch stronger --- src/axolotl/loaders/patch_manager.py | 12 ++++++++ src/axolotl/monkeypatch/ring_attn/patch.py | 29 ++++++++++++++----- .../utils/ctx_managers/sequence_parallel.py | 8 ----- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3f8116b21b..3a73f2ec88 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -65,6 +65,7 @@ def apply_pre_model_load_patches(self): self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() self._apply_gemma3_conditional_generation_forward_patch() + self._apply_sequence_parallel_patches() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -231,6 +232,17 @@ def _apply_gemma3_conditional_generation_forward_patch(self): patch_gemma3_conditional_generation_forward() + def _apply_sequence_parallel_patches(self): + """Apply sequence parallelism patches.""" + if self.cfg.sequence_parallel_degree > 1: + from axolotl.monkeypatch.ring_attn.patch import ( + patch_prepare_data_loader, + patch_prepare_device_mesh, + ) + + patch_prepare_data_loader() + patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) + def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index d83476e5a2..017b420d28 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): def patch_prepare_data_loader(): """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. - Raies: + Raises: RuntimeError: If source code to patch does not exist. """ original_fn = accelerate.data_loader.prepare_data_loader @@ -168,23 +168,34 @@ def patch_prepare_data_loader(): ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE ) + items_to_import = [] + for item in dir(accelerate.data_loader): + if item in patched_source: + items_to_import.append(item) + # Create a new function from the patched source namespace = {} exec( # pylint: disable=exec-used # nosec B102 - patched_source, accelerate.data_loader.__dict__, namespace + f"from accelerate.data_loader import ({', '.join(items_to_import)})", + globals(), + ) + exec( # pylint: disable=exec-used # nosec B102 + patched_source, globals(), namespace ) + patched_function = namespace["prepare_data_loader"] + original_fn.__code__ = patched_function.__code__ - accelerate.data_loader.prepare_data_loader = patched_function LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") -def patch_prepare_device_mesh(sequence_parallel_degree: int): +def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False): """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh that includes sequence parallelism with the specified degree. Args: - sequence_parallel_degree (int): The degree of sequence parallelism to use. + sequence_parallel_degree: The degree of sequence parallelism to use. + fsdp: Whether to use FSDP. """ def _prepare_device_mesh(self): @@ -207,12 +218,14 @@ def _prepare_device_mesh(self): ) device_ids = list(range(world_size)) - # Note that we use "cp" instead of "sp" to match the PyTorch native "context - # parallelism" implementation naming + # NOTE: We use "cp" instead of "sp" to match the PyTorch native "context + # parallelism" implementation naming. + # NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we + # only use "fsdp" and "cp" for the device mesh. return dist.DeviceMesh( "cuda", torch.tensor(device_ids).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), + mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"), ) # Replace the original method with our new method diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index f429cd2ae5..1ac805a73c 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -12,8 +12,6 @@ from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, register_ring_attn, update_ring_attn_params, ) @@ -238,12 +236,6 @@ def _register_ring_attn(self): ring_attn_func=self.ring_attn_func, ) - # Patches for accelerate functionality - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.sequence_parallel_degree - ) - def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): From 4f1f8d19eb37e975451de748e04fcd9e1dbc7949 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 28 Jun 2025 10:00:18 -0400 Subject: [PATCH 2/3] fix broken tests --- .../lora_kernels/test_lora_kernel_patching.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 56ce5a8b9a..b4dc5de542 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -396,7 +396,7 @@ def test_model_architecture(model_config): # pylint: disable=duplicate-code -def test_kernel_training_integration(): +def test_kernel_training_integration(temp_dir): """Test model loading with kernel patches enabled.""" from axolotl.cli.utils import load_model_and_tokenizer @@ -426,6 +426,14 @@ def test_kernel_training_integration(): } ) + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + # Load model model, _, _ = load_model_and_tokenizer(cfg=cfg) @@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir): assert found_patched_attn -def test_kernel_training_integration_dropout_non_zero(): +def test_kernel_training_integration_dropout_non_zero(temp_dir): """Test model loading with dropout non-zero should not patch.""" from axolotl.cli.utils import load_model_and_tokenizer @@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero(): } ) + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + # Get original attention class attention_cls = get_attention_cls_from_config(cfg) From f71a1c7e56202fa3cb28a3e5c5db3a22435779fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 28 Jun 2025 14:06:04 -0400 Subject: [PATCH 3/3] guard sequence_parallel_degree comparison against none --- src/axolotl/loaders/patch_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3a73f2ec88..610e87c7b7 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -234,7 +234,7 @@ def _apply_gemma3_conditional_generation_forward_patch(self): def _apply_sequence_parallel_patches(self): """Apply sequence parallelism patches.""" - if self.cfg.sequence_parallel_degree > 1: + if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: from axolotl.monkeypatch.ring_attn.patch import ( patch_prepare_data_loader, patch_prepare_device_mesh,