Skip to content

fix: make prepare_context_parallel_inputs no-op#3520

Merged
NanoCode012 merged 1 commit into
mainfrom
fix/cp-waste
May 13, 2026
Merged

fix: make prepare_context_parallel_inputs no-op#3520
NanoCode012 merged 1 commit into
mainfrom
fix/cp-waste

Conversation

@NanoCode012

@NanoCode012 NanoCode012 commented Mar 20, 2026

Copy link
Copy Markdown
Collaborator

Description

Followup to #3498

As we don't depend on transformers to call CP partition and use ring FA, we could bypass this call and remove a patch.

@lorenzbaraldi what do you think of this change?

Motivation and Context

How has this been tested?

AI Usage Disclaimer

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • Bug Fixes

    • Improved context parallelism handling to prevent double input partitioning and enhance FlashAttention compatibility.
  • Tests

    • Removed outdated context parallelism patch tests.

@coderabbitai

coderabbitai Bot commented Mar 20, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors context-parallel input patching by removing the old dynamic FlashAttention guard-relaxation approach and replacing it with a simpler no-op patch that prevents upstream CP partitioning.

Changes

Context-Parallel Patching Refactor

Layer / File(s) Summary
Old Trainer Patch Module Removal
src/axolotl/monkeypatch/transformers/trainer_context_parallel.py
Entire module is deleted, including the dynamic guard-relaxation patch function, string constants, and idempotency tracking.
Patch Manager Integration
src/axolotl/loaders/patch_manager.py
PatchManager.apply_pre_model_load_patches no longer invokes the deleted patch_prepare_context_parallel_inputs function.
New No-Op CP Implementation
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Adds Trainer import, defines _noop_prepare_context_parallel_inputs helper returning nullcontext and unchanged inputs, and patches both Trainer and Accelerator methods to prevent double CP partitioning.
Test Module Removal
tests/monkeypatch/test_trainer_context_parallel_patch.py
Entire test module is deleted, removing tests for guard-pattern swapping and patch idempotency.

Estimated Code Review Effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested Reviewers

  • winglian
  • SalmanMohammadi
🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: make prepare_context_parallel_inputs no-op' accurately describes the main change - removing the prepare_context_parallel_inputs patch application and making the function a no-op.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/cp-waste

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@lorenzbaraldi

Copy link
Copy Markdown
Contributor

Yes, this makes sense to me

@NanoCode012 NanoCode012 marked this pull request as ready for review May 8, 2026 08:34

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/monkeypatch/accelerate/parallelism_config.py`:
- Around line 99-106: The no-op monkeypatch
_noop_prepare_context_parallel_inputs is skipping trainer-side CP input
preparation for all backends; modify it to detect when cp_backend == "deepspeed"
and in that case call the original Trainer._prepare_context_parallel_inputs
implementation (save the original before monkeypatching, e.g.
_orig_prepare_context_parallel_inputs) so DeepSpeed follows the same guarded
path as patched_prepare_cp; otherwise keep returning contextlib.nullcontext,
inputs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6c29ce12-9025-4dbc-833e-7d40cdc79d78

📥 Commits

Reviewing files that changed from the base of the PR and between 5352d41 and e029d7d.

📒 Files selected for processing (4)
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/accelerate/parallelism_config.py
  • src/axolotl/monkeypatch/transformers/trainer_context_parallel.py
  • tests/monkeypatch/test_trainer_context_parallel_patch.py
💤 Files with no reviewable changes (3)
  • src/axolotl/monkeypatch/transformers/trainer_context_parallel.py
  • tests/monkeypatch/test_trainer_context_parallel_patch.py
  • src/axolotl/loaders/patch_manager.py

Comment on lines +99 to +106
def _noop_prepare_context_parallel_inputs(self, model, inputs):
return contextlib.nullcontext, inputs

# prevent double CP partition
Accelerator._prepare_cp = patched_prepare_cp

# remove unneeded calculation upstream
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether DeepSpeed can reach this patch and whether there is regression coverage.
rg -nC4 '\bpatch_prepare_cp\s*\(' src tests
rg -nC4 'cp_backend.*deepspeed|deepspeed.*cp_backend' src tests
rg -nC4 '_prepare_context_parallel_inputs' src tests

Repository: axolotl-ai-cloud/axolotl

Length of output: 1979


🏁 Script executed:

# Verify the review's claims about the code logic and patterns
# First, check if we can establish basic context about the file structure

cd /tmp && echo "Testing basic access patterns for Python code analysis"

# Create a test to verify the logical consistency of the review's argument
cat > verify_logic.py << 'EOF'
# Logical verification of the review comment's argument:
# 
# Current state (per review):
# 1. Line 87: if self.parallelism_config.cp_backend == "deepspeed": return args
# 2. Line 106: Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
#
# The review's concern: Line 106 applies globally, negating the DeepSpeed guard at Line 87
# 
# Question: Are these two guards independent paths or do they interact?
# - patched_prepare_cp is a method on Accelerator
# - _prepare_context_parallel_inputs is a method on Trainer
# 
# Both are being monkeypatched but on different classes
# The issue is whether both need to be DeepSpeed-aware for consistency

print("Logical Structure Analysis:")
print("- patched_prepare_cp guards DeepSpeed -> returns args unchanged")
print("- _noop_prepare_context_parallel_inputs applies globally -> no-op for all backends")
print("")
print("Potential Impact:")
print("- If both are called in a DeepSpeed CP scenario:")
print("  - patched_prepare_cp will pass args through (correct)")
print("  - _prepare_context_parallel_inputs will return nullcontext (may be wrong)")
print("")
print("The review's fix adds a backend check to _prepare_context_parallel_inputs")
print("This ensures symmetrical behavior across both monkeypatches")
EOF

python verify_logic.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 633


🌐 Web query:

transformers.Trainer._prepare_context_parallel_inputs context parallel inputs

💡 Result:

transformers.Trainer._prepare_context_parallel_inputs is a private method in the Hugging Face Transformers library's Trainer class, used to prepare model inputs for context parallelism (CP), a training technique that splits sequences across GPUs using Ring Attention (typically with FSDP2) to handle long contexts efficiently [1][2][3]. It is called during the training_step, e.g., cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs) [4][5]. The method expects inputs as dict[str, torch.Tensor], sharding relevant tensors (like input_ids, attention_mask) along the sequence dimension for CP groups, and returns a context manager (cp_context) for entering the parallelism scope and modified inputs [5][6]. It is part of support for ParallelismConfig, where CP size is set via accelerator_config (e.g., cp_size >1), alongside TP/SP/DP [1][2][3]. Errors occur if inputs are lists instead of tensors (e.g., in custom trainers like GRPOTrainer) [5]. Official docs confirm Trainer methods like get_cp_size retrieve the context parallel size [1][2]. Source in src/transformers/trainer.py (main branch) [7][8].

Citations:


🌐 Web query:

accelerate Accelerator _prepare_cp context parallel backend

💡 Result:

In the Hugging Face Accelerate library, _prepare_cp is an internal method of the Accelerator class used to prepare models, optimizers, dataloaders, and other objects for context parallelism (CP) training when enabled via ParallelismConfig(cp_enabled=True) or CLI flags like --parallelism-config-cp-size [1][2][3]. It is called within the prepare method after tensor parallelism preparation (if enabled) and before FP8 or other backends, specifically when self.parallelism_config.cp_enabled is True: args = self._prepare_cp(*args) [1]. This method initializes context parallelism support, including creating the _cp_context object required for the torch.distributed.tensor.experimental.context_parallel wrapper, which is used as a context manager in training loops (noop if CP not enabled) [4][5][2]. Context parallelism shards QKV matrices across the sequence dimension using primitives like allgather/alltoall for rotation, primarily supported with FSDP2 in a joint mesh of cp_size x dp_shard_size for memory efficiency [4][5][3]. Configuration uses ParallelismConfig(cp_size=N, cp_handler=TorchContextParallelConfig(cp_comm_strategy="allgather")) or accelerate launch --parallelism-config-cp-size N [3]. A known issue (#3856) involved _prepare_cp skipping initialization due to a sequence parallelism check mixup, preventing _cp_context creation [2].

Citations:


Add backend check to prevent skipping trainer-side context parallel preparation in DeepSpeed mode.

Line 106 applies the no-op globally to all backends, but Line 87 explicitly guards DeepSpeed in patched_prepare_cp to preserve its path. This creates an asymmetry: while patched_prepare_cp returns early for DeepSpeed, Trainer._prepare_context_parallel_inputs still skips the upstream preparation for all backends including DeepSpeed. In a DeepSpeed context parallel run, this means the trainer-side CP input preparation is bypassed despite the backend guard.

The suggested fix adds a backend check to the monkeypatched method to dispatch back to the original implementation when cp_backend == "deepspeed", ensuring both guards act consistently.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/accelerate/parallelism_config.py` around lines 99 -
106, The no-op monkeypatch _noop_prepare_context_parallel_inputs is skipping
trainer-side CP input preparation for all backends; modify it to detect when
cp_backend == "deepspeed" and in that case call the original
Trainer._prepare_context_parallel_inputs implementation (save the original
before monkeypatching, e.g. _orig_prepare_context_parallel_inputs) so DeepSpeed
follows the same guarded path as patched_prepare_cp; otherwise keep returning
contextlib.nullcontext, inputs.

@codecov

codecov Bot commented May 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...olotl/monkeypatch/accelerate/parallelism_config.py 0.00% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NanoCode012 NanoCode012 merged commit 4f4d5d8 into main May 13, 2026
13 of 15 checks passed
@NanoCode012 NanoCode012 deleted the fix/cp-waste branch May 13, 2026 05:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants