fix: make prepare_context_parallel_inputs no-op#3520
Conversation
📝 WalkthroughWalkthroughThis PR refactors context-parallel input patching by removing the old dynamic FlashAttention guard-relaxation approach and replacing it with a simpler no-op patch that prevents upstream CP partitioning. ChangesContext-Parallel Patching Refactor
Estimated Code Review Effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested Reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Yes, this makes sense to me |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/monkeypatch/accelerate/parallelism_config.py`:
- Around line 99-106: The no-op monkeypatch
_noop_prepare_context_parallel_inputs is skipping trainer-side CP input
preparation for all backends; modify it to detect when cp_backend == "deepspeed"
and in that case call the original Trainer._prepare_context_parallel_inputs
implementation (save the original before monkeypatching, e.g.
_orig_prepare_context_parallel_inputs) so DeepSpeed follows the same guarded
path as patched_prepare_cp; otherwise keep returning contextlib.nullcontext,
inputs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6c29ce12-9025-4dbc-833e-7d40cdc79d78
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/accelerate/parallelism_config.pysrc/axolotl/monkeypatch/transformers/trainer_context_parallel.pytests/monkeypatch/test_trainer_context_parallel_patch.py
💤 Files with no reviewable changes (3)
- src/axolotl/monkeypatch/transformers/trainer_context_parallel.py
- tests/monkeypatch/test_trainer_context_parallel_patch.py
- src/axolotl/loaders/patch_manager.py
| def _noop_prepare_context_parallel_inputs(self, model, inputs): | ||
| return contextlib.nullcontext, inputs | ||
|
|
||
| # prevent double CP partition | ||
| Accelerator._prepare_cp = patched_prepare_cp | ||
|
|
||
| # remove unneeded calculation upstream | ||
| Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether DeepSpeed can reach this patch and whether there is regression coverage.
rg -nC4 '\bpatch_prepare_cp\s*\(' src tests
rg -nC4 'cp_backend.*deepspeed|deepspeed.*cp_backend' src tests
rg -nC4 '_prepare_context_parallel_inputs' src testsRepository: axolotl-ai-cloud/axolotl
Length of output: 1979
🏁 Script executed:
# Verify the review's claims about the code logic and patterns
# First, check if we can establish basic context about the file structure
cd /tmp && echo "Testing basic access patterns for Python code analysis"
# Create a test to verify the logical consistency of the review's argument
cat > verify_logic.py << 'EOF'
# Logical verification of the review comment's argument:
#
# Current state (per review):
# 1. Line 87: if self.parallelism_config.cp_backend == "deepspeed": return args
# 2. Line 106: Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
#
# The review's concern: Line 106 applies globally, negating the DeepSpeed guard at Line 87
#
# Question: Are these two guards independent paths or do they interact?
# - patched_prepare_cp is a method on Accelerator
# - _prepare_context_parallel_inputs is a method on Trainer
#
# Both are being monkeypatched but on different classes
# The issue is whether both need to be DeepSpeed-aware for consistency
print("Logical Structure Analysis:")
print("- patched_prepare_cp guards DeepSpeed -> returns args unchanged")
print("- _noop_prepare_context_parallel_inputs applies globally -> no-op for all backends")
print("")
print("Potential Impact:")
print("- If both are called in a DeepSpeed CP scenario:")
print(" - patched_prepare_cp will pass args through (correct)")
print(" - _prepare_context_parallel_inputs will return nullcontext (may be wrong)")
print("")
print("The review's fix adds a backend check to _prepare_context_parallel_inputs")
print("This ensures symmetrical behavior across both monkeypatches")
EOF
python verify_logic.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 633
🌐 Web query:
transformers.Trainer._prepare_context_parallel_inputs context parallel inputs
💡 Result:
transformers.Trainer._prepare_context_parallel_inputs is a private method in the Hugging Face Transformers library's Trainer class, used to prepare model inputs for context parallelism (CP), a training technique that splits sequences across GPUs using Ring Attention (typically with FSDP2) to handle long contexts efficiently [1][2][3]. It is called during the training_step, e.g., cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs) [4][5]. The method expects inputs as dict[str, torch.Tensor], sharding relevant tensors (like input_ids, attention_mask) along the sequence dimension for CP groups, and returns a context manager (cp_context) for entering the parallelism scope and modified inputs [5][6]. It is part of support for ParallelismConfig, where CP size is set via accelerator_config (e.g., cp_size >1), alongside TP/SP/DP [1][2][3]. Errors occur if inputs are lists instead of tensors (e.g., in custom trainers like GRPOTrainer) [5]. Official docs confirm Trainer methods like get_cp_size retrieve the context parallel size [1][2]. Source in src/transformers/trainer.py (main branch) [7][8].
Citations:
- 1: https://huggingface.co/docs/transformers/en/main_classes/trainer
- 2: https://huggingface.co/docs/transformers/main_classes/trainer
- 3: https://hf.co/docs/transformers/v5.5.0/en/main_classes/trainer
- 4: AttributeError when using context_parallel_size #3204
- 5: ParallelismConfig not applied in GRPOTrainer: Trainer._prepare_context_parallel_inputs expects dict[torch.Tensor] but receives list[dict[list]] huggingface/trl#4016
- 6: https://huggingface.co/docs/accelerate/concept_guides/context_parallelism
- 7: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
- 8: https://raw.githubusercontent.com/huggingface/transformers/main/src/transformers/trainer.py
🌐 Web query:
accelerate Accelerator _prepare_cp context parallel backend
💡 Result:
In the Hugging Face Accelerate library, _prepare_cp is an internal method of the Accelerator class used to prepare models, optimizers, dataloaders, and other objects for context parallelism (CP) training when enabled via ParallelismConfig(cp_enabled=True) or CLI flags like --parallelism-config-cp-size [1][2][3]. It is called within the prepare method after tensor parallelism preparation (if enabled) and before FP8 or other backends, specifically when self.parallelism_config.cp_enabled is True: args = self._prepare_cp(*args) [1]. This method initializes context parallelism support, including creating the _cp_context object required for the torch.distributed.tensor.experimental.context_parallel wrapper, which is used as a context manager in training loops (noop if CP not enabled) [4][5][2]. Context parallelism shards QKV matrices across the sequence dimension using primitives like allgather/alltoall for rotation, primarily supported with FSDP2 in a joint mesh of cp_size x dp_shard_size for memory efficiency [4][5][3]. Configuration uses ParallelismConfig(cp_size=N, cp_handler=TorchContextParallelConfig(cp_comm_strategy="allgather")) or accelerate launch --parallelism-config-cp-size N [3]. A known issue (#3856) involved _prepare_cp skipping initialization due to a sequence parallelism check mixup, preventing _cp_context creation [2].
Citations:
- 1: https://github.com/huggingface/accelerate/blob/main/src/accelerate/accelerator.py
- 2: context parallel
_cp_contextnot created huggingface/accelerate#3856 - 3: https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md
- 4: https://hf.co/docs/accelerate/concept_guides/context_parallelism
- 5: https://huggingface.co/docs/accelerate/concept_guides/context_parallelism
Add backend check to prevent skipping trainer-side context parallel preparation in DeepSpeed mode.
Line 106 applies the no-op globally to all backends, but Line 87 explicitly guards DeepSpeed in patched_prepare_cp to preserve its path. This creates an asymmetry: while patched_prepare_cp returns early for DeepSpeed, Trainer._prepare_context_parallel_inputs still skips the upstream preparation for all backends including DeepSpeed. In a DeepSpeed context parallel run, this means the trainer-side CP input preparation is bypassed despite the backend guard.
The suggested fix adds a backend check to the monkeypatched method to dispatch back to the original implementation when cp_backend == "deepspeed", ensuring both guards act consistently.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/monkeypatch/accelerate/parallelism_config.py` around lines 99 -
106, The no-op monkeypatch _noop_prepare_context_parallel_inputs is skipping
trainer-side CP input preparation for all backends; modify it to detect when
cp_backend == "deepspeed" and in that case call the original
Trainer._prepare_context_parallel_inputs implementation (save the original
before monkeypatching, e.g. _orig_prepare_context_parallel_inputs) so DeepSpeed
follows the same guarded path as patched_prepare_cp; otherwise keep returning
contextlib.nullcontext, inputs.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Description
Followup to #3498
As we don't depend on transformers to call CP partition and use ring FA, we could bypass this call and remove a patch.
@lorenzbaraldi what do you think of this change?
Motivation and Context
How has this been tested?
AI Usage Disclaimer
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
Bug Fixes
Tests