Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp#3019
Conversation
📝 WalkthroughWalkthroughThis update introduces new distributed training configuration files and a README, adds support for configurable "thinking" fields in chat templates and datasets, refactors distributed parallelism and device mesh setup by removing Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
|
📖 Documentation Preview: https://68951fafe37cc23d5d02d203--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit b8cd55c |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)
46-47: Incomplete special_tokens configuration.The special_tokens section is defined but empty. Consider either removing it if not needed or adding the appropriate token configurations for Qwen3.
🧹 Nitpick comments (2)
src/axolotl/utils/trainer.py (1)
600-617: LGTM! Clean parallelism environment setup.The new
setup_parallelism_envsfunction properly configures environment variables for distributed parallel training. The logic correctly sets variables only when parallelism sizes are greater than 1 and enables the parallelism configuration system appropriately.Consider adding type hints and docstring for better maintainability:
-def setup_parallelism_envs(cfg): +def setup_parallelism_envs(cfg) -> None: + """ + Configure environment variables for parallelism settings. + + Sets environment variables for tensor parallel, data parallel shard/replicate, + and context parallel sizes when they are greater than 1. + + Args: + cfg: Configuration object containing parallelism settings. + """examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml (1)
30-47: Training configuration looks good.The hyperparameters are reasonable for distributed training. Note that the effective batch size per GPU is 1 (micro_batch_size × gradient_accumulation_steps), which is fine for testing but you may want to increase it for production training.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/distirbuted-parallel/README.md(1 hunks)examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml(1 hunks)examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml(1 hunks)requirements.txt(2 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/core/trainers/base.py(3 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/trainer.py(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (3)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
_save(13-20)src/axolotl/models/mamba/modeling_mamba.py (1)
save_pretrained(111-119)src/axolotl/utils/distributed.py (1)
is_main_process(79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: preview
🔇 Additional comments (15)
requirements.txt (2)
18-18: Accelerate commit hash verifiedThe commit
6891c57072295eec14d9866ca34472d59708db9cexists in the huggingface/accelerate repo, is dated 2025-08-05, and its message (“Feat: context parallel v2.0 (#3700)…” ) confirms the expected parallelism updates. No issues detected.
69-69: Confirmed axolotl-contribs-mit==0.0.5 on PyPI
- Version 0.0.5 is published and is the latest release.
- No upload timestamp is provided in the metadata.
- No security advisories are referenced in the PyPI project metadata.
No further action required.
src/axolotl/train.py (1)
277-277: Approve FSDP save-condition updateVerified that the expanded check (
trainer.is_fsdp_enabled or cfg.fsdp_config) matches existing FSDP-config patterns elsewhere in the codebase—no inconsistencies detected. No changes required.Key FSDP-config usages for reference:
- src/axolotl/train.py: final-state‐dict saving
- src/axolotl/core/builders/base.py: passing
fsdp_configintotraining_args_kwargs- src/axolotl/utils/trainer.py: setting FSDP environment variables
- src/axolotl/loaders/model.py:
is_fsdp_enabledproperty based oncfg.fsdp_config- src/axolotl/loaders/patch_manager.py: applying FSDP patches when
cfg.fsdp_configis presentsrc/axolotl/utils/trainer.py (1)
637-637: Good integration with existing environment setup.The placement of
setup_parallelism_envs(cfg)at the end ofprepare_optim_envensures parallelism variables are set after FSDP and DeepSpeed configurations, which is the correct order.src/axolotl/core/builders/base.py (1)
447-452: LGTM: Simplified accelerator configuration
- Confirmed that
src/axolotl/core/builders/base.pyno longer contains anyPartialStateinspection oruse_configured_statelogic.- Noticed that the trainer still reads
use_configured_state = accelerator_config.get("use_configured_state", False)
insrc/axolotl/core/trainers/base.py. Ensure your environment‐based parallelism setup drives this flag as intended.- Recommend a quick smoke test of your training loop both with and without setting
accelerator_config.use_configured_statein your config to verify existing workflows remain unchanged.examples/distirbuted-parallel/llama-3_1-8b-hdsp-tp.yaml (4)
1-4: LGTM!The base model and CutCrossEntropyPlugin configuration are appropriate for memory-efficient distributed training.
6-9: LGTM!The parallelism configuration correctly sets up hybrid parallelism with 8 GPUs total (2 × 2 × 2).
16-22: LGTM!FSDP v2 configuration is correctly set up with appropriate wrapping for LlamaDecoderLayer and memory-efficient resharding.
11-12: LGTM!Dataset and output directory configuration are standard and appropriate.
Also applies to: 24-28
examples/distirbuted-parallel/qwen3-8b-fsdp-tp-cp.yaml (3)
6-9: LGTM!Parallelism configuration correctly sets up context parallelism alongside tensor and data parallelism for 8 GPUs total.
13-19: LGTM!FSDP configuration correctly specifies Qwen3DecoderLayer for the Qwen3 model architecture.
27-32: LGTM!Correctly configured with longer sequence length (8192) appropriate for Qwen3, and micro_batch_size=1 as required for context parallelism.
src/axolotl/core/trainers/base.py (3)
13-29: LGTM!New imports are appropriate for the enhanced model saving functionality with safetensors and PEFT support.
522-543: LGTM!The accelerator state reset logic correctly ensures that Accelerate configures itself from environment variables for distributed parallel training, while preserving the parallelism_config for downstream use.
610-669: Duplicate saving of training arguments.The training arguments are saved twice when the model is a supported class:
- Inside the else block (lines 649-668)
- Again at line 669 which is outside the else block
Line 669 should be removed as it's redundant.
Apply this fix:
self.data_collator.tokenizer.save_pretrained(output_dir) - # Good practice: save your training arguments together with the trained model - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)
46-47: Danglingspecial_tokenskey produces null value
special_tokens:with no following mapping or list parses as YAMLnull.
This overrides any default special-token handling in Axolotl and can trigger key-look-up errors during tokenizer setup.Either remove the key or supply an explicit mapping, e.g.:
-special_tokens: +special_tokens: {}
🧹 Nitpick comments (1)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (1)
31-33: Micro-batch size of 1 severely limits throughputContext parallelism forces
micro_batch_size: 1, but with sequence length 8192 and world-size 8 you may hit under-utilisation.
Consider enabling gradient accumulation (gradient_accumulation_steps > 1) or increasingdp_shard_sizeif memory allows, to improve device utilisation and optimiser step quality.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/distributed-parallel/README.md(1 hunks)examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml(1 hunks)examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml(1 hunks)requirements.txt(2 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/core/trainers/base.py(3 hunks)src/axolotl/core/trainers/mixins/distributed_parallel.py(2 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/trainer.py(2 hunks)
✅ Files skipped from review due to trivial changes (3)
- examples/distributed-parallel/README.md
- requirements.txt
- examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
🚧 Files skipped from review as they are similar to previous changes (5)
- src/axolotl/core/trainers/mixins/distributed_parallel.py
- src/axolotl/train.py
- src/axolotl/core/builders/base.py
- src/axolotl/utils/trainer.py
- src/axolotl/core/trainers/base.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (2)
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml (2)
6-10: Confirm that total world-size (8) matches cluster allocationWith
dp_shard_size: 2,context_parallel_size: 2, andtensor_parallel_size: 2, the implied world-size is2×2×2 = 8.
If the launch script allocates fewer or more processes/GPUs, the run will hang or crash during topology discovery.Please verify that:
- Exactly 8 ranks will be launched, and
setup_parallelism_envs(cfg)is invoked early enough so that the environment variables it sets (AXOLOTL_TP_SIZE, etc.) are visible to all ranks before model init.
13-20: Validatetransformer_layer_cls_to_wrappath
transformer_layer_cls_to_wrap: Qwen3DecoderLayeris accepted only if
TRANSFORMER_BASED_WRAPresolves the symbol via import-time introspection.
Accelerate/FSDP usually expects the fully-qualified path (transformers.models.qwen.modeling_qwen.Qwen3DecoderLayer) or the actual class object supplied programmatically.Double-check that plain class name resolution works for Qwen3; otherwise, use the FQN to avoid a silent fallback to blanket wrapping (which defeats memory savings).
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
623-648: Potential issue with model type checking and unwrapping logic.The nested if-else logic for handling different model types could be simplified and made more robust. The current approach checks if the model is a supported class, then if not, checks if the unwrapped model is supported.
Consider refactoring for better readability and maintainability:
- if not isinstance(self.model, supported_classes): - if state_dict is None: - state_dict = self.model.state_dict() - if isinstance( - self.accelerator.unwrap_model(self.model, keep_torch_compile=False), - supported_classes, - ): - self.accelerator.unwrap_model( - self.model, keep_torch_compile=False - ).save_pretrained( - output_dir, - state_dict=state_dict, - safe_serialization=self.args.save_safetensors, - ) - else: - LOG.info( - "Trainer.model is not a `PreTrainedModel`, only saving its state dict." - ) - if self.args.save_safetensors: - safetensors.torch.save_file( - state_dict, - os.path.join(output_dir, SAFE_WEIGHTS_NAME), - metadata={"format": "pt"}, - ) - else: - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + # Try to get a model that supports save_pretrained + model_to_save = None + if isinstance(self.model, supported_classes): + model_to_save = self.model + else: + unwrapped_model = self.accelerator.unwrap_model(self.model, keep_torch_compile=False) + if isinstance(unwrapped_model, supported_classes): + model_to_save = unwrapped_model + + if model_to_save is not None: + model_to_save.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + is_main_process=self.accelerator.is_main_process, + ) else: - self.model.save_pretrained( - output_dir, - state_dict=state_dict, - safe_serialization=self.args.save_safetensors, - is_main_process=self.accelerator.is_main_process, - ) + LOG.info( + "Trainer.model is not a `PreTrainedModel`, only saving its state dict." + ) + if state_dict is None: + state_dict = self.model.state_dict() + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, + os.path.join(output_dir, SAFE_WEIGHTS_NAME), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/distributed-parallel/README.md(1 hunks)examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml(1 hunks)examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml(1 hunks)requirements.txt(2 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/core/trainers/base.py(3 hunks)src/axolotl/core/trainers/mixins/distributed_parallel.py(2 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/trainer.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (8)
- requirements.txt
- src/axolotl/core/trainers/mixins/distributed_parallel.py
- examples/distributed-parallel/README.md
- src/axolotl/train.py
- src/axolotl/core/builders/base.py
- examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
- examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
- src/axolotl/utils/trainer.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (3)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
_save(15-23)src/axolotl/models/mamba/modeling_mamba.py (1)
save_pretrained(111-119)src/axolotl/utils/distributed.py (1)
is_main_process(79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/core/trainers/base.py (3)
13-13: New imports look appropriate for the functionality.The added imports support the new accelerator state management and model saving functionality:
safetensorsfor safe model serializationAcceleratorState,PartialStatefor accelerator state managementPeftModelfor PEFT model support- Various utilities from transformers for model saving
Also applies to: 15-15, 17-17, 25-28
657-669: Good implementation of processor/tokenizer saving logic.The logic correctly handles different scenarios:
- Uses
processing_classif available- Falls back to
data_collator.tokenizerwith appropriate logging- Always saves training arguments
The implementation follows best practices by checking for attribute existence before accessing them.
521-544: Verify AcceleratorState Reset and PartialState Side Effects
The call toAcceleratorState._reset_state(reset_partial_state=True)followed by restoring onlyparallelism_configcan impact every place we read fromPartialState. Before merging, confirm that resetting—and then only reassigningparallelism_config—still leaves downstream code in a valid state:• src/axolotl/core/builders/base.py
– UsesPartialState().device_meshfor optimizer setup
• src/axolotl/loaders/model.py
– Writes to_shared_state["parallelism_config"]and expectsdevice_mesh
• src/axolotl/utils/distributed.py
– InitializesPartialState(timeout=…), readsis_main_processandget_world_size()
• src/axolotl/utils/ctx_managers/sequence_parallel.py
– Readspartial_state.device_meshfor ring attention registration
• src/axolotl/cli/merge_sharded_fsdp_weights.py
– UsesPartialState().is_main_processto gate checkpoint merging
• src/axolotl/monkeypatch/accelerate/fsdp2.py
– Referencesaccelerator.state.parallelism_config.fsdp_dim_names
• src/axolotl/core/trainers/mixins/distributed_parallel.py
– Checksself.accelerator.parallelism_config.dp_shard_enabledduring saveAction items:
- Walk through each module and verify that after the reset+restore only
parallelism_config, the properties you rely on (e.g.device_mesh,timeout,is_main_process, world size) are still correctly initialized.- If any fields are cleared by the reset, either restore them explicitly or consider using a public Accelerate API (instead of the protected
_reset_state) to avoid breaking with future versions.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
src/axolotl/utils/distributed.py (1)
is_main_process(79-94)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (3)
src/axolotl/monkeypatch/accelerate/fsdp2.py (3)
10-10: LGTM!The import of
torch.distributedis necessary for the new manual broadcast operations added to handle non-sharded parameters.
46-61: LGTM! Efficient handling of sharded parameter distribution.The code correctly implements memory-efficient distribution where only the main process holds the full tensor while other ranks allocate empty tensors before distribution. This prevents memory spikes on non-main ranks.
63-73: No change needed:is_main_processalways maps to rank 0The helper
is_main_process()returnsTrueexactly whendist.get_rank() == 0(or whenLOCAL_RANKis "0" before initialization), so broadcasting withsrc=0is safe and correct.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
79-82: Critical: Fix potential NameError for full_tensor deletion.This appears to be the same issue flagged in previous review comments. The variable
full_tensormay not be defined on non-main processes for non-sharded parameters, which would cause a NameError when trying to delete it.The deletion of
full_tensorand settingfull_sd[param_name] = Noneshould only occur whenfull_tensoris guaranteed to exist (i.e., on the main process).
🧹 Nitpick comments (1)
src/axolotl/utils/distributed.py (1)
316-367: Validate the parallelism dimension calculations and improve code style.The function logic for calculating parallelism dimensions appears correct, but there are several areas for improvement:
- Line 335-336: The nested if statements can be combined as suggested by static analysis.
- Line 356-357: Another nested if that can be simplified.
- Potential logic issue: The function modifies
remaining_world_sizebut usesget_world_size()in the error message (line 363), which could be confusing.Apply these improvements:
- if dp_shard_size is None and dp_replicate_size in (None, 1): - if remaining_world_size > 1: + if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1: pc_kwargs["dp_shard_size"] = remaining_world_size remaining_world_size = 1 - if remaining_world_size > 1: - if "dp_shard_size" not in pc_kwargs and is_fsdp: + if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp: pc_kwargs["dp_shard_size"] = remaining_world_size remaining_world_size = 1 if remaining_world_size > 1: raise ValueError( - f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n" + f"The configured parallelisms are incompatible with the current world size ({world_size})!\n" f"{pc_kwargs}" )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
examples/distributed-parallel/README.md(1 hunks)examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml(1 hunks)examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml(1 hunks)examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml(2 hunks)requirements.txt(2 hunks)src/axolotl/core/builders/base.py(3 hunks)src/axolotl/core/trainers/base.py(3 hunks)src/axolotl/core/trainers/mixins/distributed_parallel.py(2 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)src/axolotl/prompt_strategies/chat_template.py(5 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/distributed.py(2 hunks)src/axolotl/utils/schemas/datasets.py(1 hunks)src/axolotl/utils/trainer.py(2 hunks)tests/test_loaders.py(2 hunks)
✅ Files skipped from review due to trivial changes (3)
- examples/distributed-parallel/README.md
- tests/test_loaders.py
- requirements.txt
🚧 Files skipped from review as they are similar to previous changes (11)
- src/axolotl/core/trainers/mixins/distributed_parallel.py
- src/axolotl/utils/trainer.py
- src/axolotl/core/builders/base.py
- examples/distributed-parallel/llama-3_1-8b-hdsp-tp.yaml
- src/axolotl/loaders/model.py
- src/axolotl/core/trainers/base.py
- src/axolotl/train.py
- src/axolotl/prompt_strategies/chat_template.py
- examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
- examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
- src/axolotl/utils/schemas/datasets.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
src/axolotl/utils/distributed.py (1)
is_main_process(80-95)
🪛 Ruff (0.12.2)
src/axolotl/utils/distributed.py
335-336: Use a single if statement instead of nested if statements
(SIM102)
356-357: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/utils/distributed.py (2)
296-313: LGTM on the main function structure.The
build_parallelism_configfunction correctly delegates to the helper function and handles the creation of ParallelismConfig and device mesh appropriately.
344-355: Add missing unit test for FSDP validation in_get_parallel_config_kwargsIt looks like there are no existing tests covering the branch where
remaining_world_size > 1,dp_shard_size > 1, andis_fsdp = False.
Please add a unit test—e.g. undertests/—that invokes_get_parallel_config_kwargswith those parameters and asserts that the correctValueErroris thrown.src/axolotl/monkeypatch/accelerate/fsdp2.py (3)
10-10: LGTM on the import addition.Adding
torch.distributed as distis necessary for the manual broadcast operation introduced later in the function.
40-45: LGTM on the iteration change.Changing the iteration from
full_sd.items()tometa_sharded_sd.items()is logical since we need to process each parameter in the sharded model. The conditional initialization offull_tensoronly on the main process is correct for distributed scenarios.
63-73: Verify the manual broadcast implementation.The manual broadcast logic for non-sharded parameters looks correct, but ensure that:
- The tensor shapes match between ranks before broadcasting
- The broadcast operation is synchronous across all ranks
The implementation correctly handles the case where non-main processes need to receive the parameter data through explicit broadcasting, which is necessary when
full_sdis only available on the main process.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml(1 hunks)requirements.txt(2 hunks)src/axolotl/core/builders/causal.py(3 hunks)src/axolotl/monkeypatch/multipack.py(1 hunks)src/axolotl/utils/collators/batching.py(2 hunks)
✅ Files skipped from review due to trivial changes (2)
- src/axolotl/monkeypatch/multipack.py
- examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- requirements.txt
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/core/builders/causal.py (2)
6-6: LGTM! Import addition is appropriate.The
partialimport fromfunctoolsis correctly added to support the GPT-OSS specific collator configuration.
367-367: Good refactor: More explicit conditional logic.Changing from
elsetoelif self.cfg.pad_to_sequence_len is Nonemakes the padding logic more explicit and readable. This clarifies that the default padding behavior only applies whenpad_to_sequence_lenis explicitlyNone.src/axolotl/utils/collators/batching.py (1)
164-164: LGTM! New attribute properly defined.The
squash_position_idsboolean attribute with defaultFalseis correctly defined as a dataclass field, maintaining backward compatibility.
| # TODO(wing): figure out why gpt-oss performs better without true sample position_ids | ||
| if self.cfg.model_config_type in ["gpt_oss"]: | ||
| collator = partial( # type: ignore | ||
| V2BatchSamplerDataCollatorForSeq2Seq, squash_position_ids=True | ||
| ) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Verify GPT-OSS position ID squashing behavior.
The implementation uses partial to configure V2BatchSamplerDataCollatorForSeq2Seq with squash_position_ids=True specifically for GPT-OSS models. The TODO comment suggests this was added based on empirical performance observations.
Please verify that this change doesn't break existing functionality by running the following script:
🏁 Script executed:
#!/bin/bash
# Description: Verify GPT-OSS model configurations and ensure no breaking changes
# Expected: Find GPT-OSS configurations and verify the integration
# Search for GPT-OSS model configurations
echo "=== Searching for GPT-OSS model configurations ==="
rg -A 5 -B 2 "gpt_oss|gpt-oss" --type py
# Verify the squash_position_ids parameter usage
echo "=== Verifying squash_position_ids parameter usage ==="
rg -A 10 "squash_position_ids" --type py
# Check for any tests related to GPT-OSS models
echo "=== Checking for GPT-OSS related tests ==="
fd -e py -x rg -l "gpt_oss|gpt-oss" {} \; | head -10Length of output: 4951
Add tests for GPT-OSS squash_position_ids override
We’ve confirmed that:
- In
src/axolotl/core/builders/causal.py, thegpt_ossbranch correctly applies
partial(V2BatchSamplerDataCollatorForSeq2Seq, squash_position_ids=True). - In
src/axolotl/utils/collators/batching.py, thesquash_position_idsflag produces a concatenatedposition_idsarray vianp.arange.
However, there are currently no automated tests exercising this code path.
Please add unit or integration tests to cover:
- The GPT-OSS collator branch in src/axolotl/core/builders/causal.py
- The squashing behavior in src/axolotl/utils/collators/batching.py, verifying that for GPT-OSS inputs the output
position_idsare a contiguous range.
🤖 Prompt for AI Agents
In src/axolotl/core/builders/causal.py around lines 480 to 484, add unit or
integration tests to cover the GPT-OSS collator branch where squash_position_ids
is set to True using partial(V2BatchSamplerDataCollatorForSeq2Seq). Also, in
src/axolotl/utils/collators/batching.py, add tests verifying that when
squash_position_ids is True, the position_ids output is a contiguous range
created by np.arange, specifically for GPT-OSS inputs. These tests should
confirm the correct application and effect of the squash_position_ids flag.
| elif feature == "position_ids" and self.squash_position_ids: | ||
| arrays = [ | ||
| np.array(item[feature]) for item in features_ if feature in item | ||
| ] | ||
| # concatenate, get total length and create arange of new total position ids | ||
| position_ids = np.concatenate(arrays) | ||
| total_length = position_ids.shape[0] | ||
| position_ids = np.arange(total_length) | ||
| out_features[i][feature] = position_ids |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Verify position ID squashing logic correctness.
The position ID squashing logic concatenates position IDs from multiple sequences and replaces them with a continuous range starting from 0. This implementation looks correct for creating continuous position IDs across concatenated sequences.
The logic appears sound, but let's verify there are no edge cases by running this script:
🏁 Script executed:
#!/bin/bash
# Description: Look for any existing tests for V2BatchSamplerDataCollatorForSeq2Seq to understand expected behavior
# Expected: Find test files that validate this collator's behavior
# Search for tests related to the V2BatchSamplerDataCollatorForSeq2Seq
echo "=== Searching for V2BatchSamplerDataCollatorForSeq2Seq tests ==="
rg -A 15 -B 5 "V2BatchSamplerDataCollatorForSeq2Seq" --type py
# Look for any position_ids related test cases
echo "=== Searching for position_ids test cases ==="
rg -A 10 -B 5 "position_ids.*test|test.*position_ids" --type py
# Search for multipack related tests that might cover this collator
echo "=== Searching for multipack collator tests ==="
rg -A 10 -B 5 "multipack.*collator|collator.*multipack" --type pyLength of output: 14136
Add explicit tests for squash_position_ids logic
While the concatenation-and-arange implementation in V2BatchSamplerDataCollatorForSeq2Seq (src/axolotl/utils/collators/batching.py lines 181–189) looks correct, there are currently no existing tests that validate its behavior. Please add unit tests to cover:
- Multi-sequence inputs when
squash_position_ids=True, asserting that
out_features[i]["position_ids"]equals a continuous range from0tototal_length − 1. - Edge cases, e.g. sequences of varying lengths or items missing
position_ids(they should simply be skipped).
You can extend tests/test_packed_batch_sampler.py or create a new test file (e.g. tests/test_batching_collator.py) with assertions like:
# pseudocode
collator = V2BatchSamplerDataCollatorForSeq2Seq(tokenizer, squash_position_ids=True)
batch = [
[{"input_ids": [1,2], "position_ids": [0,1]}],
[{"input_ids": [3,4,5], "position_ids": [0,1,2]}],
]
out = collator(batch)
assert out["position_ids"].tolist() == list(range(5))🤖 Prompt for AI Agents
In src/axolotl/utils/collators/batching.py around lines 181 to 189, the
squash_position_ids logic is implemented but lacks explicit unit tests. Add
tests in tests/test_packed_batch_sampler.py or create a new test file like
tests/test_batching_collator.py that cover multi-sequence inputs with
squash_position_ids=True, verifying that out_features[i]["position_ids"] is a
continuous range from 0 to total_length - 1. Also include edge cases such as
sequences of varying lengths and items missing position_ids, ensuring those
items are skipped without errors. Use assertions to confirm the output matches
expected continuous position id ranges.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml (2)
55-56: Ensure tokenizer is re-trained / loaded with<|return|>Adding a new special token after fine-tuning begins will shift vocab indices.
Verify the tokenizer JSON used at run-time already contains<|return|>; otherwise pass--additional_special_tokensin the training script to avoid OOV issues.
65-65: Consider enablingcpu_ram_efficient_loadingfor 20 B modelThe 20 B GPT-OSS weights are ~40 GB fp16; on multi-GPU nodes with limited host RAM, enabling this FSDP flag avoids a full duplicate in CPU memory during load.
If hardware allows, uncommenting may reclaim ~30-40 GB RAM at startup.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml(3 hunks)examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml(4 hunks)examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml(3 hunks)requirements.txt(2 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
- requirements.txt
- examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
- src/axolotl/monkeypatch/accelerate/fsdp2.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: pre-commit
🔇 Additional comments (2)
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml (2)
51-51: Warm-up ratio cut from 10 % → 3 %A shorter warm-up can destabilise large-model training, especially with FSDP-2 and mixed precision.
Confirm that LR spikes are not triggering gradient overflow or NaNs in early steps.
13-16: Manual Verification Needed: ConfirmthinkingColumn ExistsThe sandbox environment couldn’t import
datasets, so we couldn’t verify column names automatically. Please manually confirm that theHuggingFaceH4/Multilingual-Thinkingdataset exposes athinkingcolumn (case-sensitive) before merging.You can run a quick check locally, for example:
python - <<'PY' from datasets import load_dataset ds = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:1]") print(ds.column_names) PY• If the column is missing or named differently, update
field_thinkingandtemplate_thinking_keyaccordingly.
• If it exists as expected, no further changes are needed here.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)
31-31: Remove unnecessary pylint disable commentThe comparison
self.cp_size > 1is not a chained comparison, so the pylint disable comment is unnecessary.if ( - self.cp_size > 1 # pylint: disable=chained-comparison + self.cp_size > 1 and self.dp_shard_size <= 1src/axolotl/utils/schemas/validation.py (1)
1150-1162: Simplify nested if statements for better readabilityThe nested if statements can be combined into a single condition.
@model_validator(mode="before") @classmethod def check_gpt_oss_fsdp_loading(cls, data): - if data.get("model_quantization_config", "") == "Mxfp4Config": - if ( - data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) - is True - ): - raise ValueError( - "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." - ) + if ( + data.get("model_quantization_config", "") == "Mxfp4Config" + and data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) is True + ): + raise ValueError( + "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." + ) return data
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/accelerate/parallelism_config.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/loaders/patch_manager.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)
src/axolotl/utils/distributed.py (1)
is_main_process(80-95)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/parallelism_config.py
53-53: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
src/axolotl/utils/schemas/validation.py
1153-1157: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/monkeypatch/accelerate/parallelism_config.py (1)
59-63: LGTM!The patching implementation is correct and follows the standard monkeypatch pattern.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/core/trainers/mixins/distributed_parallel.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#2707
File: src/axolotl/utils/data/sft.py:247-254
Timestamp: 2025-05-29T22:23:39.312Z
Learning: In distributed training scenarios with batch dispatching, placeholder datasets for non-zero ranks may intentionally use temporary files that persist during training. These files are typically very small and don't require explicit cleanup due to their minimal resource impact and specific training requirements.
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
src/axolotl/core/trainers/base.py (1)
create_accelerator_and_postprocess(522-543)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/core/trainers/mixins/distributed_parallel.py (1)
5-5: LGTM: Import addition supports the new functionality.The import of
PartialStateis necessary for the new method implementation and aligns with the accelerator state management pattern used in the codebase.
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Documentation