Skip to content

Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp#3019

Merged
winglian merged 25 commits into
mainfrom
accelerate-cp
Aug 8, 2025
Merged

Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp#3019
winglian merged 25 commits into
mainfrom
accelerate-cp

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Aug 6, 2025

Description

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for configurable "reasoning/thinking" fields in chat template prompts and dataset schemas.
    • Introduced new utility functions for advanced distributed and parallel training configurations.
    • Added new distributed training example configs for Llama 3.1 8B and Qwen3-8B models.
    • Added a README for distributed parallel training examples.
    • Enabled support for the "gpt_oss" model type in multipack batching.
  • Improvements

    • Enhanced model saving to support PEFT and safetensors.
    • Improved device mesh and parallelism handling for distributed training.
    • Updated dependency versions and added new dependencies for improved compatibility.
    • Expanded FSDP saving logic and device mapping for efficient distributed training.
    • Made padding and position ID handling more flexible and explicit.
    • Added monkeypatch to support parallelism config for standalone context parallelism.
    • Refined parameter distribution and broadcasting for FSDP state dict loading.
    • Added environment variable setup for parallelism configurations.
    • Updated training execution to include device mesh information.
    • Improved internal handling of accelerator and parallelism state.
    • Added support for resetting accelerator state and improved accelerator creation logic.
    • Enhanced parallelism config validation and patching for context parallelism.
    • Improved internal device mesh usage in sequence parallel context management.
    • Extended FSDP patching to handle parallelism config and device mesh correctly.
  • Bug Fixes

    • Corrected architecture mapping for "gpt_oss" models.
    • Added validation to prevent unsupported FSDP CPU RAM efficient loading with certain quantizations.
  • Documentation

    • Added detailed documentation and configuration examples for distributed parallel training.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Aug 6, 2025

📝 Walkthrough

Walkthrough

This update introduces new distributed training configuration files and a README, adds support for configurable "thinking" fields in chat templates and datasets, refactors distributed parallelism and device mesh setup by removing PartialState usage, improves model saving and FSDP checkpoint handling, updates dependency versions, and enhances batching and collator logic for position IDs.

Changes

Cohort / File(s) Change Summary
Distributed Parallel Examples
examples/distributed-parallel/README.md, examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml, examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Added distributed parallel training example configs and documentation for Llama 3.1 8B and Qwen3-8B, including FSDP, tensor, and context parallelism settings.
GPT-OSS Example Configs
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml, examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml, examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Updated dataset to "HuggingFaceH4/Multilingual-Thinking", added field_thinking and template_thinking_key, tweaked warmup ratios, added "<
Dependency and Requirements Update
requirements.txt
Updated versions: accelerate (to 1.10.0), trl (to 0.21.0), gradio (to 5.41.1), axolotl-contribs-mit (to 0.0.5); added trackio.
Distributed Parallelism Refactor
src/axolotl/loaders/model.py, src/axolotl/core/builders/base.py, src/axolotl/core/trainers/base.py, src/axolotl/utils/distributed.py, src/axolotl/utils/trainer.py, src/axolotl/train.py
Removed PartialState usage, centralized parallelism config building, improved device mesh handling, added environment setup for parallelism, and refined FSDP logic.
Chat Template and Dataset Configurability
src/axolotl/prompt_strategies/chat_template.py, src/axolotl/utils/schemas/datasets.py
Made "thinking" field configurable in chat templates and datasets by adding field_thinking and template_thinking_key parameters and schema fields.
FSDP2 State Dict Loading
src/axolotl/monkeypatch/accelerate/fsdp2.py
Enhanced sharded parameter distribution and broadcasting logic for FSDP2 full state dict loading.
Batching and Collator Improvements
src/axolotl/utils/collators/batching.py
Added squash_position_ids to collator for handling position ID squashing across packed sequences.
Multipack Model Support
src/axolotl/monkeypatch/multipack.py
Added "gpt_oss" to supported multipack model types.
Architectures Mapping Update
src/axolotl/common/architectures.py
Changed MOE architecture mapping for "gpt_oss" to use "GptOssDecoderLayer".
Causal Builder Padding Logic
src/axolotl/core/builders/causal.py
Made padding logic more explicit by checking for None in pad_to_sequence_len.
Test Updates
tests/test_loaders.py
Updated test to reference new location of _get_parallel_config_kwargs after refactor.
Accelerate Parallelism Config Monkeypatch
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Added monkeypatch to support pure context parallelism (CP) in ParallelismConfig validation, including environment variable override and compatibility checks.
Model Compatibility Validation
src/axolotl/utils/schemas/validation.py
Added validator to disallow FSDP CPU RAM efficient loading when using Mxfp4 quantization config.
Patch Manager Update
src/axolotl/loaders/patch_manager.py
Added call to new patch_parallelism_config monkeypatch when FSDP v2 and fsdp_config are used.
Sequence Parallel Context Manager Update
src/axolotl/utils/ctx_managers/sequence_parallel.py
Removed PartialState usage and added device_mesh parameter to context manager for sequence parallelism.
DistributedParallelMixin Update
src/axolotl/core/trainers/mixins/distributed_parallel.py
Added method to override accelerator distributed type from "FSDP" to "MULTI_GPU" when no FSDP plugin is present, handling context parallelism without FSDP.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

  • Complexity: Significant refactor of distributed parallelism logic, new configuration parameters, dependency updates, and multiple new example configs. Some changes are additive, but core logic for device mesh and state dict handling is non-trivial.

Possibly related PRs

Suggested labels

scheduled_release, ready to merge

Suggested reviewers

  • SalmanMohammadi
  • NanoCode012

📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ed5deba and b8cd55c.

📒 Files selected for processing (1)
  • src/axolotl/core/trainers/base.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/core/trainers/base.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: preview
  • GitHub Check: pre-commit
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch accelerate-cp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Aug 6, 2025

📖 Documentation Preview: https://68951fafe37cc23d5d02d203--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit b8cd55c

@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 6, 2025

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)

46-47: Incomplete special_tokens configuration.

The special_tokens section is defined but empty. Consider either removing it if not needed or adding the appropriate token configurations for Qwen3.

🧹 Nitpick comments (2)
src/axolotl/utils/trainer.py (1)

600-617: LGTM! Clean parallelism environment setup.

The new setup_parallelism_envs function properly configures environment variables for distributed parallel training. The logic correctly sets variables only when parallelism sizes are greater than 1 and enables the parallelism configuration system appropriately.

Consider adding type hints and docstring for better maintainability:

-def setup_parallelism_envs(cfg):
+def setup_parallelism_envs(cfg) -> None:
+    """
+    Configure environment variables for parallelism settings.
+    
+    Sets environment variables for tensor parallel, data parallel shard/replicate,
+    and context parallel sizes when they are greater than 1.
+    
+    Args:
+        cfg: Configuration object containing parallelism settings.
+    """
examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml (1)

30-47: Training configuration looks good.

The hyperparameters are reasonable for distributed training. Note that the effective batch size per GPU is 1 (micro_batch_size × gradient_accumulation_steps), which is fine for testing but you may want to increase it for production training.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 42f5e6f and 571a89a.

📒 Files selected for processing (8)
  • examples/distirbuted-parallel/README.md (1 hunks)
  • examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml (1 hunks)
  • examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml (1 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/core/trainers/base.py (3 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/trainer.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (3)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
  • _save (13-20)
src/axolotl/models/mamba/modeling_mamba.py (1)
  • save_pretrained (111-119)
src/axolotl/utils/distributed.py (1)
  • is_main_process (79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: preview
🔇 Additional comments (15)
requirements.txt (2)

18-18: Accelerate commit hash verified

The commit 6891c57072295eec14d9866ca34472d59708db9c exists in the huggingface/accelerate repo, is dated 2025-08-05, and its message (“Feat: context parallel v2.0 (#3700)…” ) confirms the expected parallelism updates. No issues detected.


69-69: Confirmed axolotl-contribs-mit==0.0.5 on PyPI

  • Version 0.0.5 is published and is the latest release.
  • No upload timestamp is provided in the metadata.
  • No security advisories are referenced in the PyPI project metadata.

No further action required.

src/axolotl/train.py (1)

277-277: Approve FSDP save-condition update

Verified that the expanded check (trainer.is_fsdp_enabled or cfg.fsdp_config) matches existing FSDP-config patterns elsewhere in the codebase—no inconsistencies detected. No changes required.

Key FSDP-config usages for reference:

  • src/axolotl/train.py: final-state‐dict saving
  • src/axolotl/core/builders/base.py: passing fsdp_config into training_args_kwargs
  • src/axolotl/utils/trainer.py: setting FSDP environment variables
  • src/axolotl/loaders/model.py: is_fsdp_enabled property based on cfg.fsdp_config
  • src/axolotl/loaders/patch_manager.py: applying FSDP patches when cfg.fsdp_config is present
src/axolotl/utils/trainer.py (1)

637-637: Good integration with existing environment setup.

The placement of setup_parallelism_envs(cfg) at the end of prepare_optim_env ensures parallelism variables are set after FSDP and DeepSpeed configurations, which is the correct order.

src/axolotl/core/builders/base.py (1)

447-452: LGTM: Simplified accelerator configuration

  • Confirmed that src/axolotl/core/builders/base.py no longer contains any PartialState inspection or use_configured_state logic.
  • Noticed that the trainer still reads
    use_configured_state = accelerator_config.get("use_configured_state", False)
    in src/axolotl/core/trainers/base.py. Ensure your environment‐based parallelism setup drives this flag as intended.
  • Recommend a quick smoke test of your training loop both with and without setting accelerator_config.use_configured_state in your config to verify existing workflows remain unchanged.
examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml (4)

1-4: LGTM!

The base model and CutCrossEntropyPlugin configuration are appropriate for memory-efficient distributed training.


6-9: LGTM!

The parallelism configuration correctly sets up hybrid parallelism with 8 GPUs total (2 × 2 × 2).


16-22: LGTM!

FSDP v2 configuration is correctly set up with appropriate wrapping for LlamaDecoderLayer and memory-efficient resharding.


11-12: LGTM!

Dataset and output directory configuration are standard and appropriate.

Also applies to: 24-28

examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml (3)

6-9: LGTM!

Parallelism configuration correctly sets up context parallelism alongside tensor and data parallelism for 8 GPUs total.


13-19: LGTM!

FSDP configuration correctly specifies Qwen3DecoderLayer for the Qwen3 model architecture.


27-32: LGTM!

Correctly configured with longer sequence length (8192) appropriate for Qwen3, and micro_batch_size=1 as required for context parallelism.

src/axolotl/core/trainers/base.py (3)

13-29: LGTM!

New imports are appropriate for the enhanced model saving functionality with safetensors and PEFT support.


522-543: LGTM!

The accelerator state reset logic correctly ensures that Accelerate configures itself from environment variables for distributed parallel training, while preserving the parallelism_config for downstream use.


610-669: Duplicate saving of training arguments.

The training arguments are saved twice when the model is a supported class:

  1. Inside the else block (lines 649-668)
  2. Again at line 669 which is outside the else block

Line 669 should be removed as it's redundant.

Apply this fix:

             self.data_collator.tokenizer.save_pretrained(output_dir)
-            # Good practice: save your training arguments together with the trained model
-            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+        # Good practice: save your training arguments together with the trained model
+        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

Likely an incorrect or invalid review comment.

Comment thread examples/distirbuted-parallel/README.md
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)

46-47: Dangling special_tokens key produces null value

special_tokens: with no following mapping or list parses as YAML null.
This overrides any default special-token handling in Axolotl and can trigger key-look-up errors during tokenizer setup.

Either remove the key or supply an explicit mapping, e.g.:

-special_tokens:
+special_tokens: {}
🧹 Nitpick comments (1)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)

31-33: Micro-batch size of 1 severely limits throughput

Context parallelism forces micro_batch_size: 1, but with sequence length 8192 and world-size 8 you may hit under-utilisation.
Consider enabling gradient accumulation (gradient_accumulation_steps > 1) or increasing dp_shard_size if memory allows, to improve device utilisation and optimiser step quality.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cd4339a and 8b52919.

📒 Files selected for processing (9)
  • examples/distributed-parallel/README.md (1 hunks)
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml (1 hunks)
  • examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/core/trainers/base.py (3 hunks)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py (2 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/trainer.py (2 hunks)
✅ Files skipped from review due to trivial changes (3)
  • examples/distributed-parallel/README.md
  • requirements.txt
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
🚧 Files skipped from review as they are similar to previous changes (5)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py
  • src/axolotl/train.py
  • src/axolotl/core/builders/base.py
  • src/axolotl/utils/trainer.py
  • src/axolotl/core/trainers/base.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (2)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (2)

6-10: Confirm that total world-size (8) matches cluster allocation

With dp_shard_size: 2, context_parallel_size: 2, and tensor_parallel_size: 2, the implied world-size is 2×2×2 = 8.
If the launch script allocates fewer or more processes/GPUs, the run will hang or crash during topology discovery.

Please verify that:

  1. Exactly 8 ranks will be launched, and
  2. setup_parallelism_envs(cfg) is invoked early enough so that the environment variables it sets (AXOLOTL_TP_SIZE, etc.) are visible to all ranks before model init.

13-20: Validate transformer_layer_cls_to_wrap path

transformer_layer_cls_to_wrap: Qwen3DecoderLayer is accepted only if
TRANSFORMER_BASED_WRAP resolves the symbol via import-time introspection.
Accelerate/FSDP usually expects the fully-qualified path (transformers.models.qwen.modeling_qwen.Qwen3DecoderLayer) or the actual class object supplied programmatically.

Double-check that plain class name resolution works for Qwen3; otherwise, use the FQN to avoid a silent fallback to blanket wrapping (which defeats memory savings).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)

623-648: Potential issue with model type checking and unwrapping logic.

The nested if-else logic for handling different model types could be simplified and made more robust. The current approach checks if the model is a supported class, then if not, checks if the unwrapped model is supported.

Consider refactoring for better readability and maintainability:

-        if not isinstance(self.model, supported_classes):
-            if state_dict is None:
-                state_dict = self.model.state_dict()
-            if isinstance(
-                self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
-                supported_classes,
-            ):
-                self.accelerator.unwrap_model(
-                    self.model, keep_torch_compile=False
-                ).save_pretrained(
-                    output_dir,
-                    state_dict=state_dict,
-                    safe_serialization=self.args.save_safetensors,
-                )
-            else:
-                LOG.info(
-                    "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
-                )
-                if self.args.save_safetensors:
-                    safetensors.torch.save_file(
-                        state_dict,
-                        os.path.join(output_dir, SAFE_WEIGHTS_NAME),
-                        metadata={"format": "pt"},
-                    )
-                else:
-                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        # Try to get a model that supports save_pretrained
+        model_to_save = None
+        if isinstance(self.model, supported_classes):
+            model_to_save = self.model
+        else:
+            unwrapped_model = self.accelerator.unwrap_model(self.model, keep_torch_compile=False)
+            if isinstance(unwrapped_model, supported_classes):
+                model_to_save = unwrapped_model
+        
+        if model_to_save is not None:
+            model_to_save.save_pretrained(
+                output_dir,
+                state_dict=state_dict,
+                safe_serialization=self.args.save_safetensors,
+                is_main_process=self.accelerator.is_main_process,
+            )
         else:
-            self.model.save_pretrained(
-                output_dir,
-                state_dict=state_dict,
-                safe_serialization=self.args.save_safetensors,
-                is_main_process=self.accelerator.is_main_process,
-            )
+            LOG.info(
+                "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
+            )
+            if state_dict is None:
+                state_dict = self.model.state_dict()
+            if self.args.save_safetensors:
+                safetensors.torch.save_file(
+                    state_dict,
+                    os.path.join(output_dir, SAFE_WEIGHTS_NAME),
+                    metadata={"format": "pt"},
+                )
+            else:
+                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8b52919 and 9275ca3.

📒 Files selected for processing (9)
  • examples/distributed-parallel/README.md (1 hunks)
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml (1 hunks)
  • examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/core/trainers/base.py (3 hunks)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py (2 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/trainer.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (8)
  • requirements.txt
  • src/axolotl/core/trainers/mixins/distributed_parallel.py
  • examples/distributed-parallel/README.md
  • src/axolotl/train.py
  • src/axolotl/core/builders/base.py
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
  • examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
  • src/axolotl/utils/trainer.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (3)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
  • _save (15-23)
src/axolotl/models/mamba/modeling_mamba.py (1)
  • save_pretrained (111-119)
src/axolotl/utils/distributed.py (1)
  • is_main_process (79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/core/trainers/base.py (3)

13-13: New imports look appropriate for the functionality.

The added imports support the new accelerator state management and model saving functionality:

  • safetensors for safe model serialization
  • AcceleratorState, PartialState for accelerator state management
  • PeftModel for PEFT model support
  • Various utilities from transformers for model saving

Also applies to: 15-15, 17-17, 25-28


657-669: Good implementation of processor/tokenizer saving logic.

The logic correctly handles different scenarios:

  1. Uses processing_class if available
  2. Falls back to data_collator.tokenizer with appropriate logging
  3. Always saves training arguments

The implementation follows best practices by checking for attribute existence before accessing them.


521-544: Verify AcceleratorState Reset and PartialState Side Effects
The call to AcceleratorState._reset_state(reset_partial_state=True) followed by restoring only parallelism_config can impact every place we read from PartialState. Before merging, confirm that resetting—and then only reassigning parallelism_config—still leaves downstream code in a valid state:

src/axolotl/core/builders/base.py
– Uses PartialState().device_mesh for optimizer setup
src/axolotl/loaders/model.py
– Writes to _shared_state["parallelism_config"] and expects device_mesh
src/axolotl/utils/distributed.py
– Initializes PartialState(timeout=…), reads is_main_process and get_world_size()
src/axolotl/utils/ctx_managers/sequence_parallel.py
– Reads partial_state.device_mesh for ring attention registration
src/axolotl/cli/merge_sharded_fsdp_weights.py
– Uses PartialState().is_main_process to gate checkpoint merging
src/axolotl/monkeypatch/accelerate/fsdp2.py
– References accelerator.state.parallelism_config.fsdp_dim_names
src/axolotl/core/trainers/mixins/distributed_parallel.py
– Checks self.accelerator.parallelism_config.dp_shard_enabled during save

Action items:

  1. Walk through each module and verify that after the reset+restore only parallelism_config, the properties you rely on (e.g. device_mesh, timeout, is_main_process, world size) are still correctly initialized.
  2. If any fields are cleared by the reset, either restore them explicitly or consider using a public Accelerate API (instead of the protected _reset_state) to avoid breaking with future versions.

Comment thread src/axolotl/core/trainers/base.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f20fbb4 and cdfe57b.

📒 Files selected for processing (1)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
src/axolotl/utils/distributed.py (1)
  • is_main_process (79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (3)
src/axolotl/monkeypatch/accelerate/fsdp2.py (3)

10-10: LGTM!

The import of torch.distributed is necessary for the new manual broadcast operations added to handle non-sharded parameters.


46-61: LGTM! Efficient handling of sharded parameter distribution.

The code correctly implements memory-efficient distribution where only the main process holds the full tensor while other ranks allocate empty tensors before distribution. This prevents memory spikes on non-main ranks.


63-73: No change needed: is_main_process always maps to rank 0

The helper is_main_process() returns True exactly when dist.get_rank() == 0 (or when LOCAL_RANK is "0" before initialization), so broadcasting with src=0 is safe and correct.

Comment thread src/axolotl/monkeypatch/accelerate/fsdp2.py
Comment thread src/axolotl/monkeypatch/accelerate/fsdp2.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)

79-82: Critical: Fix potential NameError for full_tensor deletion.

This appears to be the same issue flagged in previous review comments. The variable full_tensor may not be defined on non-main processes for non-sharded parameters, which would cause a NameError when trying to delete it.

The deletion of full_tensor and setting full_sd[param_name] = None should only occur when full_tensor is guaranteed to exist (i.e., on the main process).

🧹 Nitpick comments (1)
src/axolotl/utils/distributed.py (1)

316-367: Validate the parallelism dimension calculations and improve code style.

The function logic for calculating parallelism dimensions appears correct, but there are several areas for improvement:

  1. Line 335-336: The nested if statements can be combined as suggested by static analysis.
  2. Line 356-357: Another nested if that can be simplified.
  3. Potential logic issue: The function modifies remaining_world_size but uses get_world_size() in the error message (line 363), which could be confusing.

Apply these improvements:

-    if dp_shard_size is None and dp_replicate_size in (None, 1):
-        if remaining_world_size > 1:
+    if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1:
             pc_kwargs["dp_shard_size"] = remaining_world_size
             remaining_world_size = 1

-    if remaining_world_size > 1:
-        if "dp_shard_size" not in pc_kwargs and is_fsdp:
+    if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp:
             pc_kwargs["dp_shard_size"] = remaining_world_size
             remaining_world_size = 1

     if remaining_world_size > 1:
         raise ValueError(
-            f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
+            f"The configured parallelisms are incompatible with the current world size ({world_size})!\n"
             f"{pc_kwargs}"
         )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ba46c75 and 1e1db74.

📒 Files selected for processing (16)
  • examples/distributed-parallel/README.md (1 hunks)
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml (1 hunks)
  • examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1 hunks)
  • examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml (2 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/core/builders/base.py (3 hunks)
  • src/axolotl/core/trainers/base.py (3 hunks)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py (2 hunks)
  • src/axolotl/loaders/model.py (7 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (2 hunks)
  • src/axolotl/prompt_strategies/chat_template.py (5 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/distributed.py (2 hunks)
  • src/axolotl/utils/schemas/datasets.py (1 hunks)
  • src/axolotl/utils/trainer.py (2 hunks)
  • tests/test_loaders.py (2 hunks)
✅ Files skipped from review due to trivial changes (3)
  • examples/distributed-parallel/README.md
  • tests/test_loaders.py
  • requirements.txt
🚧 Files skipped from review as they are similar to previous changes (11)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py
  • src/axolotl/utils/trainer.py
  • src/axolotl/core/builders/base.py
  • examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
  • src/axolotl/loaders/model.py
  • src/axolotl/core/trainers/base.py
  • src/axolotl/train.py
  • src/axolotl/prompt_strategies/chat_template.py
  • examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
  • examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
  • src/axolotl/utils/schemas/datasets.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
src/axolotl/utils/distributed.py (1)
  • is_main_process (80-95)
🪛 Ruff (0.12.2)
src/axolotl/utils/distributed.py

335-336: Use a single if statement instead of nested if statements

(SIM102)


356-357: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/utils/distributed.py (2)

296-313: LGTM on the main function structure.

The build_parallelism_config function correctly delegates to the helper function and handles the creation of ParallelismConfig and device mesh appropriately.


344-355: Add missing unit test for FSDP validation in _get_parallel_config_kwargs

It looks like there are no existing tests covering the branch where
remaining_world_size > 1, dp_shard_size > 1, and is_fsdp = False.
Please add a unit test—e.g. under tests/—that invokes _get_parallel_config_kwargs with those parameters and asserts that the correct ValueError is thrown.

src/axolotl/monkeypatch/accelerate/fsdp2.py (3)

10-10: LGTM on the import addition.

Adding torch.distributed as dist is necessary for the manual broadcast operation introduced later in the function.


40-45: LGTM on the iteration change.

Changing the iteration from full_sd.items() to meta_sharded_sd.items() is logical since we need to process each parameter in the sharded model. The conditional initialization of full_tensor only on the main process is correct for distributed scenarios.


63-73: Verify the manual broadcast implementation.

The manual broadcast logic for non-sharded parameters looks correct, but ensure that:

  1. The tensor shapes match between ranks before broadcasting
  2. The broadcast operation is synchronous across all ranks

The implementation correctly handles the case where non-main processes need to receive the parameter data through explicit broadcasting, which is necessary when full_sd is only available on the main process.

Comment thread src/axolotl/prompt_strategies/chat_template.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1e1db74 and 5f5eb8e.

📒 Files selected for processing (5)
  • examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml (1 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/core/builders/causal.py (3 hunks)
  • src/axolotl/monkeypatch/multipack.py (1 hunks)
  • src/axolotl/utils/collators/batching.py (2 hunks)
✅ Files skipped from review due to trivial changes (2)
  • src/axolotl/monkeypatch/multipack.py
  • examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • requirements.txt
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/core/builders/causal.py (2)

6-6: LGTM! Import addition is appropriate.

The partial import from functools is correctly added to support the GPT-OSS specific collator configuration.


367-367: Good refactor: More explicit conditional logic.

Changing from else to elif self.cfg.pad_to_sequence_len is None makes the padding logic more explicit and readable. This clarifies that the default padding behavior only applies when pad_to_sequence_len is explicitly None.

src/axolotl/utils/collators/batching.py (1)

164-164: LGTM! New attribute properly defined.

The squash_position_ids boolean attribute with default False is correctly defined as a dataclass field, maintaining backward compatibility.

Comment thread src/axolotl/core/builders/causal.py Outdated
Comment on lines +480 to +484
# TODO(wing): figure out why gpt-oss performs better without true sample position_ids
if self.cfg.model_config_type in ["gpt_oss"]:
collator = partial( # type: ignore
V2BatchSamplerDataCollatorForSeq2Seq, squash_position_ids=True
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify GPT-OSS position ID squashing behavior.

The implementation uses partial to configure V2BatchSamplerDataCollatorForSeq2Seq with squash_position_ids=True specifically for GPT-OSS models. The TODO comment suggests this was added based on empirical performance observations.

Please verify that this change doesn't break existing functionality by running the following script:


🏁 Script executed:

#!/bin/bash
# Description: Verify GPT-OSS model configurations and ensure no breaking changes
# Expected: Find GPT-OSS configurations and verify the integration

# Search for GPT-OSS model configurations
echo "=== Searching for GPT-OSS model configurations ==="
rg -A 5 -B 2 "gpt_oss|gpt-oss" --type py

# Verify the squash_position_ids parameter usage
echo "=== Verifying squash_position_ids parameter usage ==="
rg -A 10 "squash_position_ids" --type py

# Check for any tests related to GPT-OSS models
echo "=== Checking for GPT-OSS related tests ==="
fd -e py -x rg -l "gpt_oss|gpt-oss" {} \; | head -10

Length of output: 4951


Add tests for GPT-OSS squash_position_ids override

We’ve confirmed that:

  • In src/axolotl/core/builders/causal.py, the gpt_oss branch correctly applies
    partial(V2BatchSamplerDataCollatorForSeq2Seq, squash_position_ids=True).
  • In src/axolotl/utils/collators/batching.py, the squash_position_ids flag produces a concatenated position_ids array via np.arange.

However, there are currently no automated tests exercising this code path.

Please add unit or integration tests to cover:

  • The GPT-OSS collator branch in src/axolotl/core/builders/causal.py
  • The squashing behavior in src/axolotl/utils/collators/batching.py, verifying that for GPT-OSS inputs the output position_ids are a contiguous range.
🤖 Prompt for AI Agents
In src/axolotl/core/builders/causal.py around lines 480 to 484, add unit or
integration tests to cover the GPT-OSS collator branch where squash_position_ids
is set to True using partial(V2BatchSamplerDataCollatorForSeq2Seq). Also, in
src/axolotl/utils/collators/batching.py, add tests verifying that when
squash_position_ids is True, the position_ids output is a contiguous range
created by np.arange, specifically for GPT-OSS inputs. These tests should
confirm the correct application and effect of the squash_position_ids flag.

Comment on lines +181 to +189
elif feature == "position_ids" and self.squash_position_ids:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
# concatenate, get total length and create arange of new total position ids
position_ids = np.concatenate(arrays)
total_length = position_ids.shape[0]
position_ids = np.arange(total_length)
out_features[i][feature] = position_ids
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify position ID squashing logic correctness.

The position ID squashing logic concatenates position IDs from multiple sequences and replaces them with a continuous range starting from 0. This implementation looks correct for creating continuous position IDs across concatenated sequences.

The logic appears sound, but let's verify there are no edge cases by running this script:


🏁 Script executed:

#!/bin/bash
# Description: Look for any existing tests for V2BatchSamplerDataCollatorForSeq2Seq to understand expected behavior
# Expected: Find test files that validate this collator's behavior

# Search for tests related to the V2BatchSamplerDataCollatorForSeq2Seq
echo "=== Searching for V2BatchSamplerDataCollatorForSeq2Seq tests ==="
rg -A 15 -B 5 "V2BatchSamplerDataCollatorForSeq2Seq" --type py

# Look for any position_ids related test cases
echo "=== Searching for position_ids test cases ==="
rg -A 10 -B 5 "position_ids.*test|test.*position_ids" --type py

# Search for multipack related tests that might cover this collator
echo "=== Searching for multipack collator tests ==="
rg -A 10 -B 5 "multipack.*collator|collator.*multipack" --type py

Length of output: 14136


Add explicit tests for squash_position_ids logic

While the concatenation-and-arange implementation in V2BatchSamplerDataCollatorForSeq2Seq (src/axolotl/utils/collators/batching.py lines 181–189) looks correct, there are currently no existing tests that validate its behavior. Please add unit tests to cover:

  • Multi-sequence inputs when squash_position_ids=True, asserting that
    out_features[i]["position_ids"] equals a continuous range from 0 to total_length − 1.
  • Edge cases, e.g. sequences of varying lengths or items missing position_ids (they should simply be skipped).

You can extend tests/test_packed_batch_sampler.py or create a new test file (e.g. tests/test_batching_collator.py) with assertions like:

# pseudocode
collator = V2BatchSamplerDataCollatorForSeq2Seq(tokenizer, squash_position_ids=True)
batch = [
    [{"input_ids": [1,2], "position_ids": [0,1]}],
    [{"input_ids": [3,4,5], "position_ids": [0,1,2]}],
]
out = collator(batch)
assert out["position_ids"].tolist() == list(range(5))
🤖 Prompt for AI Agents
In src/axolotl/utils/collators/batching.py around lines 181 to 189, the
squash_position_ids logic is implemented but lacks explicit unit tests. Add
tests in tests/test_packed_batch_sampler.py or create a new test file like
tests/test_batching_collator.py that cover multi-sequence inputs with
squash_position_ids=True, verifying that out_features[i]["position_ids"] is a
continuous range from 0 to total_length - 1. Also include edge cases such as
sequences of varying lengths and items missing position_ids, ensuring those
items are skipped without errors. Use assertions to confirm the output matches
expected continuous position id ranges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml (2)

55-56: Ensure tokenizer is re-trained / loaded with <|return|>

Adding a new special token after fine-tuning begins will shift vocab indices.
Verify the tokenizer JSON used at run-time already contains <|return|>; otherwise pass --additional_special_tokens in the training script to avoid OOV issues.


65-65: Consider enabling cpu_ram_efficient_loading for 20 B model

The 20 B GPT-OSS weights are ~40 GB fp16; on multi-GPU nodes with limited host RAM, enabling this FSDP flag avoids a full duplicate in CPU memory during load.
If hardware allows, uncommenting may reclaim ~30-40 GB RAM at startup.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 43f263d and ab450f7.

📒 Files selected for processing (5)
  • examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml (3 hunks)
  • examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml (4 hunks)
  • examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml (3 hunks)
  • requirements.txt (2 hunks)
  • src/axolotl/monkeypatch/accelerate/fsdp2.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
  • requirements.txt
  • examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
  • src/axolotl/monkeypatch/accelerate/fsdp2.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: pre-commit
🔇 Additional comments (2)
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml (2)

51-51: Warm-up ratio cut from 10 % → 3 %

A shorter warm-up can destabilise large-model training, especially with FSDP-2 and mixed precision.
Confirm that LR spikes are not triggering gradient overflow or NaNs in early steps.


13-16: Manual Verification Needed: Confirm thinking Column Exists

The sandbox environment couldn’t import datasets, so we couldn’t verify column names automatically. Please manually confirm that the HuggingFaceH4/Multilingual-Thinking dataset exposes a thinking column (case-sensitive) before merging.

You can run a quick check locally, for example:

python - <<'PY'
from datasets import load_dataset
ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:1]")
print(ds.column_names)
PY

• If the column is missing or named differently, update field_thinking and template_thinking_key accordingly.
• If it exists as expected, no further changes are needed here.

Comment thread examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)

31-31: Remove unnecessary pylint disable comment

The comparison self.cp_size > 1 is not a chained comparison, so the pylint disable comment is unnecessary.

    if (
-        self.cp_size > 1  # pylint: disable=chained-comparison
+        self.cp_size > 1
        and self.dp_shard_size <= 1
src/axolotl/utils/schemas/validation.py (1)

1150-1162: Simplify nested if statements for better readability

The nested if statements can be combined into a single condition.

    @model_validator(mode="before")
    @classmethod
    def check_gpt_oss_fsdp_loading(cls, data):
-        if data.get("model_quantization_config", "") == "Mxfp4Config":
-            if (
-                data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
-                is True
-            ):
-                raise ValueError(
-                    "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
-                )
+        if (
+            data.get("model_quantization_config", "") == "Mxfp4Config"
+            and data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) is True
+        ):
+            raise ValueError(
+                "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
+            )
        return data
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ab450f7 and a13470e.

📒 Files selected for processing (3)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/accelerate/parallelism_config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/loaders/patch_manager.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)
src/axolotl/utils/distributed.py (1)
  • is_main_process (80-95)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/parallelism_config.py

53-53: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

src/axolotl/utils/schemas/validation.py

1153-1157: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)

59-63: LGTM!

The patching implementation is correct and follows the standard monkeypatch pattern.

Comment thread src/axolotl/monkeypatch/accelerate/parallelism_config.py
Comment thread src/axolotl/monkeypatch/accelerate/parallelism_config.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4dc2200 and ed5deba.

📒 Files selected for processing (1)
  • src/axolotl/core/trainers/mixins/distributed_parallel.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
src/axolotl/core/trainers/base.py (1)
  • create_accelerator_and_postprocess (522-543)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)

5-5: LGTM: Import addition supports the new functionality.

The import of PartialState is necessary for the new method implementation and aligns with the accelerator state management pattern used in the codebase.

Comment thread src/axolotl/core/trainers/mixins/distributed_parallel.py
@winglian winglian merged commit 9d5c95d into main Aug 8, 2025
19 of 20 checks passed
@winglian winglian deleted the accelerate-cp branch August 8, 2025 01:22
@coderabbitai coderabbitai Bot mentioned this pull request Feb 15, 2026
@coderabbitai coderabbitai Bot mentioned this pull request Apr 30, 2026
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants