Skip to content

Distributed/ND-Parallel #2977

Merged
winglian merged 65 commits into
mainfrom
ndp
Jul 31, 2025
Merged

Distributed/ND-Parallel #2977
winglian merged 65 commits into
mainfrom
ndp

Conversation

@salmanmohammadi
Copy link
Copy Markdown
Contributor

@salmanmohammadi salmanmohammadi commented Jul 24, 2025

Description

Supersedes #2947

Needs upstream:

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for context parallelism, replacing sequence parallelism throughout the configuration and user interface.
    • Introduced new configuration options for data parallel sharding and replication.
    • Enhanced distributed training support with improved model saving for parallel and sharded scenarios.
    • Improved GPU memory usage reporting and logging, now including peak and maximum statistics.
    • Added compatibility checks for LIGER and tensor parallelism settings.
    • Introduced new end-to-end tests for tensor parallelism.
  • Bug Fixes

    • Fixed compatibility and validation issues with parallelism and optimizer configurations.
    • Improved error handling and logging for optimizer and distributed training scenarios.
  • Documentation

    • Updated documentation and configuration examples to use the new context parallelism terminology and parameters.
  • Tests

    • Updated and added tests to reflect new parallelism configuration.
    • Added Hopper GPU-specific test decorators and new distributed parallelism tests.
  • Chores

    • Updated dependency versions for transformers and accelerate.
    • Improved environment variable handling in test and CI scripts.
    • Enhanced logging and performance instrumentation for batch sampling.
  • Refactor

    • Unified and streamlined parallelism configuration and validation logic.
    • Replaced manual distributed process group management with device mesh abstractions for attention mechanisms.
    • Deprecated sequence parallelism configuration in favor of context parallelism.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jul 24, 2025

📝 Walkthrough

Walkthrough

This change refactors the codebase to replace the sequence parallelism configuration parameter (sequence_parallel_degree) with context_parallel_size across all modules, documentation, and tests. It also introduces device mesh-based registration for ring attention, updates distributed parallelism handling, adds new validators, and improves GPU memory usage logging and dataset saving logic.

Changes

Cohort / File(s) Summary
Sequence Parallelism → Context Parallelism Refactor
src/axolotl/core/trainers/grpo/args.py, src/axolotl/core/trainers/grpo/sampler.py, src/axolotl/core/trainers/grpo/trainer.py, src/axolotl/core/trainers/grpo/__init__.py, src/axolotl/core/trainers/rl.py, src/axolotl/core/trainers/dpo/trainer.py, src/axolotl/core/trainers/base.py, src/axolotl/core/trainers/trl.py, src/axolotl/core/trainers/mixins/__init__.py, src/axolotl/core/trainers/mixins/distributed_parallel.py, src/axolotl/core/builders/base.py, src/axolotl/cli/merge_lora.py, src/axolotl/train.py, src/axolotl/utils/ctx_managers/sequence_parallel.py, src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/validation.py, src/axolotl/utils/trainer.py, src/axolotl/loaders/model.py, src/axolotl/monkeypatch/ring_attn/patch.py, src/axolotl/monkeypatch/ring_attn/__init__.py, src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py, src/axolotl/loaders/patch_manager.py, docs/sequence_parallelism.qmd, examples/alst/llama3-8b-deepspeed-alst.yaml, tests/core/test_builders.py, tests/e2e/multigpu/patched/test_sp.py, tests/e2e/multigpu/solo/test_grpo.py, tests/e2e/test_load_model.py
Replaces sequence_parallel_degree with context_parallel_size everywhere, updates related logic, documentation, and tests, and adapts distributed and context parallelism handling to use device mesh abstractions and new registration methods.
Distributed Parallelism and Trainer Mixins
src/axolotl/core/trainers/mixins/distributed_parallel.py, src/axolotl/core/trainers/mixins/__init__.py, src/axolotl/core/trainers/base.py, src/axolotl/core/trainers/dpo/trainer.py, src/axolotl/core/trainers/trl.py
Introduces and applies DistributedParallelMixin to all relevant trainer classes to ensure correct model state saving in distributed/FSDP scenarios.
Monkeypatch and Ring Attention Refactor
src/axolotl/monkeypatch/ring_attn/patch.py, src/axolotl/monkeypatch/ring_attn/__init__.py, src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py, src/axolotl/loaders/patch_manager.py
Removes manual patching of accelerate internals, replaces ring attention registration with device mesh-based approach, and introduces a patch for transformers.modeling_flash_attention_utils._prepare_from_posids.
Parallelism Config and Validation Enhancements
src/axolotl/loaders/model.py, src/axolotl/utils/schemas/validation.py, src/axolotl/utils/schemas/config.py, src/axolotl/integrations/liger/args.py, src/axolotl/utils/environment.py, tests/test_loaders.py
Adds device mesh and parallelism config support, new distributed validation mixin, additional config fields, and new validators for Liger and parallel optimizer compatibility.
GPU Memory Usage Improvements
src/axolotl/utils/bench.py, src/axolotl/utils/callbacks/__init__.py
Refactors GPU memory usage functions to report peak and max values, adds a unified retrieval function, and improves logging and WandB integration.
Dataset Saving and Data Utilities
src/axolotl/utils/data/shared.py
Changes dataset save process to use a minimum rows per process of 256 instead of a hardcoded divisor of 8.
Test Suite and E2E Enhancements
tests/e2e/multigpu/test_tp.py, tests/e2e/multigpu/test_fp8_fsdp2.py, tests/e2e/utils.py, tests/e2e/test_load_model.py, tests/e2e/multigpu/solo/test_grpo.py, tests/core/test_builders.py, tests/e2e/multigpu/patched/test_sp.py, tests/test_loaders.py, tests/utils/schemas/validation/test_fsdp.py, tests/e2e/patched/test_sp.py
Adds a tensor parallelism E2E test, restricts FP8 test to Hopper GPUs, adds utility decorator, updates tests for new config, and removes legacy sequence parallelism tests and FSDP sharded state dict test.
Requirements and Setup Adjustments
requirements.txt, setup.py
Updates transformers and accelerate versions, adjusts vllm dependency handling for PyTorch versions.
Miscellaneous
src/axolotl/core/trainers/mamba.py, src/axolotl/integrations/kd/trainer.py, src/axolotl/core/trainers/mixins/checkpoints.py, src/axolotl/monkeypatch/accelerate/fsdp2.py, src/axolotl/utils/samplers/multipack.py, cicd/single_gpu.py, cicd/multigpu.sh
Minor changes: disables pylint warnings, expands exception handling, adds timing logs, sets env vars for tests, and tweaks test runner options.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • #2699: Refactors sequence parallelism context manager, consolidates patching, and improves ring attention registration—related as both PRs modify sequence parallelism handling, but at different abstraction levels.
  • #2948: Limits number of processes for dataset saving, directly related to the change in this PR adjusting the minimum rows per process for dataset saving.
  • #2842: Adds sequence parallel patching logic, which is later removed or replaced in this PR, making the changes directly related but opposite in effect.

Suggested labels

enhancement, review requested

Suggested reviewers

  • SalmanMohammadi
  • djsaunde
  • NanoCode012

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ndp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@winglian winglian marked this pull request as ready for review July 24, 2025 20:08
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jul 24, 2025

📖 Documentation Preview: https://688baeb302c8405999dacbe0--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit f8df5bf

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
src/axolotl/utils/ctx_managers/sequence_parallel.py (1)

158-166: Consider removing or implementing the commented code.

This commented-out code appears to be an alternative implementation for computing the global token count using all-reduce. Either implement it properly or remove it to avoid confusion.

src/axolotl/utils/schemas/validation.py (1)

1253-1259: Flash attention compatibility patch needs documentation.

The aliasing of _flash_supports_window_size to _flash_supports_window appears to be a compatibility fix. Consider adding a comment explaining why this patching is necessary.

 try:
     import transformers.modeling_flash_attention_utils
 
+    # Compatibility patch: ring_flash_attn expects _flash_supports_window
+    # but newer transformers versions may have _flash_supports_window_size
     # pylint: disable=protected-access
     transformers.modeling_flash_attention_utils._flash_supports_window_size = (
         transformers.modeling_flash_attention_utils._flash_supports_window
     )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1407aac and 0ea778d.

📒 Files selected for processing (25)
  • docs/sequence_parallelism.qmd (3 hunks)
  • requirements.txt (1 hunks)
  • src/axolotl/cli/merge_lora.py (1 hunks)
  • src/axolotl/core/builders/base.py (3 hunks)
  • src/axolotl/core/builders/rl.py (1 hunks)
  • src/axolotl/core/trainers/grpo/__init__.py (1 hunks)
  • src/axolotl/core/trainers/grpo/args.py (1 hunks)
  • src/axolotl/core/trainers/grpo/sampler.py (4 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (8 hunks)
  • src/axolotl/core/trainers/mixins/checkpoints.py (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/loaders/model.py (5 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/__init__.py (2 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (7 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (5 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (5 hunks)
  • src/axolotl/utils/trainer.py (3 hunks)
  • tests/core/test_builders.py (1 hunks)
  • tests/e2e/multigpu/patched/test_sp.py (1 hunks)
  • tests/e2e/multigpu/solo/test_grpo.py (1 hunks)
  • tests/e2e/patched/test_sp.py (6 hunks)
🧰 Additional context used
🧠 Learnings (4)
tests/e2e/patched/test_sp.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

src/axolotl/core/trainers/grpo/__init__.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

src/axolotl/core/builders/rl.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

src/axolotl/utils/schemas/validation.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

🧬 Code Graph Analysis (4)
src/axolotl/monkeypatch/ring_attn/__init__.py (1)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
  • register_ring_attn_from_device_mesh (248-325)
src/axolotl/core/trainers/mixins/checkpoints.py (1)
src/axolotl/utils/logging.py (1)
  • warning_once (31-39)
src/axolotl/utils/ctx_managers/sequence_parallel.py (2)
src/axolotl/monkeypatch/ring_attn/patch.py (2)
  • get_ring_attn_group (62-66)
  • register_ring_attn_from_device_mesh (248-325)
tests/e2e/patched/test_sp.py (1)
  • partial_state (25-28)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
  • load (151-179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
🔇 Additional comments (44)
src/axolotl/core/trainers/mixins/checkpoints.py (1)

13-23: Excellent improvements for distributed training robustness.

The changes enhance error handling and logging for distributed training scenarios:

  1. Exception handling expansion: Adding KeyError alongside NotImplementedError is appropriate, as optimizer state dictionaries may have missing keys in complex parallelism setups like FSDP2.

  2. Improved logging: Using LOG.warning_once with main_process_only=True follows best practices for distributed training by preventing duplicate log messages across processes and avoiding log spam.

  3. Documentation: The TODO comment provides valuable context about the known FSDP2 limitation.

These changes align well with the broader parallelism improvements in this PR and enhance the checkpoint saving reliability.

tests/e2e/multigpu/patched/test_sp.py (1)

70-70: LGTM: Parameter rename aligns with codebase refactoring.

The change from sequence_parallel_degree to context_parallel_size is consistent with the broader renaming effort across the codebase. The test configuration maintains the same functionality while using the updated parameter name.

src/axolotl/monkeypatch/accelerate/fsdp2.py (1)

257-259: Ensure device_mesh and parallelism_config Are Always Initialized

Before using

accelerator.state.device_mesh[
    accelerator.state.parallelism_config.model_shard_dim_names
]

make sure that:

  • accelerator.state.parallelism_config has been assigned (e.g., via parallelism_config.build_device_mesh(...)) in every model‐loading and accelerator setup path.
  • accelerator.state.device_mesh is non-None and contains all keys named in model_shard_dim_names.

Points to review in your codebase:

  • src/axolotl/loaders/model.py around line 410, where device_mesh = parallelism_config.build_device_mesh("cuda") is created.
  • Any other initialization branch (e.g., CPU or multi-GPU flows) to confirm they invoke build_device_mesh and assign to accelerator.state.device_mesh before FSDP2 is instantiated.
src/axolotl/cli/merge_lora.py (1)

73-73: LGTM: Consistent parameter renaming in CLI interface.

The change from sequence_parallel_degree to context_parallel_size maintains consistency with the broader refactoring effort while preserving the same functionality and default value.

src/axolotl/core/trainers/grpo/args.py (1)

16-16: LGTM: Dataclass field rename maintains consistency.

The parameter rename from sequence_parallel_degree to context_parallel_size in the AxolotlGRPOConfig dataclass aligns with the codebase-wide refactoring while preserving the same type annotation and default value.

tests/core/test_builders.py (1)

67-67: LGTM: Test configuration updated for parameter consistency.

The change from sequence_parallel_degree to context_parallel_size in the base test configuration fixture ensures all test cases use the updated parameter name while maintaining the same test behavior.

src/axolotl/core/builders/rl.py (1)

56-56: LGTM! Clean parameter rename.

The parameter rename from sequence_parallel_degree to context_parallel_size is consistent with the broader refactoring effort across the codebase. The logic remains identical and correct.

src/axolotl/monkeypatch/ring_attn/__init__.py (1)

11-11: LGTM! Proper API exposure for DeviceMesh-based ring attention.

The addition of register_ring_attn_from_device_mesh to both the import and __all__ list correctly exposes the new DeviceMesh-based ring attention registration function as part of the public API. This aligns with the modernization effort to use accelerate's device mesh abstractions.

Also applies to: 21-21

tests/e2e/multigpu/solo/test_grpo.py (1)

301-301: LGTM! Test configuration updated for parameter rename.

The test configuration correctly uses the new context_parallel_size parameter instead of the deprecated sequence_parallel_degree. The test semantics remain unchanged while staying consistent with the broader parameter renaming effort.

src/axolotl/train.py (1)

205-205: LGTM! Consistent parameter rename throughout training logic.

Both the conditional check and the parameter passed to SequenceParallelContextManager have been correctly updated to use context_parallel_size instead of sequence_parallel_degree. The logic remains identical while maintaining consistency with the broader refactoring effort.

Also applies to: 213-213

docs/sequence_parallelism.qmd (1)

25-25: LGTM! Documentation updated to reflect parameter rename.

All documentation references have been consistently updated from sequence_parallel_degree to context_parallel_size. The examples, explanations, and usage descriptions maintain the same semantics while staying aligned with the codebase changes. This ensures users have accurate documentation for the new parameter name.

Also applies to: 33-33, 69-69, 92-92, 94-94, 99-99

tests/e2e/patched/test_sp.py (2)

114-114: LGTM: Consistent parameter rename in test.

The update from sequence_parallel_degree to context_parallel_size aligns with the codebase-wide refactoring for clearer terminology.


159-285: LGTM: Comprehensive test coverage maintained with updated parameter name.

All test configurations have been systematically updated to use context_parallel_size instead of sequence_parallel_degree. The test logic, assertions, and expected behavior remain unchanged, ensuring continued validation of the sequence parallelism functionality with the new parameter name.

src/axolotl/core/trainers/grpo/__init__.py (1)

85-86: LGTM: Consistent parameter rename in GRPO configuration.

The update from cfg.sequence_parallel_degree to cfg.context_parallel_size maintains the same conditional logic while aligning with the codebase-wide terminology standardization.

requirements.txt (1)

16-18: Verify Git dependency usage and pin commit hashes

Using branch names for transformers and a custom accelerate fork introduces non-determinism and potential instability. Please confirm the following:

• Intent: Are these Git dependencies temporary until upstream releases include the needed changes, or intended as permanent sources?
• Pinning: For reproducible builds, replace branch references with the exact commit hashes you’ve verified (shown below).
• Timeline: When do you expect the upstream changes to land in an official release?

Proposed snippet for requirements.txt (Lines 16–18):

-transformers @ git+https://github.com/huggingface/transformers.git@main
-tokenizers>=0.21.1
-accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config
+transformers @ git+https://github.com/huggingface/transformers.git@5a81d7e0b388fb2b86fc1279cdc07d9dc7e84b4c
+tokenizers>=0.21.1
+accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@168b520279a21a4d3fb89413aacb86c70a2f0a99
src/axolotl/utils/trainer.py (3)

445-445: LGTM: Consistent parameter rename in training step calculations.

The update from cfg.sequence_parallel_degree to cfg.context_parallel_size maintains the same multiplication logic for total step calculations while aligning with the codebase terminology standardization.


487-487: LGTM: Consistent parameter rename maintained.

Same parameter rename applied consistently in the data loader length calculation path.


514-514: LGTM: Consistent parameter rename in fallback calculation.

The parameter rename is consistently applied in the fallback total steps calculation, maintaining mathematical correctness.

src/axolotl/loaders/patch_manager.py (1)

264-269: LGTM: Consistent parameter rename in patch application.

The parameter rename from sequence_parallel_degree to context_parallel_size is applied consistently within the sequence parallel patches method, maintaining the same conditional logic.

src/axolotl/utils/schemas/config.py (1)

647-664: LGTM! Well-structured parameter refactoring with proper backwards compatibility.

The changes correctly:

  • Add the new dp_shard_size field with clear documentation
  • Maintain backwards compatibility by keeping sequence_parallel_degree with deprecation notice
  • Introduce context_parallel_size with comprehensive documentation explaining its purpose
  • Follow consistent field definition patterns
src/axolotl/core/trainers/grpo/sampler.py (4)

23-23: LGTM! Documentation updated consistently.

The comment correctly reflects the new parameter name context_parallel_size.


48-48: LGTM! Parameter documentation updated.

The docstring parameter description correctly uses the new context_parallel_size naming.


62-62: LGTM! Constructor parameter renamed consistently.

The parameter name change from sequence_parallel_degree to context_parallel_size is correct and maintains the same functionality.


80-82: LGTM! Internal attribute usage updated consistently.

All internal references to the parameter have been correctly updated:

  • self.context_parallel_size
  • self.num_sp_groups = world_size // context_parallel_size
  • self.sp_group_id = rank // context_parallel_size

The logic remains unchanged, only the parameter name is updated.

src/axolotl/core/trainers/grpo/trainer.py (4)

103-103: LGTM! Calculation updated with new parameter name.

The calculation num_sp_groups = num_processes // self.args.context_parallel_size correctly uses the renamed parameter while maintaining the same logic.


133-138: LGTM! Error message updated consistently.

The error message correctly references the new parameter name context_parallel_size and maintains clear, informative messaging about the validation requirements.


170-172: LGTM! Sampler initialization updated correctly.

Both the batch size calculation and the sampler parameter are correctly updated to use context_parallel_size.


238-238: LGTM! All conditional checks and calculations updated consistently.

All references to the parameter in:

  • Conditional checks (if self.args.context_parallel_size > 1:)
  • Variable assignments (context_parallel_size = self.args.context_parallel_size)
  • Calculations (group_leader_rank = sp_group_id * context_parallel_size)

are correctly updated while maintaining the same logic flow.

Also applies to: 311-311, 314-315, 322-322, 338-338, 355-355, 362-362, 586-586

src/axolotl/core/builders/base.py (3)

30-30: LGTM! Correct import added for AcceleratorConfig.

The import is properly added to support the new accelerator configuration handling.


438-449: LGTM! Improved accelerator configuration handling.

The method now properly instantiates AcceleratorConfig objects instead of passing raw dictionaries, which provides better type safety and validation. The logic correctly:

  • Preserves the use_configured_state parameter handling
  • Handles both cases when config exists and when it doesn't
  • Maintains backwards compatibility

514-514: LGTM! Consistent training argument configuration.

Setting "average_tokens_across_devices": False aligns with the device mesh parallelism changes mentioned in the PR and ensures consistent token handling across devices.

src/axolotl/integrations/liger/args.py (2)

31-37: LGTM! Type annotations modernized.

The change from Optional[bool] to bool | None follows modern Python type annotation practices and improves readability.


64-72: LGTM! Proper validation for incompatible configurations.

The new validator correctly prevents the incompatible combination of liger_rms_norm with tensor parallelism. The implementation:

  • Checks the logical conditions properly (liger_rms_norm enabled AND tensor_parallel_size > 1)
  • Provides a clear error message explaining the incompatibility
  • Includes a helpful reference URL for more context
  • Follows the established pattern of other validators in the class
src/axolotl/utils/ctx_managers/sequence_parallel.py (3)

8-8: LGTM!

The import of PartialState from accelerate is appropriate for accessing the device mesh used in the new registration method.


180-181: Parameter renaming looks good.

The renaming from sequence_parallel_degree to context_parallel_size is consistently applied in the docstring and throughout the class initialization.

Also applies to: 192-192, 199-199


243-249: Dimension name “cp” is correct and needs no change

The mesh_dim_names in src/axolotl/monkeypatch/ring_attn/patch.py is always set to either

  • ("dp", "cp") (data + column parallel)
  • ("fsdp", "cp") (FSDP + column parallel)

This guarantees that "cp" is a valid dimension in the device mesh. Passing sequence_parallel_dim=("cp",) therefore matches the patched mesh configuration and requires no updates.

src/axolotl/loaders/model.py (2)

16-17: Import changes look good.

The addition of PartialState and ParallelismConfig imports and consolidation of distributed utilities are appropriate for the new parallelism configuration functionality.

Also applies to: 52-52


646-652: Tensor parallel configuration looks good.

The integration with the device mesh from PartialState and the removal of incompatible device_map are handled correctly.

src/axolotl/monkeypatch/ring_attn/patch.py (3)

166-204: Parameter renaming is consistent.

The renaming from sequence_parallel_degree to context_parallel_size is properly applied throughout the function, including in assertions, calculations, and logging messages.


248-326: New device mesh-based registration function is well-implemented.

The function properly extracts the sequence parallel submesh with appropriate error handling and logging. The integration with different ring attention implementations is handled correctly.


383-427: Device mesh patching updates are correct.

The parameter renaming and dimension naming changes are consistent. Good documentation explaining the use of "cp" to match PyTorch native context parallelism naming convention.

src/axolotl/utils/schemas/validation.py (3)

689-689: Parameter renaming in GRPO validation is correct.

The check for GRPO + sequence parallel + Liger loss properly uses the new context_parallel_size parameter.


1237-1276: Sequence parallel validation updates are consistent.

All references to sequence_parallel_degree have been properly updated to context_parallel_size in method names, error messages, and logging statements.

Also applies to: 1280-1280


916-940: DeepSpeed configuration update logic improvement.

Good change to make the tensor parallel configuration conditional on having a DeepSpeed config file, rather than raising an error when missing. This provides better flexibility for users.

Comment thread src/axolotl/loaders/model.py Outdated
Comment thread src/axolotl/loaders/patch_manager.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/axolotl/loaders/patch_manager.py (1)

69-69: Re-enable or confirm removal of sequence parallel patches

This issue has already been raised in a previous review. The call to _apply_sequence_parallel_patches() is commented out, which disables sequence-parallelism patches.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b8bd173 and f7af069.

📒 Files selected for processing (3)
  • src/axolotl/loaders/patch_manager.py (3 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (7 hunks)
  • src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/monkeypatch/ring_attn/patch.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
  • patch_prepare_from_posids (77-87)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
  • patch_prepare_device_mesh (389-433)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py

83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (3)

52-52: LGTM! Correct placement of transformers patches.

The _apply_transformers_patches() call is appropriately placed at the beginning of the pre-model load patches sequence.


72-78: LGTM! Clean implementation of transformers patches.

The method correctly imports and applies the Flash Attention utility patch following the established pattern.


272-277: LGTM! Correct parameter renaming.

The method correctly updates all references from sequence_parallel_degree to context_parallel_size, maintaining the same logic while aligning with the PR's refactoring objectives.

src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)

12-74: LGTM! Well-implemented Flash Attention utility function.

The function correctly prepares tensors for Flash Attention with proper handling of position IDs and sequence lengths. The documentation and TorchDynamo compatibility note are particularly helpful.

Comment thread src/axolotl/monkeypatch/accelerate/fsdp2.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
train.yaml (3)

70-70: Add trailing newline to satisfy YAML lint
The linter error (no new line character at the end of file) can break some tooling that strictly enforces POSIX-style text files.

 # save_first_step: true  # uncomment this to validate checkpoint saving works with your config
+

30-33: Eliminate null-valued WandB keys to avoid silent mis-configuration
wandb_watch: and wandb_log_model: are currently set to explicit null (empty value). The Axolotl config loader passes these directly to the Weights-and-Biases SDK; null values can override sensible defaults and disable expected logging.

- wandb_watch:
- wandb_log_model:
+ # wandb_watch: gradients  # valid options: false | gradients | all
+ # wandb_log_model: checkpoint  # valid options: false | checkpoint | artifact

Confirm the intended behaviour or comment out the keys entirely.


25-27: Large sequence_len + sample packing may exhaust memory
A sequence_len of 4096 combined with sample_packing: true, flash-attention and gradient-checkpointing can easily exceed the memory footprint of 8-bit AdamW on 8 × 24 GB GPUs.

Consider:

  • Reducing sequence_len during warm-up or
  • Enabling q_lora / LoRA adapters
  • Increasing gradient_accumulation_steps to trade compute for memory.
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c894610 and b93c173.

📒 Files selected for processing (2)
  • src/axolotl/loaders/model.py (5 hunks)
  • train.yaml (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/loaders/model.py
🧰 Additional context used
🪛 YAMLlint (1.37.1)
train.yaml

[error] 70-70: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (2)
train.yaml (2)

44-46: Verify context_parallel_size against available GPUs and upstream patches
Setting context_parallel_size: 8 requires exactly eight devices per node (or an 8-way device mesh) and the new accelerate & transformers PRs referenced in the main description. Mismatching hardware or an un-patched dependency will raise runtime errors during mesh creation.

Double-check that:

  1. huggingface/accelerate PR 3682 is installed.
  2. You actually have 8 GPUs per process group, or override with CUDA_VISIBLE_DEVICES.
  3. dp_shard_size: 1 is intentional; otherwise the total world size becomes context_parallel_size * dp_shard_size.

58-60: Ensure placeholder special tokens exist in the tokenizer
<|finetune_right_pad_id|> and <|eot_id|> must already be present in the Qwen tokenizer vocab. If they are not, loading will fail or silently mis-align IDs.

If the base tokenizer lacks these tokens, add them via --add-special-tokens preprocessing or switch to the built-in <|endoftext|> / <|pad|> tokens.

Comment on lines 653 to +669
sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I missed this. Should we add this as an alias to context_parallel_size or make sure to add migration from the old config to new one in validator?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 need to raise deprecation warning and migrate configs

@NanoCode012
Copy link
Copy Markdown
Collaborator

Could we also have an doc that explains how to compose TP/Hybrid Sharding/CP?

Like, when to use TP, how to shard intra-node & replicate across nodes? I'll link it from README.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
src/axolotl/monkeypatch/ring_attn/patch.py (1)

254-332: Extract duplicated ring attention function setup logic.

The ring_attn_func handling logic (lines 306-331) is duplicated from the register_ring_attn function (lines 220-252). This violates the DRY principle and makes maintenance harder.

Extract the common logic into a separate function:

+def _setup_ring_attn_func(ring_attn_func: RingAttnFunc | None, heads_k_stride: int | None):
+    """Setup ring attention function implementation."""
+    if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
+        # fmt: off
+        # pylint: disable=protected-access
+        import transformers.modeling_flash_attention_utils
+        transformers.modeling_flash_attention_utils._flash_supports_window_size = (
+            transformers.modeling_flash_attention_utils._flash_supports_window
+        )
+
+        import ring_flash_attn.adapters.hf_adapter
+
+        from ring_flash_attn.adapters.hf_adapter import (  # isort: skip  # pylint: disable=unused-import
+            create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,
+        )
+
+        create_ring_flash_attention_forward_orig = (  # noqa: F811,F841
+            create_ring_flash_attention_forward
+        )
+        ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward
+        # fmt: on
+
+        ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(
+            process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
+        )
+    elif ring_attn_func is RingAttnFunc.BATCH_RING:
+        from axolotl.monkeypatch.ring_attn.adapters.batch import (
+            substitute_hf_flash_attn,
+        )
+
+        substitute_hf_flash_attn(
+            process_group=get_ring_attn_group(),
+            ring_attn_func=ring_attn_func,
+        )

Then use it in both functions:

-    if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
-        # ... (lines 220-242)
-    elif ring_attn_func is RingAttnFunc.BATCH_RING:
-        # ... (lines 243-251)
+    _setup_ring_attn_func(ring_attn_func, heads_k_stride)
♻️ Duplicate comments (1)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)

77-87: Remove redundant patching with setattr.

The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.

Apply this fix to remove the redundancy:

 def patch_prepare_from_posids():
     import transformers.modeling_flash_attention_utils
 
     transformers.modeling_flash_attention_utils._prepare_from_posids = (  # pylint: disable=protected-access
         _prepare_from_posids
     )
-    setattr(
-        sys.modules["transformers.modeling_flash_attention_utils"],
-        "_prepare_from_posids",
-        _prepare_from_posids,
-    )
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)

1221-1227: Consider moving the transformers patching to a more appropriate location.

The transformers attribute aliasing (lines 1221-1226) is performed during validation, which seems like an unexpected place for runtime patching. This patching appears to be a prerequisite for importing ring_flash_attn.

Consider moving this patching logic to the patch manager or a dedicated initialization module where other patches are applied, rather than in the validation logic.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d1b64c3 and 451c085.

📒 Files selected for processing (28)
  • docs/sequence_parallelism.qmd (3 hunks)
  • requirements.txt (1 hunks)
  • src/axolotl/cli/merge_lora.py (1 hunks)
  • src/axolotl/core/builders/base.py (3 hunks)
  • src/axolotl/core/builders/rl.py (1 hunks)
  • src/axolotl/core/trainers/grpo/__init__.py (1 hunks)
  • src/axolotl/core/trainers/grpo/args.py (1 hunks)
  • src/axolotl/core/trainers/grpo/sampler.py (4 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (8 hunks)
  • src/axolotl/core/trainers/mixins/checkpoints.py (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/loaders/model.py (5 hunks)
  • src/axolotl/loaders/patch_manager.py (3 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/__init__.py (2 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (7 hunks)
  • src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (5 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (5 hunks)
  • src/axolotl/utils/trainer.py (3 hunks)
  • tests/core/test_builders.py (1 hunks)
  • tests/e2e/multigpu/patched/test_sp.py (1 hunks)
  • tests/e2e/multigpu/solo/test_grpo.py (1 hunks)
  • tests/e2e/multigpu/test_fp8_fsdp2.py (2 hunks)
  • tests/e2e/patched/test_sp.py (6 hunks)
  • tests/e2e/utils.py (1 hunks)
✅ Files skipped from review due to trivial changes (7)
  • tests/e2e/multigpu/patched/test_sp.py
  • src/axolotl/core/trainers/grpo/args.py
  • src/axolotl/monkeypatch/ring_attn/init.py
  • docs/sequence_parallelism.qmd
  • src/axolotl/utils/trainer.py
  • tests/e2e/multigpu/solo/test_grpo.py
  • src/axolotl/core/builders/rl.py
🚧 Files skipped from review as they are similar to previous changes (18)
  • tests/core/test_builders.py
  • requirements.txt
  • src/axolotl/train.py
  • src/axolotl/cli/merge_lora.py
  • tests/e2e/utils.py
  • src/axolotl/core/trainers/grpo/init.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/loaders/patch_manager.py
  • tests/e2e/patched/test_sp.py
  • tests/e2e/multigpu/test_fp8_fsdp2.py
  • src/axolotl/core/trainers/mixins/checkpoints.py
  • src/axolotl/utils/ctx_managers/sequence_parallel.py
  • src/axolotl/monkeypatch/accelerate/fsdp2.py
  • src/axolotl/core/trainers/grpo/trainer.py
  • src/axolotl/loaders/model.py
  • src/axolotl/core/builders/base.py
  • src/axolotl/integrations/liger/args.py
  • src/axolotl/core/trainers/grpo/sampler.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
  • load (151-180)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py

83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
🔇 Additional comments (4)
src/axolotl/monkeypatch/ring_attn/patch.py (2)

165-252: LGTM! Consistent renaming and enhanced ring attention support.

The renaming from sequence_parallel_degree to context_parallel_size is consistently applied throughout the function. The addition of support for multiple ring attention implementations (VARLEN_LLAMA3 and BATCH_RING) enhances the flexibility of the system.


389-433: LGTM! Consistent renaming and alignment with PyTorch conventions.

The renaming from sequence_parallel_degree to context_parallel_size is properly implemented. Good decision to use "cp" instead of "sp" for the device mesh dimension name to align with PyTorch's native context parallelism implementation.

src/axolotl/utils/schemas/validation.py (2)

900-927: Verify the removal of DeepSpeed config requirement is intentional.

The validation no longer raises an error when tensor_parallel_size > 1 but DeepSpeed config is missing. This is a significant behavior change that makes the validation more permissive.

Please confirm this change is intentional and that tensor parallelism can now work without DeepSpeed configuration.


1205-1244: LGTM! Consistent renaming and improved error messages.

The renaming from check_sequence_parallel_degree to check_context_parallel_size is properly implemented with all references, error messages, and warnings updated to use the new terminology.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 451c085 and 55a68cb.

📒 Files selected for processing (3)
  • src/axolotl/loaders/model.py (5 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (1 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py
  • src/axolotl/loaders/model.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/fsdp2.py

260-260: Found useless expression. Either assign it to a variable or remove it.

(B018)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)

257-257: LGTM! Device mesh integration enhances parallelism support.

The addition of the mesh parameter properly integrates the device mesh from accelerator state into FSDP2 model preparation, aligning with the PR's parallelism configuration improvements.

Comment thread src/axolotl/monkeypatch/accelerate/fsdp2.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)

112-128: Enhanced logging function with improved memory reporting.

The updated log_gpu_memory_usage function leverages the new centralized memory retrieval and provides clearer logging:

  • Uses the new get_gpu_memory_usage helper
  • Updates variable names to match the new memory metrics (active, allocated, reserved)
  • Maintains backward compatibility with the existing logging interface

However, there's a minor issue with the conditional string formatting that can be simplified.

Apply this fix to simplify the conditional string formatting:

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7706fb2 and 3abe558.

📒 Files selected for processing (6)
  • src/axolotl/loaders/model.py (5 hunks)
  • src/axolotl/utils/bench.py (2 hunks)
  • src/axolotl/utils/callbacks/__init__.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (5 hunks)
  • train.yaml (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • train.yaml
🚧 Files skipped from review as they are similar to previous changes (3)
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/loaders/model.py
  • src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
src/axolotl/utils/callbacks/__init__.py (1)
src/axolotl/utils/bench.py (2)
  • get_gpu_memory_usage (96-109)
  • log_gpu_memory_usage (112-128)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py

124-124: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/utils/callbacks/__init__.py (2)

38-38: LGTM: Import statement correctly updated.

The addition of get_gpu_memory_usage import is properly aligned with its usage in the updated callback implementation.


111-123: Improved GPU memory monitoring with proper step handling.

The changes enhance GPU memory monitoring by:

  • Removing the one-time logging guard (logged attribute) for continuous monitoring
  • Adding detailed WandB metrics (active, allocated, reserved memory)
  • Maintaining existing debug logging functionality

The implementation correctly uses state.global_step > 0 to avoid logging on the initial step and properly gates WandB logging with both configuration and process checks.

src/axolotl/utils/bench.py (3)

60-64: Improved memory metrics with peak usage tracking.

The refactoring correctly updates the memory metrics to track:

  • Peak active memory from CUDA memory stats
  • Maximum allocated memory (peak allocation)
  • Maximum reserved memory (peak reservation)

This provides more meaningful monitoring data compared to current memory usage, especially for understanding memory pressure during training.


96-109: Well-designed centralized memory retrieval function.

The new get_gpu_memory_usage function effectively centralizes device-specific memory retrieval logic with proper error handling for unsupported devices. The function correctly:

  • Uses get_device_type() to determine the current device
  • Handles MPS, NPU, and CUDA devices appropriately
  • Raises a clear error for unsupported device types

102-102: Fix: Corrected device type check.

Good catch! The device type check was properly corrected from "gpu" to "cuda" to match the actual CUDA device type string.

Comment thread src/axolotl/utils/bench.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/utils/bench.py (1)

119-122: Critical issue: Incorrect variable usage persists.

The same bug identified in previous reviews still exists - both conditions incorrectly use reserved instead of the appropriate variables.

Apply this fix to correct the variable usage:

-    if allocated > 0:
-        extras.append(f"+{reserved:.03f}GB allocated")
-    if reserved > 0:
-        extras.append(f"+{reserved:.03f}GB reserved")
+    if allocated > 0:
+        extras.append(f"+{allocated:.03f}GB allocated")
+    if reserved > 0:
+        extras.append(f"+{reserved:.03f}GB reserved")
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)

123-123: Apply static analysis suggestion for cleaner conditional.

The static analysis tool correctly identifies a more readable conditional expression pattern.

Apply this improvement:

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3abe558 and fded751.

📒 Files selected for processing (1)
  • src/axolotl/utils/bench.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py

123-123: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/utils/bench.py (3)

60-63: LGTM: Improved memory reporting with peak metrics.

The refactoring from current memory usage to peak/max metrics provides more meaningful insights for performance analysis. The implementation correctly uses:

  • memory_stats().get("active_bytes.all.peak", 0) for peak active memory
  • max_memory_allocated() for maximum allocated memory
  • max_memory_reserved() for maximum reserved memory

95-108: Well-designed centralized memory retrieval function.

The new get_gpu_memory_usage function effectively centralizes device-specific memory logic with proper error handling for unsupported devices. The device type detection using get_device_type() is robust and the fallback error message is informative.


111-127: Good refactoring with enhanced logging capabilities.

The function refactoring successfully:

  • Uses the new centralized get_gpu_memory_usage function
  • Updates terminology to more accurate labels ("active", "allocated", "reserved")
  • Maintains backward compatibility with existing function signature
  • Improves logging with debug level and proper stack level

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)

123-123: Apply static analysis suggestion for improved readability.

The conditional expression can be simplified for better readability.

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfa16fa and ab642d5.

📒 Files selected for processing (1)
  • src/axolotl/utils/bench.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py

123-123: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/bench.py (2)

60-63: LGTM! Enhanced memory reporting with peak/max statistics.

The updated implementation provides more valuable memory insights by reporting peak active memory and maximum allocated/reserved memory instead of current usage. This is particularly useful for training monitoring and debugging memory issues.


111-127: LGTM! Function refactored with proper variable usage.

The refactoring successfully:

  • Uses the centralized get_gpu_memory_usage function
  • Adds proper type hints
  • Uses more descriptive variable names (active/allocated/reserved)
  • Correctly fixes the previous issue where both conditions used reserved instead of the appropriate variables

The debug-level logging is appropriate for detailed memory monitoring.

Comment thread src/axolotl/utils/bench.py
Comment thread src/axolotl/utils/bench.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/utils/bench.py (1)

95-109: Address previous review feedback on device detection.

The current implementation still has the device detection issues flagged in previous reviews:

  1. Fragile string comparisons: Lines 99 and 101 use string containment checks which are unreliable
  2. Inconsistent MPS detection: Line 97 checks MPS availability independently of device type
  3. Unnecessary ValueError: Based on winglian's feedback, raising errors may not be appropriate for this non-critical utility

Consider this approach based on previous feedback:

 def get_gpu_memory_usage(device: int | torch.device = 0):
-    cur_device_type = str(get_device_type())
-    if torch.backends.mps.is_available():
+    cur_device = get_device_type()
+    if cur_device.type == "mps":
         usage, cache, misc = mps_memory_usage_all()
-    elif "npu" in cur_device_type and is_torch_npu_available():
+    elif cur_device.type == "npu":
         usage, cache, misc = npu_memory_usage_all(device)
-    elif "cuda" in cur_device_type and torch.cuda.is_available():
+    elif cur_device.type == "cuda":
         usage, cache, misc = gpu_memory_usage_all(device)
     else:
-        raise ValueError(
-            f"Unable to determine memory statistics for current device {device}"
-        )
+        # Return default values for unsupported devices (CPU, etc.)
+        return 0.0, 0.0, 0.0

     return usage, cache, misc
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)

127-127: Apply style improvement suggested by static analysis.

The ternary operator can be simplified for better readability.

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ab642d5 and 21127c6.

📒 Files selected for processing (2)
  • src/axolotl/utils/bench.py (2 hunks)
  • src/axolotl/utils/callbacks/__init__.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/callbacks/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py

127-127: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (4)
src/axolotl/utils/bench.py (4)

60-63: LGTM: Improved memory metrics for benchmarking.

The change from current memory usage to peak/max memory metrics is a good improvement for benchmarking purposes. Using memory_stats().get("active_bytes.all.peak"), max_memory_allocated(), and max_memory_reserved() provides more valuable insights into actual memory requirements during training.


117-120: Good exception handling for unsupported devices.

The try-catch block properly handles cases where get_gpu_memory_usage() raises a ValueError (e.g., for CPU-only environments), gracefully returning without logging.


123-126: Variable usage bug has been correctly fixed.

The previous issue with using reserved for both allocated and reserved memory has been properly addressed. The code now correctly uses allocated for the allocated memory string and reserved for the reserved memory string.


111-131: Well-structured refactoring with improved separation of concerns.

The refactoring successfully centralizes memory usage logic and improves the logging function with:

  • Proper type hints for better code quality
  • Clear separation between memory retrieval and logging
  • Appropriate exception handling for unsupported devices
  • Updated variable names reflecting the new peak/max memory metrics

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
cicd/single_gpu.py (1)

68-72: LGTM! Consider making the process count configurable.

The explicit environment configuration for dataset processing is good for CI/CD consistency. The hardcoded value of "8" processes aligns with the broader parallelism improvements in this PR.

Consider making this configurable via an environment variable:

 sp_env = os.environ.copy()
-sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
+sp_env["AXOLOTL_DATASET_PROCESSES"] = os.environ.get("AXOLOTL_DATASET_PROCESSES", "8")
src/axolotl/loaders/model.py (1)

400-452: Excellent implementation of parallelism configuration calculation.

The static method correctly computes parallelism dimensions by systematically dividing the world size. The validation logic ensures compatibility and provides clear error messages for misconfigurations.

Consider simplifying the nested if statements as suggested by static analysis:

-        if dp_shard_size is None and dp_replicate_size in (None, 1):
-            if remaining_world_size > 1:
-                pc_kwargs["dp_shard_size"] = remaining_world_size
-                remaining_world_size = 1
+        if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1:
+            pc_kwargs["dp_shard_size"] = remaining_world_size
+            remaining_world_size = 1
-        if remaining_world_size > 1:
-            if "dp_shard_size" not in pc_kwargs and is_fsdp:
-                pc_kwargs["dp_shard_size"] = remaining_world_size
-                remaining_world_size = 1
+        if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp:
+            pc_kwargs["dp_shard_size"] = remaining_world_size
+            remaining_world_size = 1
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b2dee8e and 5b74f91.

📒 Files selected for processing (4)
  • cicd/single_gpu.py (1 hunks)
  • src/axolotl/loaders/model.py (5 hunks)
  • src/axolotl/utils/data/shared.py (1 hunks)
  • tests/test_loaders.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
cicd/single_gpu.py (1)

Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.

🧬 Code Graph Analysis (1)
tests/test_loaders.py (1)
src/axolotl/loaders/model.py (1)
  • _get_parallel_config_kwargs (401-452)
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py

420-421: Use a single if statement instead of nested if statements

(SIM102)


441-442: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
🔇 Additional comments (6)
src/axolotl/utils/data/shared.py (1)

433-437: Excellent refactoring! Improves code readability.

Replacing the magic number "8" with the descriptive min_rows_per_proc = 256 makes the intent much clearer. This change aligns well with the parallelism configuration improvements throughout the PR.

tests/test_loaders.py (1)

175-214: Excellent test coverage for parallelism configuration logic.

The parameterized test comprehensively covers various combinations of parallelism settings. The test cases effectively validate the _get_parallel_config_kwargs method's behavior for different world sizes and parallelism configurations.

src/axolotl/loaders/model.py (4)

16-17: Good addition of accelerate imports for parallelism support.

The imports for PartialState and ParallelismConfig are essential for the new parallelism configuration functionality.


184-192: Well-structured conditional logic for parallel configuration.

The logic correctly determines when to enable parallel configuration based on FSDP, tensor parallelism, and context parallelism settings, with proper handling of FSDP version constraints.


454-472: Solid integration with accelerate's parallelism infrastructure.

The method correctly creates a ParallelismConfig and device mesh, then properly configures the PartialState. This establishes the parallelism configuration early in the model loading process.


707-712: Proper tensor parallelism configuration in model building.

The code correctly sets up tensor parallelism parameters and removes the incompatible device_map when tensor parallelism is enabled, which prevents conflicts between device mapping and tensor parallel plans.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/axolotl/loaders/model.py (1)

408-460: Add validation for world size divisibility.

The method should validate that the world size is evenly divisible by each parallelism factor before performing division to prevent runtime errors.

 @staticmethod
 def _get_parallel_config_kwargs(
     world_size: int,
     tensor_parallel_size: int = 1,
     context_parallel_size: int = 1,
     dp_shard_size: int | None = None,
     dp_replicate_size: int | None = None,
     is_fsdp: bool = False,
 ):
     pc_kwargs = {}
     remaining_world_size = world_size

     if tensor_parallel_size and tensor_parallel_size > 1:
+        if remaining_world_size % tensor_parallel_size != 0:
+            raise ValueError(
+                f"World size ({world_size}) must be divisible by tensor_parallel_size ({tensor_parallel_size})"
+            )
         pc_kwargs["tp_size"] = tensor_parallel_size
         remaining_world_size = remaining_world_size // tensor_parallel_size

     if context_parallel_size and context_parallel_size > 1:
+        if remaining_world_size % context_parallel_size != 0:
+            raise ValueError(
+                f"Remaining world size ({remaining_world_size}) must be divisible by context_parallel_size ({context_parallel_size})"
+            )
         pc_kwargs["cp_size"] = context_parallel_size
         remaining_world_size = remaining_world_size // context_parallel_size

Also consider simplifying the nested if statements at lines 428-429 and 449-450:

-    if dp_shard_size is None and dp_replicate_size in (None, 1):
-        if remaining_world_size > 1:
+    if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1:
             pc_kwargs["dp_shard_size"] = remaining_world_size
             remaining_world_size = 1
🧹 Nitpick comments (8)
src/axolotl/utils/environment.py (1)

34-37: Consider adding error handling for missing packages.

The function may raise PackageNotFoundError if the package isn't installed. Consider documenting this behavior or adding error handling depending on the expected usage pattern.

 def get_package_version(package: str) -> Version:
+    """Get the installed version of a package.
+    
+    Args:
+        package: The name of the package to check.
+        
+    Returns:
+        The parsed version of the package.
+        
+    Raises:
+        PackageNotFoundError: If the package is not installed.
+    """
     version_str = version(package)
     return parse(version_str)
src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)

8-18: Document the temporary nature of this monkeypatch.

This patch sets protected attributes on the model, which suggests it's working around a limitation in the transformers library. Consider adding a comment explaining why this is needed and referencing the upstream PR (transformers #39622) that may eliminate the need for this patch.

 def distribute_model(model, distributed_config, device_mesh, tp_size):
+    """
+    Wrapper for transformers' distribute_model that adds tensor parallel metadata.
+    
+    This is a temporary workaround until transformers properly exposes TP information.
+    See: https://github.com/huggingface/transformers/pull/39622
+    """
     res = transformers.integrations.tensor_parallel.distribute_model(
src/axolotl/core/trainers/mixins/dist_parallel.py (1)

25-27: Document the upstream dependency for this workaround.

Consider adding a reference to the accelerate PR that will eliminate the need for this workaround.

                 # check for device mesh as we don't worry about this for DDP and it wouldn't be set
                 # and is only specific to older accelerate atm
+                # This workaround can be removed once accelerate PR #3682 is merged and released
                 if "device_mesh" in PartialState()._shared_state:
src/axolotl/integrations/liger/args.py (1)

68-83: Validators correctly enforce tensor parallelism constraints.

Both validators properly check for incompatibilities between Liger features and tensor parallelism. The reference to the GitHub issue in the first validator provides helpful context.

The TODO comment indicates this might need a more comprehensive fix. Would you like me to help investigate a more permanent solution or open an issue to track this technical debt?

src/axolotl/monkeypatch/accelerate/distributed.py (3)

5-5: Consider reducing pylint disable scope.

The broad pylint disable includes important checks like protected-access and inconsistent-return-statements. Consider:

  1. Addressing the underlying issues instead of disabling the checks
  2. Using more targeted inline disables where absolutely necessary
  3. Documenting why each disable is required

165-180: Add strict=True to zip for safer tuple unpacking.

The static analysis correctly identifies that zip() should use the strict parameter to ensure both sequences have the same length.

-            return tuple(zip(*sorted_items))
+            return tuple(zip(*sorted_items, strict=True))

214-219: Use more Pythonic dict membership check.

-            assert (
-                parallelism in self._sizes.keys()
-            ), f"Parallelism must be one of {self._sizes.keys()}"
+            assert (
+                parallelism in self._sizes
+            ), f"Parallelism must be one of {list(self._sizes.keys())}"
src/axolotl/loaders/model.py (1)

842-845: Simplify nested condition and document protected member access.

-        if self.cfg.tensor_parallel_size > 1:
-            if self.model._tp_size != self.cfg.tensor_parallel_size:
+        if self.cfg.tensor_parallel_size > 1 and self.model._tp_size != self.cfg.tensor_parallel_size:
                 self.model._tp_size = self.cfg.tensor_parallel_size

Consider adding a comment explaining why accessing the protected _tp_size attribute is necessary for tensor parallelism functionality.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b74f91 and 8433009.

📒 Files selected for processing (15)
  • requirements.txt (1 hunks)
  • setup.py (2 hunks)
  • src/axolotl/core/builders/base.py (3 hunks)
  • src/axolotl/core/trainers/base.py (2 hunks)
  • src/axolotl/core/trainers/mixins/__init__.py (1 hunks)
  • src/axolotl/core/trainers/mixins/dist_parallel.py (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/loaders/patch_manager.py (3 hunks)
  • src/axolotl/monkeypatch/accelerate/distributed.py (1 hunks)
  • src/axolotl/monkeypatch/transformers/tensor_parallel.py (1 hunks)
  • src/axolotl/utils/environment.py (2 hunks)
  • src/axolotl/utils/schemas/validation.py (5 hunks)
  • tests/e2e/multigpu/patched/test_sp.py (2 hunks)
  • tests/e2e/multigpu/test_tp.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • src/axolotl/core/trainers/mixins/init.py
  • src/axolotl/core/trainers/base.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tests/e2e/multigpu/patched/test_sp.py
  • requirements.txt
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/core/builders/base.py
  • src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/core/trainers/mixins/dist_parallel.py

20-22: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/monkeypatch/transformers/tensor_parallel.py

22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

src/axolotl/monkeypatch/accelerate/distributed.py

180-180: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


216-216: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

src/axolotl/loaders/model.py

428-429: Use a single if statement instead of nested if statements

(SIM102)


449-450: Use a single if statement instead of nested if statements

(SIM102)


842-843: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: pre-commit
🔇 Additional comments (7)
setup.py (1)

30-30: Version compatibility handling looks correct.

The conditional vllm version management based on PyTorch version is appropriate, especially given the incompatibility between vllm 0.9.x and the updated transformers version.

Also applies to: 72-76

tests/e2e/multigpu/test_tp.py (1)

13-64: Well-structured tensor parallelism test.

The test appropriately validates basic tensor parallelism functionality with a minimal configuration. The use of a small model and few training steps makes it suitable for CI/CD pipelines.

src/axolotl/integrations/liger/args.py (1)

31-37: LGTM! Modern type annotation syntax.

The update to use bool | None instead of Optional[bool] follows modern Python 3.10+ conventions and improves code readability.

src/axolotl/monkeypatch/accelerate/distributed.py (1)

182-212: Robust validation logic for parallelism configurations.

The validation comprehensively checks size constraints and enforces sensible limitations on parallelism combinations. The error messages provide clear guidance for users.

src/axolotl/loaders/model.py (3)

88-89: LGTM! Well-structured parallelism configuration attributes.

The class attributes are properly typed and initialized with sensible defaults.


187-200: Well-structured parallelism setup logic.

The method correctly determines when to enable parallelism configuration based on FSDP and parallel size settings, with appropriate version compatibility checks.


722-728: Correct tensor parallelism configuration for model initialization.

The setup properly configures tensor parallelism arguments and correctly removes incompatible device_map when using tp_plan.

Comment on lines +20 to +41
if not is_package_version_ge("accelerate", "1.10.0"):
# pylint: disable=protected-access
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
from accelerate.state import PartialState

# check for device mesh as we don't worry about this for DDP and it wouldn't be set
# and is only specific to older accelerate atm
if "device_mesh" in PartialState()._shared_state:
device_mesh: DeviceMesh = PartialState()._shared_state[
"device_mesh"
]
mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names
if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1:
self.accelerator.state.distributed_type = "TP"
PartialState().distributed_type = "TP"
tp_plugin = TorchTensorParallelPlugin(
tp_size=device_mesh["tp"].size(),
torch_device_mesh=device_mesh,
)
self.accelerator.state.torch_tp_plugin = tp_plugin

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider combining nested if statements and addressing Python compatibility.

The implementation has a few areas for improvement:

  1. The nested if statements can be combined as suggested by static analysis
  2. The type hint tuple[str, ...] | None requires Python 3.10+
-        if not is_package_version_ge("accelerate", "1.10.0"):
-            # pylint: disable=protected-access
-            if int(os.environ.get("WORLD_SIZE", "1")) > 1:
+        if (
+            not is_package_version_ge("accelerate", "1.10.0")
+            and int(os.environ.get("WORLD_SIZE", "1")) > 1
+        ):
+            # pylint: disable=protected-access
                 from accelerate.state import PartialState

                 # check for device mesh as we don't worry about this for DDP and it wouldn't be set
                 # and is only specific to older accelerate atm
                 if "device_mesh" in PartialState()._shared_state:
                     device_mesh: DeviceMesh = PartialState()._shared_state[
                         "device_mesh"
                     ]
-                    mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names
+                    mesh_dim_names = device_mesh.mesh_dim_names
                     if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not is_package_version_ge("accelerate", "1.10.0"):
# pylint: disable=protected-access
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
from accelerate.state import PartialState
# check for device mesh as we don't worry about this for DDP and it wouldn't be set
# and is only specific to older accelerate atm
if "device_mesh" in PartialState()._shared_state:
device_mesh: DeviceMesh = PartialState()._shared_state[
"device_mesh"
]
mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names
if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1:
self.accelerator.state.distributed_type = "TP"
PartialState().distributed_type = "TP"
tp_plugin = TorchTensorParallelPlugin(
tp_size=device_mesh["tp"].size(),
torch_device_mesh=device_mesh,
)
self.accelerator.state.torch_tp_plugin = tp_plugin
if (
not is_package_version_ge("accelerate", "1.10.0")
and int(os.environ.get("WORLD_SIZE", "1")) > 1
):
# pylint: disable=protected-access
from accelerate.state import PartialState
# check for device mesh as we don't worry about this for DDP and it wouldn't be set
# and is only specific to older accelerate atm
if "device_mesh" in PartialState()._shared_state:
device_mesh: DeviceMesh = PartialState()._shared_state[
"device_mesh"
]
- mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names
+ mesh_dim_names = device_mesh.mesh_dim_names
if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1:
self.accelerator.state.distributed_type = "TP"
PartialState().distributed_type = "TP"
tp_plugin = TorchTensorParallelPlugin(
tp_size=device_mesh["tp"].size(),
torch_device_mesh=device_mesh,
)
self.accelerator.state.torch_tp_plugin = tp_plugin
🧰 Tools
🪛 Ruff (0.12.2)

20-22: Use a single if statement instead of nested if statements

(SIM102)

🤖 Prompt for AI Agents
In src/axolotl/core/trainers/mixins/dist_parallel.py around lines 20 to 40,
combine the nested if statements checking the package version and WORLD_SIZE
into a single condition to simplify the logic. Also, replace the type hint
`tuple[str, ...] | None` with a syntax compatible with earlier Python versions,
such as using `Optional[Tuple[str, ...]]` and importing the necessary types from
the typing module to maintain compatibility below Python 3.10.

Comment thread src/axolotl/loaders/model.py
Comment on lines +20 to +26
def patch_tp_fix():
transformers.integrations.tensor_parallel.distribute_model = distribute_model
setattr(
sys.modules["transformers.integrations.tensor_parallel"],
"distribute_model",
distribute_model,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove redundant setattr call.

The static analysis correctly identifies that the setattr call is redundant. The direct assignment on line 21 is sufficient to patch the module.

 def patch_tp_fix():
     transformers.integrations.tensor_parallel.distribute_model = distribute_model
-    setattr(
-        sys.modules["transformers.integrations.tensor_parallel"],
-        "distribute_model",
-        distribute_model,
-    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def patch_tp_fix():
transformers.integrations.tensor_parallel.distribute_model = distribute_model
setattr(
sys.modules["transformers.integrations.tensor_parallel"],
"distribute_model",
distribute_model,
)
def patch_tp_fix():
transformers.integrations.tensor_parallel.distribute_model = distribute_model
🧰 Tools
🪛 Ruff (0.12.2)

22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/transformers/tensor_parallel.py between lines 20 and
26, remove the redundant setattr call that sets "distribute_model" on the module
since the direct assignment on line 21 already patches the module correctly.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
src/axolotl/utils/schemas/validation.py (1)

1280-1292: Include DistributedValidationMixin in ValidationMixin inheritance

The DistributedValidationMixin (defined at lines 1265–1273) provides the check_tensor_parallel_optimizer validator for distributed training, but it isn’t currently inherited by ValidationMixin (lines 1280–1292). Without it, distributed-specific validations won’t run.

File: src/axolotl/utils/schemas/validation.py
Location: class ValidationMixin definition

Suggested diff:

 class ValidationMixin(
     DatasetValidationMixin,
     AttentionValidationMixin,
     TrainingValidationMixin,
     LoRAValidationMixin,
     RLValidationMixin,
     OptimizationValidationMixin,
     SystemValidationMixin,
     ChatTemplateValidationMixin,
     PretrainingValidationMixin,
     ModelCompatibilityValidationMixin,
+    DistributedValidationMixin,
     ComplexValidationMixin,
 ):
♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)

408-460: Add validation for world size divisibility.

The method correctly calculates the data parallel replicate size, but should validate that the world size is evenly divisible by all parallelism factors to prevent runtime errors.


462-488: Replace print with logging and document protected member access.

Two issues to address:

  1. Line 474 uses print() instead of the logger
  2. Lines 481-487 access protected members of PartialState
🧹 Nitpick comments (3)
cicd/multigpu.sh (1)

24-25: Consider surfacing upload failures instead of silencing them

|| true neutralises any non-zero exit from codecov, which means genuine issues (e.g., malformed XML, network outages) won’t be visible in CI. If the upload is meant to be non-blocking, emit a warning so it’s still detectable in logs:

-  codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
+  if ! codecov upload-process -t "${CODECOV_TOKEN}" \
+       -f multigpu-coverage.xml \
+       -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}; then
+    echo "[WARN] Codecov upload failed but CI will proceed."
+  fi
src/axolotl/loaders/model.py (1)

841-847: Consider simplifying nested if statements.

The workaround logic is necessary and well-documented. Consider combining the nested if statements for cleaner code:

-        if self.cfg.tensor_parallel_size > 1:
-            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
-            # TODO(wing): remove once 4.54.1 is released
-            if self.model._tp_size != self.cfg.tensor_parallel_size:
+        if (self.cfg.tensor_parallel_size > 1 
+            and self.model._tp_size != self.cfg.tensor_parallel_size):
+            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
+            # TODO(wing): remove once 4.54.1 is released
src/axolotl/utils/schemas/validation.py (1)

1265-1277: LGTM! Important validation for tensor parallelism compatibility.

The new validation correctly prevents incompatible optimizer configurations with tensor parallelism. Consider simplifying the nested if statements:

-        if self.tensor_parallel_size > 1:
-            if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
+        if (self.tensor_parallel_size > 1 
+            and self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]):
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8433009 and 8e55df3.

📒 Files selected for processing (5)
  • cicd/multigpu.sh (1 hunks)
  • setup.py (2 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/utils/samplers/multipack.py (2 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/utils/samplers/multipack.py
🧰 Additional context used
🧠 Learnings (1)
src/axolotl/utils/schemas/validation.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py

428-429: Use a single if statement instead of nested if statements

(SIM102)


449-450: Use a single if statement instead of nested if statements

(SIM102)


842-845: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/utils/schemas/validation.py

1270-1271: Use a single if statement instead of nested if statements

(SIM102)

🔇 Additional comments (10)
cicd/multigpu.sh (1)

22-23: Guard clause for missing token is a solid reliability win

Wrapping the upload in an if [ -n "$CODECOV_TOKEN" ] block prevents noisy failures on forks / local runs where the secret isn’t present.

setup.py (1)

30-30: LGTM! Consistent pattern for version extraction.

The vllm version extraction follows the same pattern as xformers and autoawq versions.

src/axolotl/loaders/model.py (4)

16-16: LGTM! Import changes support new parallelism functionality.

The added imports for PartialState, ParallelismConfig, and get_world_size are necessary for the new parallelism configuration features.

Also applies to: 49-49, 52-52


88-89: LGTM! Clean attribute declarations for parallelism state.

The class attributes are properly typed and initialized for managing parallelism configuration.


187-200: LGTM! Sound logic for parallelism configuration activation.

The conditions for enabling parallelism configuration are well-defined and the FSDP version 2 requirement is appropriately handled.


722-727: LGTM! Proper tensor parallelism setup with conflict resolution.

The tensor parallelism configuration correctly sets the required parameters and appropriately removes conflicting device_map settings.

src/axolotl/utils/schemas/validation.py (4)

676-676: LGTM! Parameter rename aligns with codebase-wide changes.

The change from sequence_parallel_degree to context_parallel_size is consistent with the broader parameter renaming effort across the codebase.


903-926: LGTM! More flexible DeepSpeed configuration handling.

The conditional logic appropriately handles cases where no DeepSpeed configuration file is present, making the validation more robust.


1205-1244: LGTM! Comprehensive parameter rename with consistent updates.

The method rename and all parameter references have been consistently updated from sequence_parallel_degree to context_parallel_size. The transformers patching logic for flash attention compatibility is appropriately maintained.


1248-1248: LGTM! Consistent parameter reference update.

The parameter reference update maintains consistency with the broader renaming effort.

Comment thread setup.py
Comment on lines +72 to +75
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop(_install_requires.index(vllm_version))
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
extras_require_map["vllm"] = ["vllm>=0.10.0"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix syntax error in vllm version handling.

Line 73 has a critical syntax error - extras_require_map.pop() expects a dictionary key, but _install_requires.index(vllm_version) returns an integer index.

-            # vllm 0.9.x is incompatible with latest transformers
-            extras_require_map.pop(_install_requires.index(vllm_version))
+            # vllm 0.9.x is incompatible with latest transformers
+            _install_requires.pop(_install_requires.index(vllm_version))

This should remove vllm from the install requirements (similar to xformers on line 69) rather than from the extras_require_map.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop(_install_requires.index(vllm_version))
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
extras_require_map["vllm"] = ["vllm>=0.10.0"]
# vllm 0.9.x is incompatible with latest transformers
_install_requires.pop(_install_requires.index(vllm_version))
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
🤖 Prompt for AI Agents
In setup.py around lines 72 to 76, the code incorrectly uses
extras_require_map.pop() with an integer index instead of a dictionary key,
causing a syntax error. To fix this, remove the vllm version from the
_install_requires list using _install_requires.remove(vllm_version) instead of
popping from extras_require_map. This aligns with the intended behavior of
removing vllm from install requirements rather than from extras_require_map.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)

408-460: Method implementation looks correct, but validation issue remains unaddressed.

The parallelism configuration calculation logic is well-structured and handles various scenarios appropriately. However, the past review comment about adding validation for world size divisibility by each parallelism factor has not been addressed.


462-488: Print statement and protected member access issues remain unaddressed.

The parallelism configuration setup logic is correct, but the previously identified issues with using print() instead of logging and accessing protected members of PartialState have not been resolved.

🧹 Nitpick comments (3)
.github/workflows/multi-gpu-e2e.yml (1)

46-48: Minor: quote scalar to stay consistent with other string values

All other scalar strings in the matrix ("3.11", "true") are quoted. Quoting vllm keeps style consistent and guards against accidental boolean parsing.

-            axolotl_extras: vllm
+            axolotl_extras: "vllm"
src/axolotl/loaders/model.py (1)

841-847: LGTM! Necessary workaround with minor optimization opportunity.

The workaround for the upstream transformers 4.54.0 issue is appropriate and well-documented. The TODO comment ensures it will be removed when no longer needed.

Consider simplifying the nested if statement as suggested by static analysis:

-        if self.cfg.tensor_parallel_size > 1:
-            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
-            # TODO(wing): remove once 4.54.1 is released
-            if self.model._tp_size != self.cfg.tensor_parallel_size:
-                self.model._tp_size = self.cfg.tensor_parallel_size
-                self.model._device_mesh = self.model_kwargs["device_mesh"]
+        if (
+            self.cfg.tensor_parallel_size > 1
+            and self.model._tp_size != self.cfg.tensor_parallel_size
+        ):
+            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
+            # TODO(wing): remove once 4.54.1 is released
+            self.model._tp_size = self.cfg.tensor_parallel_size
+            self.model._device_mesh = self.model_kwargs["device_mesh"]
src/axolotl/utils/schemas/validation.py (1)

1265-1276: LGTM! New tensor parallelism optimizer validation with minor optimization opportunity.

The new validation mixin and check_tensor_parallel_optimizer method appropriately prevent incompatible optimizer configurations with tensor parallelism. The logic correctly identifies optimizers that don't work with tensor parallelism.

Consider simplifying the nested if statement as suggested by static analysis:

-    def check_tensor_parallel_optimizer(self):
-        if self.tensor_parallel_size > 1:
-            if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
-                raise ValueError(
-                    "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
-                )
+    def check_tensor_parallel_optimizer(self):
+        if (
+            self.tensor_parallel_size > 1
+            and self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]
+        ):
+            raise ValueError(
+                "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
+            )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6852319 and 5351801.

📒 Files selected for processing (6)
  • .github/workflows/multi-gpu-e2e.yml (1 hunks)
  • cicd/multigpu.sh (1 hunks)
  • setup.py (1 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/utils/samplers/multipack.py (2 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/utils/samplers/multipack.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • cicd/multigpu.sh
  • setup.py
🧰 Additional context used
🧠 Learnings (2)
.github/workflows/multi-gpu-e2e.yml (1)

Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.

src/axolotl/utils/schemas/validation.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
  • load (154-183)
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py

428-429: Use a single if statement instead of nested if statements

(SIM102)


449-450: Use a single if statement instead of nested if statements

(SIM102)


842-845: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/utils/schemas/validation.py

1270-1271: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (8)
src/axolotl/loaders/model.py (4)

16-16: LGTM! Import additions support new parallelism features.

The new imports for PartialState, ParallelismConfig, and consolidated get_world_size are appropriate for the parallelism configuration functionality being added.

Also applies to: 49-49, 52-52


88-89: LGTM! New class attributes are properly typed.

The new parallelism configuration attributes are correctly typed and initialized with appropriate default values to manage the parallelism state.


187-200: LGTM! Parallelism configuration logic is well-structured.

The conditional logic properly determines when parallelism configuration should be enabled based on FSDP, tensor parallel, and context parallel settings. The check to disable parallelism for non-FSDP v2 configurations is appropriate.


722-728: LGTM! Tensor parallelism configuration is properly implemented.

The logic correctly configures tensor parallelism by setting the appropriate model kwargs and removing the incompatible device_map. Using tp_plan: "auto" is a good approach for automatic tensor parallel planning.

src/axolotl/utils/schemas/validation.py (4)

676-676: LGTM! Parameter rename is consistent with the refactoring.

The change from sequence_parallel_degree to context_parallel_size maintains the same validation logic while aligning with the systematic parameter renaming throughout the codebase.


903-926: LGTM! Improved robustness by making DeepSpeed config updates conditional.

The addition of the conditional check ensures the DeepSpeed configuration is only updated when a DeepSpeed config is actually present, preventing errors when tensor parallelism is used without DeepSpeed. This is a sensible improvement to the validation logic.


1205-1244: LGTM! Method renamed consistently with preserved validation logic.

The systematic renaming from check_sequence_parallel_degree to check_context_parallel_size maintains all the existing validation logic while using the new parameter naming convention. The validation requirements for flash attention, micro batch size constraints, and ring_flash_attn imports are properly preserved.


1248-1248: LGTM! Consistent parameter rename in ring attention validation.

The change maintains the same validation logic while using the updated parameter name, consistent with the systematic refactoring.

Comment thread .github/workflows/multi-gpu-e2e.yml
@winglian winglian force-pushed the ndp branch 3 times, most recently from 5351801 to 94a6a84 Compare July 27, 2025 21:35
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (7)
src/axolotl/loaders/patch_manager.py (1)

69-69: Re-enable or confirm removal of sequence parallel patches.

The call to _apply_sequence_parallel_patches() remains commented out, which disables sequence-parallelism patches and makes context_parallel_size ineffective. This was previously flagged in past reviews.

Please confirm:

  1. Is disabling sequence-parallelism intentional and permanent?
  2. If not, where are these patches now applied?
  3. If you still need context_parallel_size functionality, please re-enable:
-        # self._apply_sequence_parallel_patches()
+        self._apply_sequence_parallel_patches()
src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)

20-26: Remove redundant setattr call.

The static analysis correctly identifies that the setattr call is redundant. The direct assignment on line 21 is sufficient to patch the module.

 def patch_tp_fix():
     transformers.integrations.tensor_parallel.distribute_model = distribute_model
-    setattr(
-        sys.modules["transformers.integrations.tensor_parallel"],
-        "distribute_model",
-        distribute_model,
-    )
src/axolotl/core/trainers/mixins/dist_parallel.py (1)

20-22: Combine nested if statements and address Python compatibility.

The implementation has a few areas for improvement:

  1. The nested if statements can be combined as suggested by static analysis
  2. The type hint tuple[str, ...] | None requires Python 3.10+
-        if not is_package_version_ge("accelerate", "1.10.0"):
-            # pylint: disable=protected-access
-            if int(os.environ.get("WORLD_SIZE", "1")) > 1:
+        if (
+            not is_package_version_ge("accelerate", "1.10.0")
+            and int(os.environ.get("WORLD_SIZE", "1")) > 1
+        ):
+            # pylint: disable=protected-access
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)

77-87: Remove redundant patching with setattr.

The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.

 def patch_prepare_from_posids():
     import transformers.modeling_flash_attention_utils

     transformers.modeling_flash_attention_utils._prepare_from_posids = (  # pylint: disable=protected-access
         _prepare_from_posids
     )
-    setattr(
-        sys.modules["transformers.modeling_flash_attention_utils"],
-        "_prepare_from_posids",
-        _prepare_from_posids,
-    )
src/axolotl/loaders/model.py (3)

408-465: Static method for computing parallelism config kwargs is well-structured.

The _get_parallel_config_kwargs method properly computes and validates parallelism parameters by sequentially dividing the world size. The validation ensures the combined parallelism matches the world size exactly.


478-478: Replace print with logging.


485-490: Document protected member access.

🧹 Nitpick comments (6)
src/axolotl/core/builders/base.py (1)

439-462: Consider simplifying the parallelism configuration logic.

The current implementation has complex nested checks for parallelism_config presence. While functional, it could be more readable and maintainable.

Consider refactoring:

def _configure_accelerator_config(self, training_args_kwargs: dict):
    partial_state = PartialState()
-    has_pc_attr = (
-        hasattr(partial_state, "parallelism_config")
-        and partial_state.parallelism_config
-    )
-    has_pc_key = (
-        "parallelism_config"
-        in partial_state._shared_state  # pylint: disable=protected-access
-        and partial_state._shared_state[  # pylint: disable=protected-access
-            "parallelism_config"
-        ]
-    )
-    use_configured_state = has_pc_attr or has_pc_key
+    
+    def _has_parallelism_config():
+        return (
+            (hasattr(partial_state, "parallelism_config") and partial_state.parallelism_config) or
+            (partial_state._shared_state.get("parallelism_config"))  # pylint: disable=protected-access
+        )
+    
+    use_configured_state = _has_parallelism_config()
    
    if self.cfg.accelerator_config:
        use_configured_state = self.cfg.accelerator_config.pop(
            "use_configured_state", use_configured_state
        )
        training_args_kwargs["accelerator_config"] = AcceleratorConfig(
            use_configured_state=use_configured_state, **self.cfg.accelerator_config
        )
    else:
        training_args_kwargs["accelerator_config"] = AcceleratorConfig(
            use_configured_state=use_configured_state,
        )
src/axolotl/utils/bench.py (2)

104-106: Consider removing ValueError for non-critical utility.

Based on past feedback, raising ValueErrors may not be necessary for this non-critical utility. Consider returning default values or silently handling unsupported devices instead.

-        raise ValueError(
-            f"Unable to determine memory statistics for current device {device}"
-        )
+        # Return default values for unsupported devices
+        return 0.0, 0.0, 0.0

111-131: LGTM - Good refactoring with minor style improvement opportunity.

The refactored function correctly uses the new helper, handles errors gracefully, and provides clearer logging format.

Consider the static analysis suggestion for a simpler ternary expression:

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
src/axolotl/integrations/liger/args.py (1)

68-83: Validators correctly enforce tensor parallelism constraints.

The new validators properly prevent incompatible configurations between LIGER components and tensor parallelism. The TODO comment on line 80 indicates this might be part of a larger compatibility issue that needs investigation.

Would you like me to help investigate the broader compatibility issues between tensor parallelism and liger losses mentioned in the TODO?

src/axolotl/monkeypatch/accelerate/distributed.py (1)

1-220: Comprehensive fallback implementation for ParallelismConfig.

The fallback implementation provides a complete interface for parallelism configuration with proper validation logic. The class structure and methods are well-designed to support various parallelism strategies.

Consider addressing the static analysis hints:

-            return tuple(zip(*sorted_items))
+            return tuple(zip(*sorted_items, strict=False))
-            assert (
-                parallelism in self._sizes.keys()
-            ), f"Parallelism must be one of {self._sizes.keys()}"
+            assert (
+                parallelism in self._sizes
+            ), f"Parallelism must be one of {list(self._sizes.keys())}"
src/axolotl/utils/schemas/validation.py (1)

1265-1278: New distributed validation mixin correctly enforces optimizer constraints.

The DistributedValidationMixin properly validates that tensor parallelism is incompatible with certain 8-bit optimizers (paged_adamw_8bit, adamw_8bit, adamw_bnb_8bit).

Consider simplifying the nested if statement as suggested by static analysis:

-    def check_tensor_parallel_optimizer(self):
-        if self.tensor_parallel_size > 1:
-            if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
-                raise ValueError(
-                    "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
-                )
+    def check_tensor_parallel_optimizer(self):
+        if (self.tensor_parallel_size > 1 and 
+            self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]):
+            raise ValueError(
+                "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
+            )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 87b7217 and 94a6a84.

📒 Files selected for processing (43)
  • cicd/single_gpu.py (1 hunks)
  • docs/sequence_parallelism.qmd (3 hunks)
  • setup.py (1 hunks)
  • src/axolotl/cli/merge_lora.py (1 hunks)
  • src/axolotl/core/builders/base.py (2 hunks)
  • src/axolotl/core/builders/rl.py (1 hunks)
  • src/axolotl/core/trainers/base.py (2 hunks)
  • src/axolotl/core/trainers/grpo/__init__.py (1 hunks)
  • src/axolotl/core/trainers/grpo/args.py (1 hunks)
  • src/axolotl/core/trainers/grpo/sampler.py (4 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (8 hunks)
  • src/axolotl/core/trainers/mixins/__init__.py (1 hunks)
  • src/axolotl/core/trainers/mixins/checkpoints.py (1 hunks)
  • src/axolotl/core/trainers/mixins/dist_parallel.py (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/loaders/patch_manager.py (3 hunks)
  • src/axolotl/monkeypatch/accelerate/distributed.py (1 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/__init__.py (2 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (7 hunks)
  • src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1 hunks)
  • src/axolotl/monkeypatch/transformers/tensor_parallel.py (1 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/bench.py (2 hunks)
  • src/axolotl/utils/callbacks/__init__.py (2 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (5 hunks)
  • src/axolotl/utils/data/shared.py (1 hunks)
  • src/axolotl/utils/environment.py (2 hunks)
  • src/axolotl/utils/samplers/multipack.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • src/axolotl/utils/trainer.py (3 hunks)
  • tests/core/test_builders.py (1 hunks)
  • tests/e2e/multigpu/patched/test_sp.py (2 hunks)
  • tests/e2e/multigpu/solo/test_grpo.py (1 hunks)
  • tests/e2e/multigpu/test_fp8_fsdp2.py (2 hunks)
  • tests/e2e/multigpu/test_tp.py (1 hunks)
  • tests/e2e/patched/test_sp.py (6 hunks)
  • tests/e2e/test_load_model.py (1 hunks)
  • tests/e2e/utils.py (1 hunks)
  • tests/test_loaders.py (1 hunks)
  • train.yaml (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • src/axolotl/core/trainers/mixins/init.py
  • src/axolotl/utils/samplers/multipack.py
🚧 Files skipped from review as they are similar to previous changes (26)
  • src/axolotl/cli/merge_lora.py
  • src/axolotl/monkeypatch/accelerate/fsdp2.py
  • tests/e2e/test_load_model.py
  • docs/sequence_parallelism.qmd
  • src/axolotl/utils/data/shared.py
  • src/axolotl/core/trainers/grpo/args.py
  • tests/core/test_builders.py
  • tests/e2e/utils.py
  • src/axolotl/core/trainers/mixins/checkpoints.py
  • tests/e2e/patched/test_sp.py
  • src/axolotl/train.py
  • cicd/single_gpu.py
  • tests/e2e/multigpu/solo/test_grpo.py
  • src/axolotl/monkeypatch/ring_attn/init.py
  • src/axolotl/core/trainers/grpo/sampler.py
  • src/axolotl/core/trainers/grpo/trainer.py
  • src/axolotl/utils/trainer.py
  • tests/e2e/multigpu/test_fp8_fsdp2.py
  • tests/test_loaders.py
  • src/axolotl/core/trainers/grpo/init.py
  • tests/e2e/multigpu/patched/test_sp.py
  • src/axolotl/core/builders/rl.py
  • setup.py
  • train.yaml
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/ctx_managers/sequence_parallel.py
🧰 Additional context used
🧠 Learnings (1)
src/axolotl/utils/schemas/validation.py (1)

Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

🧬 Code Graph Analysis (8)
src/axolotl/core/trainers/base.py (1)
src/axolotl/core/trainers/mixins/dist_parallel.py (1)
  • DistParallelMixin (12-41)
src/axolotl/loaders/patch_manager.py (4)
src/axolotl/integrations/base.py (2)
  • cfg (350-351)
  • cfg (354-355)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
  • patch_prepare_from_posids (77-87)
src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)
  • patch_tp_fix (20-26)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
  • patch_prepare_device_mesh (389-433)
tests/e2e/multigpu/test_tp.py (2)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/utils.py (2)
  • check_tensorboard (149-163)
  • require_torch_2_7_0 (80-89)
src/axolotl/utils/callbacks/__init__.py (1)
src/axolotl/utils/bench.py (2)
  • get_gpu_memory_usage (95-108)
  • log_gpu_memory_usage (111-131)
src/axolotl/core/builders/base.py (1)
tests/e2e/patched/test_sp.py (1)
  • partial_state (25-28)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
src/axolotl/core/trainers/mixins/dist_parallel.py (2)
src/axolotl/utils/environment.py (1)
  • is_package_version_ge (39-41)
src/axolotl/core/trainers/base.py (1)
  • create_accelerator_and_postprocess (515-525)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
  • load (154-183)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/tensor_parallel.py

22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

src/axolotl/utils/bench.py

127-127: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

src/axolotl/core/trainers/mixins/dist_parallel.py

20-22: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/monkeypatch/accelerate/distributed.py

180-180: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


216-216: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

src/axolotl/loaders/model.py

432-433: Use a single if statement instead of nested if statements

(SIM102)


453-454: Use a single if statement instead of nested if statements

(SIM102)


846-849: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py

83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

src/axolotl/utils/schemas/validation.py

1270-1271: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (24)
src/axolotl/core/trainers/base.py (2)

30-30: LGTM: Clean integration of distributed parallelism mixin.

The import and integration of DistParallelMixin is correctly implemented.


54-54: LGTM: Proper mixin inheritance order.

The DistParallelMixin is correctly positioned in the inheritance chain, allowing it to extend the trainer's accelerator creation capabilities for distributed training.

tests/e2e/multigpu/test_tp.py (1)

16-64: LGTM: Well-structured tensor parallelism integration test.

This test effectively validates tensor parallelism functionality with:

  • Proper version requirements (@require_torch_2_7_0)
  • Appropriate test configuration (small model, limited steps for quick execution)
  • Multi-process execution setup matching the tensor_parallel_size
  • Validation through TensorBoard logs

The test configuration and execution approach follow established patterns from the existing test suite.

src/axolotl/loaders/patch_manager.py (3)

52-52: LGTM: Appropriate integration of transformers patches.

The new _apply_transformers_patches() method properly applies Flash Attention and tensor parallelism fixes from the transformers monkeypatch modules.


75-84: LGTM: Clean implementation of transformers patching.

The new method properly imports and applies the required transformers patches for Flash Attention and tensor parallelism support.


279-284: LGTM: Consistent parameter renaming.

The renaming from sequence_parallel_degree to context_parallel_size is consistent with the broader refactoring across the codebase and aligns with PyTorch's native context parallelism terminology.

src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)

8-17: LGTM - Clean monkey patch implementation.

The wrapped distribute_model function correctly calls the original implementation and adds the necessary tensor parallel metadata (_tp_size and _device_mesh) to the model instance. The pylint disable comments are appropriate for this monkey patch context.

src/axolotl/core/trainers/mixins/dist_parallel.py (1)

23-40: LGTM - Solid tensor parallelism setup logic.

The implementation correctly handles backward compatibility with older accelerate versions by manually configuring tensor parallelism. The device mesh inspection and plugin setup follow the expected patterns for accelerate integration.

src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)

12-74: LGTM - Comprehensive flash attention utility implementation.

The _prepare_from_posids function is well-implemented with:

  • Comprehensive documentation explaining all parameters and return values
  • Proper tensor handling with contiguous views and device management
  • Good awareness of TorchDynamo compatibility requirements
  • Sound logic for extracting cumulative sequence lengths and position handling
src/axolotl/utils/callbacks/__init__.py (2)

38-38: LGTM - Correct import addition.

The import of get_gpu_memory_usage is needed for the enhanced GPU memory logging functionality implemented in the callback.


111-126: LGTM - Enhanced GPU memory monitoring.

The updated callback implementation provides better monitoring capabilities:

  • Continuous logging after the first step gives better visibility into memory patterns
  • WandB integration with structured metrics (active, allocated, reserved) provides detailed tracking
  • Proper error handling for unsupported devices
  • Correct process-zero checking for WandB logging
src/axolotl/utils/bench.py (1)

60-63: LGTM - Improved memory metrics for monitoring.

The change to peak/max memory metrics (active bytes peak, max allocated, max reserved) provides better insights into memory usage patterns during training compared to current usage snapshots.

src/axolotl/integrations/liger/args.py (1)

31-37: Type annotation updates look good.

The migration from Optional[bool] to bool | None syntax is consistent with modern Python type hints.

src/axolotl/monkeypatch/ring_attn/patch.py (4)

166-196: Parameter renaming is consistent and well-documented.

The renaming from sequence_parallel_degree to context_parallel_size is properly implemented with updated docstrings and logging messages.


220-252: Ring attention function selection logic is properly implemented.

The conditional logic for selecting between VARLEN_LLAMA3 and BATCH_RING implementations based on the ring_attn_func parameter is well-structured. The monkeypatch approach for overriding the flash attention forward function is appropriate for this use case.


254-332: New DeviceMesh-based registration function is well-designed.

The register_ring_attn_from_device_mesh function provides a clean interface for creating ring attention groups from existing device meshes. Error handling for missing dimensions is appropriate, and the logging provides good visibility into the configuration.


389-433: Device mesh patching correctly uses "cp" dimension name.

The update to use "cp" (context parallelism) instead of "sp" aligns with PyTorch's native naming conventions as noted in the comments. The mesh shape construction properly uses context_parallel_size.

src/axolotl/loaders/model.py (3)

88-90: Parallelism configuration attributes are properly initialized.

The new attributes use_parallel_config and parallelism_config are correctly typed and initialized.


726-732: Tensor parallelism setup in model kwargs is correct.

The code properly configures tensor parallelism parameters and removes incompatible device_map when tensor parallelism is enabled.


846-852: Workaround for upstream transformers issue is documented.

The workaround for transformers 4.54.0 not setting _tp_size or _device_mesh is properly documented with a TODO comment for future removal.

src/axolotl/utils/schemas/validation.py (4)

676-679: GRPO validation correctly uses context_parallel_size.

The validation properly checks for incompatibility between GRPO, Liger loss, and context parallelism using the renamed parameter.


900-927: DeepSpeed tensor parallelism config update is now conditional.

The logic now only proceeds if a DeepSpeed config path is present, which is more robust than the previous implementation that would error if missing.


1205-1244: Context parallel size validation is comprehensive.

The validation properly:

  1. Sets default value to 1 if not specified
  2. Requires flash attention when context_parallel_size > 1
  3. Enforces micro_batch_size=1 with sample packing due to ring-flash-attn requirements
  4. Patches transformers before importing ring_flash_attn
  5. Provides helpful error messages and warnings

1247-1263: Ring attention function validation properly uses context_parallel_size.

The validation correctly checks context_parallel_size and sets appropriate defaults based on sample packing configuration.

Comment on lines +34 to +41
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)


def is_package_version_ge(package: str, version_: str) -> bool:
package_version = get_package_version(package)
return package_version >= parse(version_)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for missing packages.

The functions don't handle the case where a package is not installed, which will cause importlib.metadata.version() to raise a PackageNotFoundError. This could lead to unexpected crashes in calling code.

Consider adding error handling:

def get_package_version(package: str) -> Version:
-    version_str = version(package)
-    return parse(version_str)
+    try:
+        version_str = version(package)
+        return parse(version_str)
+    except PackageNotFoundError:
+        raise ValueError(f"Package '{package}' is not installed")

def is_package_version_ge(package: str, version_: str) -> bool:
-    package_version = get_package_version(package)
-    return package_version >= parse(version_)
+    try:
+        package_version = get_package_version(package)
+        return package_version >= parse(version_)
+    except ValueError:
+        return False  # Or handle appropriately based on use case

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/utils/environment.py around lines 34 to 41, the functions
get_package_version and is_package_version_ge do not handle the case where the
specified package is not installed, causing importlib.metadata.version() to
raise a PackageNotFoundError. To fix this, add try-except blocks around the
calls to version() to catch PackageNotFoundError and handle it gracefully, such
as returning None or raising a custom error, and update is_package_version_ge to
handle this case accordingly to avoid unexpected crashes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)

462-488: Replace print with logging and document protected member access.

Two issues to address:

  1. Line 474 uses print() instead of the logger
  2. Lines 481-487 access protected members of PartialState
         if pc_kwargs:
-            print(pc_kwargs)
+            LOG.info("Parallelism configuration: %s", pc_kwargs)
             self.parallelism_config = ParallelismConfig(
                 **pc_kwargs,
             )

Consider adding a comment explaining why protected member access is necessary:

# We need to access PartialState's protected members to properly configure
# the parallelism state before model initialization. This is required for
# tensor parallelism to work correctly with accelerate.

408-460: Missing divisibility validation and code structure improvements needed.

Two issues to address:

  1. Missing divisibility validation: The past review comment about adding validation for world size divisibility by parallelism factors hasn't been implemented. This could lead to runtime errors.

  2. Simplify nested if statements: Static analysis suggests combining nested if statements for better readability.

Apply these improvements:

 @staticmethod
 def _get_parallel_config_kwargs(
     world_size: int,
     tensor_parallel_size: int = 1,
     context_parallel_size: int = 1,
     dp_shard_size: int | None = None,
     dp_replicate_size: int | None = None,
     is_fsdp: bool = False,
 ):
     pc_kwargs = {}
     remaining_world_size = world_size

     if tensor_parallel_size and tensor_parallel_size > 1:
+        if remaining_world_size % tensor_parallel_size != 0:
+            raise ValueError(
+                f"World size ({remaining_world_size}) must be divisible by tensor_parallel_size ({tensor_parallel_size})"
+            )
         pc_kwargs["tp_size"] = tensor_parallel_size
         remaining_world_size = remaining_world_size // tensor_parallel_size

     if context_parallel_size and context_parallel_size > 1:
+        if remaining_world_size % context_parallel_size != 0:
+            raise ValueError(
+                f"Remaining world size ({remaining_world_size}) must be divisible by context_parallel_size ({context_parallel_size})"
+            )
         pc_kwargs["cp_size"] = context_parallel_size
         remaining_world_size = remaining_world_size // context_parallel_size

-    if dp_shard_size is None and dp_replicate_size in (None, 1):
-        if remaining_world_size > 1:
+    if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1:
             pc_kwargs["dp_shard_size"] = remaining_world_size
             remaining_world_size = 1

     if dp_replicate_size and dp_replicate_size > 1:
+        if remaining_world_size % dp_replicate_size != 0:
+            raise ValueError(
+                f"Remaining world size ({remaining_world_size}) must be divisible by dp_replicate_size ({dp_replicate_size})"
+            )
         pc_kwargs["dp_replicate_size"] = dp_replicate_size
         remaining_world_size = remaining_world_size // dp_replicate_size

-    if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
+    if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
+        if remaining_world_size % dp_shard_size != 0:
+            raise ValueError(
+                f"Remaining world size ({remaining_world_size}) must be divisible by dp_shard_size ({dp_shard_size})"
+            )
         if not is_fsdp:
             raise ValueError(
                 "dp_shard_size was configured without a corresponding fsdp_config! "
                 "Please ensure you have configured FSDP using fsdp_config."
             )
         pc_kwargs["dp_shard_size"] = dp_shard_size
         remaining_world_size = remaining_world_size // dp_shard_size
         if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
             pc_kwargs["dp_replicate_size"] = remaining_world_size
             remaining_world_size = 1

-    if remaining_world_size > 1:
-        if "dp_shard_size" not in pc_kwargs and is_fsdp:
+    if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp:
             pc_kwargs["dp_shard_size"] = remaining_world_size
             remaining_world_size = 1
🧹 Nitpick comments (1)
src/axolotl/loaders/model.py (1)

841-848: Simplify nested if statement.

The workaround logic is correct, but the nested if statements can be combined for better readability as suggested by static analysis.

-        if self.cfg.tensor_parallel_size > 1:
-            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
-            # TODO(wing): remove once 4.54.1 is released
-            if self.model._tp_size != self.cfg.tensor_parallel_size:
+        # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
+        # TODO(wing): remove once 4.54.1 is released
+        if self.cfg.tensor_parallel_size > 1 and self.model._tp_size != self.cfg.tensor_parallel_size:
                 self.model._tp_size = self.cfg.tensor_parallel_size
                 self.model._device_mesh = self.model_kwargs["device_mesh"]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 94a6a84 and c44e565.

📒 Files selected for processing (1)
  • src/axolotl/loaders/model.py (7 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py

428-429: Use a single if statement instead of nested if statements

(SIM102)


449-450: Use a single if statement instead of nested if statements

(SIM102)


842-845: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/model.py (4)

16-16: Import changes look good.

The new imports for PartialState, ParallelismConfig, and get_world_size are appropriate for the parallelism functionality being added.

Also applies to: 49-49, 52-52


88-89: Class attributes properly typed.

The new parallelism-related class attributes are well-defined with appropriate type annotations and default values.


187-200: Parallelism configuration logic is well-structured.

The conditional logic correctly determines when to enable parallelism configuration based on FSDP, tensor parallelism, and context parallelism settings. The FSDP version check is a good safeguard.


722-728: Tensor parallelism setup is correct.

The logic properly configures tensor parallelism by setting the required model kwargs and removing incompatible device_map. The conditional check ensures it only applies when tensor parallelism is enabled.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9dad77 and f46cf48.

📒 Files selected for processing (2)
  • cicd/multigpu.sh (1 hunks)
  • requirements.txt (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
🔇 Additional comments (1)
cicd/multigpu.sh (1)

5-5: Good call adding --durations=10 to surface slow tests
This will automatically highlight the ten slowest tests in CI logs, making it easier to spot performance regressions without extra tooling.

Comment thread requirements.txt Outdated
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.54.0
transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Pin transformers to an immutable commit or tag to guarantee reproducible installs

Pointing to winglian/transformers.git@v4.54.1-maybe relies on a mutable branch/tag.
If that ref is force-pushed or deleted, every fresh environment will break or pick up different code, making binary wheels & training runs non-reproducible. Pin to a commit SHA (or an immutable annotated tag) and document why this fork is required.

Example fix:

-transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe
+# TODO: replace <commit-sha> with the exact commit you depend on
+transformers @ git+https://github.com/winglian/transformers.git@<commit-sha>

Also verify the fork remains compatible with accelerate==1.9.0; the upstream PR 3682 suggests you may need a custom accelerate build as well.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "🔍 Verifying ‘v4.54.1-maybe’ ref in winglian/transformers:"
git ls-remote --refs https://github.com/winglian/transformers.git \
  refs/tags/v4.54.1-maybe refs/heads/v4.54.1-maybe || true

echo
echo "🔍 Checking for any ‘accelerate’ pins in requirements.txt:"
grep -n "accelerate" requirements.txt || true

Length of output: 523


Pin transformers to an exact commit SHA

You’re currently tracking the mutable branch v4.54.1-maybe, which can be force-pushed or deleted at any time—breaking reproducibility. Replace it with the precise commit hash we just looked up:

-requirements.txt:16
-transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe
+requirements.txt:16
+transformers @ git+https://github.com/winglian/transformers.git@d274901a61d5ff2977072ada1738bd38ff14be70

• Add a comment (or README entry) explaining why you’re using this fork.
• Verify that this commit remains compatible with accelerate==1.9.0 (pinned on line 18).

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In requirements.txt at line 16, replace the current transformers dependency that
tracks the mutable branch v4.54.1-maybe with the exact commit SHA to ensure
reproducibility. Add a comment explaining why this specific fork is used. Also,
verify that this commit is compatible with accelerate==1.9.0 pinned on line 18.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
baseten/run.sh (1)

2-2: Add pipefail for safer error-handling in pipelines

set -eux does not cover failures that occur inside a pipeline (cmd1 | cmd2).
Using set -euo pipefail (or explicitly set -o pipefail) prevents silent data loss by exiting on the first failing command in any pipeline.

-set -eux
+set -euo pipefail
+set -x        # keep the `-x` (debug) behaviour
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f46cf48 and 1337db6.

📒 Files selected for processing (1)
  • baseten/run.sh (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
baseten/run.sh (1)

Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit

Comment thread baseten/run.sh Outdated
Comment on lines +8 to +10
# if node rank 0
axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

axolotl preprocess will be executed on every node – guard with node-rank check

The comment says “if node rank 0” but the command is unconditional.
When launched via torchrun, every node/container will attempt to write to the same $BT_CHECKPOINT_DIR, leading to corruption or a noisy/slow preprocess phase.

-# if node rank 0
-axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared
+# Only the leader should run preprocessing
+if [[ "${BT_NODE_RANK:-0}" == "0" ]]; then
+  axolotl preprocess train.yaml \
+    --output-dir="$BT_CHECKPOINT_DIR" \
+    --dataset-prepared-path="${BT_CHECKPOINT_DIR}/last_run_prepared"
+fi

(Uses parameter expansion so the script still works when BT_NODE_RANK is undefined in single-node setups.)

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# if node rank 0
axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared
# Only the leader should run preprocessing
if [[ "${BT_NODE_RANK:-0}" == "0" ]]; then
axolotl preprocess train.yaml \
--output-dir="$BT_CHECKPOINT_DIR" \
--dataset-prepared-path="${BT_CHECKPOINT_DIR}/last_run_prepared"
fi
🤖 Prompt for AI Agents
In baseten/run.sh around lines 8 to 10, the axolotl preprocess command runs
unconditionally on all nodes, risking concurrent writes to $BT_CHECKPOINT_DIR.
Wrap the preprocess command in a conditional that checks if BT_NODE_RANK is
unset or equals 0, using parameter expansion to handle single-node setups
safely. This ensures only the node with rank 0 executes the preprocess step,
preventing directory corruption and redundant processing.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
src/axolotl/utils/bench.py (1)

95-107: Improve device detection consistency.

The device type detection logic has inconsistencies that should be addressed:

  1. Line 97 checks MPS availability independently of the device type from line 96
  2. String containment checks on lines 99 and 101 are fragile
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)

77-87: Remove redundant patching with setattr.

The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.

src/axolotl/loaders/model.py (1)

408-460: Add validation for world size divisibility.

The method should validate that world_size is evenly divisible by each parallelism factor before performing division to prevent runtime errors.

src/axolotl/utils/schemas/validation.py (2)

1213-1218: Monkey-patching transformers internals needs documentation and version guards.

This unconditional monkey patch of transformers.modeling_flash_attention_utils._flash_supports_window_size was previously flagged but remains unaddressed. The patch runs on every transformers version and may break if internals change.

Please add:

  • Version check or attribute existence guard before applying the patch
  • Documentation explaining why this patch is needed and which transformers versions require it
  • Clear specification of supported transformers versions in setup.py/pyproject.toml

1257-1268: Critical: New validation mixin not included in main ValidationMixin.

The DistributedValidationMixin contains important tensor parallel + 8-bit optimizer compatibility validation, but it's not included in the main ValidationMixin class (lines 1272-1284), so it won't be executed.

Add DistributedValidationMixin to the main ValidationMixin class:

 class ValidationMixin(
     DatasetValidationMixin,
     AttentionValidationMixin,
     TrainingValidationMixin,
     LoRAValidationMixin,
     RLValidationMixin,
     OptimizationValidationMixin,
     SystemValidationMixin,
     ChatTemplateValidationMixin,
     PretrainingValidationMixin,
     ModelCompatibilityValidationMixin,
     ComplexValidationMixin,
+    DistributedValidationMixin,
 ):

Also, simplify the nested if statements:

 @model_validator(mode="after")
 def check_tensor_parallel_optimizer(self):
-    if self.tensor_parallel_size > 1:
-        if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
-            raise ValueError(
-                "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
-            )
+    if (self.tensor_parallel_size > 1 and 
+        self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]):
+        raise ValueError(
+            "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
+        )
🧹 Nitpick comments (2)
src/axolotl/utils/bench.py (1)

125-125: Simplify ternary expression.

-    msg = f"{cur_device_type} memory active:" if not msg else msg
+    msg = msg if msg else f"{cur_device_type} memory active:"
src/axolotl/loaders/model.py (1)

841-847: Simplify nested conditional and document the workaround.

-        if self.cfg.tensor_parallel_size > 1:
-            # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
-            # TODO(wing): remove once 4.54.1 is released
-            if self.model._tp_size != self.cfg.tensor_parallel_size:
+        # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
+        # TODO(wing): remove once 4.54.1 is released
+        if (self.cfg.tensor_parallel_size > 1 and 
+            self.model._tp_size != self.cfg.tensor_parallel_size):
                self.model._tp_size = self.cfg.tensor_parallel_size
                self.model._device_mesh = self.model_kwargs["device_mesh"]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b739ee and f8df5bf.

📒 Files selected for processing (48)
  • cicd/multigpu.sh (1 hunks)
  • cicd/single_gpu.py (1 hunks)
  • docs/sequence_parallelism.qmd (3 hunks)
  • examples/alst/llama3-8b-deepspeed-alst.yaml (1 hunks)
  • requirements.txt (1 hunks)
  • setup.py (1 hunks)
  • src/axolotl/cli/merge_lora.py (1 hunks)
  • src/axolotl/core/builders/base.py (2 hunks)
  • src/axolotl/core/builders/rl.py (1 hunks)
  • src/axolotl/core/trainers/base.py (2 hunks)
  • src/axolotl/core/trainers/dpo/trainer.py (2 hunks)
  • src/axolotl/core/trainers/grpo/__init__.py (1 hunks)
  • src/axolotl/core/trainers/grpo/args.py (1 hunks)
  • src/axolotl/core/trainers/grpo/sampler.py (4 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (10 hunks)
  • src/axolotl/core/trainers/mamba.py (1 hunks)
  • src/axolotl/core/trainers/mixins/__init__.py (1 hunks)
  • src/axolotl/core/trainers/mixins/checkpoints.py (1 hunks)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py (1 hunks)
  • src/axolotl/core/trainers/trl.py (5 hunks)
  • src/axolotl/integrations/kd/trainer.py (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/__init__.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (3 hunks)
  • src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/bench.py (2 hunks)
  • src/axolotl/utils/callbacks/__init__.py (2 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (5 hunks)
  • src/axolotl/utils/data/shared.py (1 hunks)
  • src/axolotl/utils/environment.py (2 hunks)
  • src/axolotl/utils/samplers/multipack.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • src/axolotl/utils/trainer.py (3 hunks)
  • tests/core/test_builders.py (1 hunks)
  • tests/e2e/multigpu/patched/test_sp.py (2 hunks)
  • tests/e2e/multigpu/solo/test_grpo.py (1 hunks)
  • tests/e2e/multigpu/test_fp8_fsdp2.py (2 hunks)
  • tests/e2e/multigpu/test_tp.py (1 hunks)
  • tests/e2e/patched/test_sp.py (0 hunks)
  • tests/e2e/test_load_model.py (1 hunks)
  • tests/e2e/utils.py (1 hunks)
  • tests/test_loaders.py (1 hunks)
  • tests/utils/schemas/validation/test_fsdp.py (0 hunks)
💤 Files with no reviewable changes (2)
  • tests/utils/schemas/validation/test_fsdp.py
  • tests/e2e/patched/test_sp.py
✅ Files skipped from review due to trivial changes (3)
  • src/axolotl/integrations/kd/trainer.py
  • tests/e2e/multigpu/patched/test_sp.py
  • src/axolotl/core/trainers/dpo/trainer.py
🚧 Files skipped from review as they are similar to previous changes (38)
  • examples/alst/llama3-8b-deepspeed-alst.yaml
  • src/axolotl/core/trainers/mamba.py
  • cicd/multigpu.sh
  • src/axolotl/core/trainers/mixins/init.py
  • src/axolotl/core/trainers/grpo/args.py
  • docs/sequence_parallelism.qmd
  • src/axolotl/utils/samplers/multipack.py
  • src/axolotl/core/trainers/mixins/checkpoints.py
  • src/axolotl/cli/merge_lora.py
  • tests/core/test_builders.py
  • src/axolotl/train.py
  • src/axolotl/core/trainers/base.py
  • cicd/single_gpu.py
  • tests/e2e/test_load_model.py
  • src/axolotl/core/builders/rl.py
  • requirements.txt
  • src/axolotl/monkeypatch/accelerate/fsdp2.py
  • tests/e2e/multigpu/solo/test_grpo.py
  • tests/e2e/utils.py
  • src/axolotl/core/trainers/grpo/init.py
  • tests/e2e/multigpu/test_tp.py
  • src/axolotl/utils/environment.py
  • src/axolotl/loaders/patch_manager.py
  • setup.py
  • src/axolotl/utils/trainer.py
  • tests/e2e/multigpu/test_fp8_fsdp2.py
  • src/axolotl/core/trainers/mixins/distributed_parallel.py
  • src/axolotl/core/builders/base.py
  • src/axolotl/core/trainers/grpo/trainer.py
  • src/axolotl/core/trainers/grpo/sampler.py
  • src/axolotl/monkeypatch/ring_attn/init.py
  • tests/test_loaders.py
  • src/axolotl/core/trainers/trl.py
  • src/axolotl/utils/callbacks/init.py
  • src/axolotl/utils/ctx_managers/sequence_parallel.py
  • src/axolotl/utils/data/shared.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/integrations/liger/args.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: grpo (cfg.rl == "grpo" or cfg.rl is rltype.grpo) should be excluded from dataset label checking duri...
Learnt from: NanoCode012
PR: axolotl-ai-cloud/axolotl#2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.

Applied to files:

  • src/axolotl/utils/schemas/validation.py
🧬 Code Graph Analysis (4)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
  • get_device_type (20-28)
src/axolotl/loaders/model.py (1)
src/axolotl/utils/distributed.py (3)
  • get_device_count (31-37)
  • get_device_type (20-28)
  • get_world_size (103-104)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
src/axolotl/utils/schemas/enums.py (1)
  • RingAttnFunc (84-92)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
  • load (154-183)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py

125-125: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg

Replace with msg if msg else f"{cur_device_type} memory active:"

(SIM212)

src/axolotl/loaders/model.py

428-429: Use a single if statement instead of nested if statements

(SIM102)


449-450: Use a single if statement instead of nested if statements

(SIM102)


841-844: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py

83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

src/axolotl/utils/schemas/validation.py

1262-1263: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: preview
🔇 Additional comments (8)
src/axolotl/utils/bench.py (1)

60-63: LGTM! More comprehensive memory metrics.

The updated return values provide better visibility into GPU memory usage with distinct metrics for active, allocated, and reserved memory.

src/axolotl/monkeypatch/ring_attn/patch.py (1)

137-188: Excellent refactoring to use DeviceMesh abstraction.

The migration from manual process group creation to using PyTorch's DeviceMesh is a significant improvement. The error handling and logging are well-implemented.

src/axolotl/loaders/model.py (2)

187-200: LGTM! Clear parallelism configuration setup.

The conditional logic properly determines when to use parallel configuration based on FSDP, tensor parallelism, and context parallelism settings.


462-487: Well-structured parallelism configuration.

The method properly sets up the parallelism configuration and device mesh in PartialState. The protected member access is necessary for proper integration with accelerate.

src/axolotl/utils/schemas/validation.py (4)

676-676: LGTM! Parameter renamed consistently with refactoring effort.

The update from sequence_parallel_degree to context_parallel_size aligns with the broader codebase refactoring mentioned in the AI summary.


888-912: LGTM! Improved conditional handling of DeepSpeed configuration.

The updated logic now gracefully handles cases where no DeepSpeed configuration is present, making tensor parallelism setup more flexible. The temporary config file generation and settings updates are correctly implemented.


1192-1236: LGTM! Well-handled parameter renaming with backward compatibility.

The method rename from check_sequence_parallel_degree to check_context_parallel_size is implemented correctly with:

  • Proper backward compatibility for the deprecated parameter
  • Clear deprecation warning to guide users
  • Consistent validation logic maintained
  • Updated error messages using the new parameter name

This follows best practices for API migration.


1240-1240: LGTM! Consistent parameter name update.

The update to use context_parallel_size instead of sequence_parallel_degree is consistent with the broader refactoring effort.

@winglian winglian merged commit 294c7fe into main Jul 31, 2025
17 checks passed
@winglian winglian deleted the ndp branch July 31, 2025 19:25
@coderabbitai coderabbitai Bot mentioned this pull request Mar 11, 2026
@coderabbitai coderabbitai Bot mentioned this pull request Apr 30, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants