Conversation
📝 WalkthroughWalkthroughThis change refactors the codebase to replace the sequence parallelism configuration parameter ( Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
|
📖 Documentation Preview: https://688baeb302c8405999dacbe0--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit f8df5bf |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
src/axolotl/utils/ctx_managers/sequence_parallel.py (1)
158-166: Consider removing or implementing the commented code.This commented-out code appears to be an alternative implementation for computing the global token count using all-reduce. Either implement it properly or remove it to avoid confusion.
src/axolotl/utils/schemas/validation.py (1)
1253-1259: Flash attention compatibility patch needs documentation.The aliasing of
_flash_supports_window_sizeto_flash_supports_windowappears to be a compatibility fix. Consider adding a comment explaining why this patching is necessary.try: import transformers.modeling_flash_attention_utils + # Compatibility patch: ring_flash_attn expects _flash_supports_window + # but newer transformers versions may have _flash_supports_window_size # pylint: disable=protected-access transformers.modeling_flash_attention_utils._flash_supports_window_size = ( transformers.modeling_flash_attention_utils._flash_supports_window )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (25)
docs/sequence_parallelism.qmd(3 hunks)requirements.txt(1 hunks)src/axolotl/cli/merge_lora.py(1 hunks)src/axolotl/core/builders/base.py(3 hunks)src/axolotl/core/builders/rl.py(1 hunks)src/axolotl/core/trainers/grpo/__init__.py(1 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)src/axolotl/core/trainers/grpo/sampler.py(4 hunks)src/axolotl/core/trainers/grpo/trainer.py(8 hunks)src/axolotl/core/trainers/mixins/checkpoints.py(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/loaders/model.py(5 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(1 hunks)src/axolotl/monkeypatch/ring_attn/__init__.py(2 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(7 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(5 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(5 hunks)src/axolotl/utils/trainer.py(3 hunks)tests/core/test_builders.py(1 hunks)tests/e2e/multigpu/patched/test_sp.py(1 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)tests/e2e/patched/test_sp.py(6 hunks)
🧰 Additional context used
🧠 Learnings (4)
tests/e2e/patched/test_sp.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
src/axolotl/core/trainers/grpo/__init__.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
src/axolotl/core/builders/rl.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
src/axolotl/utils/schemas/validation.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
🧬 Code Graph Analysis (4)
src/axolotl/monkeypatch/ring_attn/__init__.py (1)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
register_ring_attn_from_device_mesh(248-325)
src/axolotl/core/trainers/mixins/checkpoints.py (1)
src/axolotl/utils/logging.py (1)
warning_once(31-39)
src/axolotl/utils/ctx_managers/sequence_parallel.py (2)
src/axolotl/monkeypatch/ring_attn/patch.py (2)
get_ring_attn_group(62-66)register_ring_attn_from_device_mesh(248-325)tests/e2e/patched/test_sp.py (1)
partial_state(25-28)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
load(151-179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
🔇 Additional comments (44)
src/axolotl/core/trainers/mixins/checkpoints.py (1)
13-23: Excellent improvements for distributed training robustness.The changes enhance error handling and logging for distributed training scenarios:
Exception handling expansion: Adding
KeyErroralongsideNotImplementedErroris appropriate, as optimizer state dictionaries may have missing keys in complex parallelism setups like FSDP2.Improved logging: Using
LOG.warning_oncewithmain_process_only=Truefollows best practices for distributed training by preventing duplicate log messages across processes and avoiding log spam.Documentation: The TODO comment provides valuable context about the known FSDP2 limitation.
These changes align well with the broader parallelism improvements in this PR and enhance the checkpoint saving reliability.
tests/e2e/multigpu/patched/test_sp.py (1)
70-70: LGTM: Parameter rename aligns with codebase refactoring.The change from
sequence_parallel_degreetocontext_parallel_sizeis consistent with the broader renaming effort across the codebase. The test configuration maintains the same functionality while using the updated parameter name.src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
257-259: Ensuredevice_meshandparallelism_configAre Always InitializedBefore using
accelerator.state.device_mesh[ accelerator.state.parallelism_config.model_shard_dim_names ]make sure that:
accelerator.state.parallelism_confighas been assigned (e.g., viaparallelism_config.build_device_mesh(...)) in every model‐loading and accelerator setup path.accelerator.state.device_meshis non-None and contains all keys named inmodel_shard_dim_names.Points to review in your codebase:
- src/axolotl/loaders/model.py around line 410, where
device_mesh = parallelism_config.build_device_mesh("cuda")is created.- Any other initialization branch (e.g., CPU or multi-GPU flows) to confirm they invoke
build_device_meshand assign toaccelerator.state.device_meshbefore FSDP2 is instantiated.src/axolotl/cli/merge_lora.py (1)
73-73: LGTM: Consistent parameter renaming in CLI interface.The change from
sequence_parallel_degreetocontext_parallel_sizemaintains consistency with the broader refactoring effort while preserving the same functionality and default value.src/axolotl/core/trainers/grpo/args.py (1)
16-16: LGTM: Dataclass field rename maintains consistency.The parameter rename from
sequence_parallel_degreetocontext_parallel_sizein theAxolotlGRPOConfigdataclass aligns with the codebase-wide refactoring while preserving the same type annotation and default value.tests/core/test_builders.py (1)
67-67: LGTM: Test configuration updated for parameter consistency.The change from
sequence_parallel_degreetocontext_parallel_sizein the base test configuration fixture ensures all test cases use the updated parameter name while maintaining the same test behavior.src/axolotl/core/builders/rl.py (1)
56-56: LGTM! Clean parameter rename.The parameter rename from
sequence_parallel_degreetocontext_parallel_sizeis consistent with the broader refactoring effort across the codebase. The logic remains identical and correct.src/axolotl/monkeypatch/ring_attn/__init__.py (1)
11-11: LGTM! Proper API exposure for DeviceMesh-based ring attention.The addition of
register_ring_attn_from_device_meshto both the import and__all__list correctly exposes the new DeviceMesh-based ring attention registration function as part of the public API. This aligns with the modernization effort to useaccelerate's device mesh abstractions.Also applies to: 21-21
tests/e2e/multigpu/solo/test_grpo.py (1)
301-301: LGTM! Test configuration updated for parameter rename.The test configuration correctly uses the new
context_parallel_sizeparameter instead of the deprecatedsequence_parallel_degree. The test semantics remain unchanged while staying consistent with the broader parameter renaming effort.src/axolotl/train.py (1)
205-205: LGTM! Consistent parameter rename throughout training logic.Both the conditional check and the parameter passed to
SequenceParallelContextManagerhave been correctly updated to usecontext_parallel_sizeinstead ofsequence_parallel_degree. The logic remains identical while maintaining consistency with the broader refactoring effort.Also applies to: 213-213
docs/sequence_parallelism.qmd (1)
25-25: LGTM! Documentation updated to reflect parameter rename.All documentation references have been consistently updated from
sequence_parallel_degreetocontext_parallel_size. The examples, explanations, and usage descriptions maintain the same semantics while staying aligned with the codebase changes. This ensures users have accurate documentation for the new parameter name.Also applies to: 33-33, 69-69, 92-92, 94-94, 99-99
tests/e2e/patched/test_sp.py (2)
114-114: LGTM: Consistent parameter rename in test.The update from
sequence_parallel_degreetocontext_parallel_sizealigns with the codebase-wide refactoring for clearer terminology.
159-285: LGTM: Comprehensive test coverage maintained with updated parameter name.All test configurations have been systematically updated to use
context_parallel_sizeinstead ofsequence_parallel_degree. The test logic, assertions, and expected behavior remain unchanged, ensuring continued validation of the sequence parallelism functionality with the new parameter name.src/axolotl/core/trainers/grpo/__init__.py (1)
85-86: LGTM: Consistent parameter rename in GRPO configuration.The update from
cfg.sequence_parallel_degreetocfg.context_parallel_sizemaintains the same conditional logic while aligning with the codebase-wide terminology standardization.requirements.txt (1)
16-18: Verify Git dependency usage and pin commit hashesUsing branch names for
transformersand a customacceleratefork introduces non-determinism and potential instability. Please confirm the following:• Intent: Are these Git dependencies temporary until upstream releases include the needed changes, or intended as permanent sources?
• Pinning: For reproducible builds, replace branch references with the exact commit hashes you’ve verified (shown below).
• Timeline: When do you expect the upstream changes to land in an official release?Proposed snippet for
requirements.txt(Lines 16–18):-transformers @ git+https://github.com/huggingface/transformers.git@main -tokenizers>=0.21.1 -accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config +transformers @ git+https://github.com/huggingface/transformers.git@5a81d7e0b388fb2b86fc1279cdc07d9dc7e84b4c +tokenizers>=0.21.1 +accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@168b520279a21a4d3fb89413aacb86c70a2f0a99src/axolotl/utils/trainer.py (3)
445-445: LGTM: Consistent parameter rename in training step calculations.The update from
cfg.sequence_parallel_degreetocfg.context_parallel_sizemaintains the same multiplication logic for total step calculations while aligning with the codebase terminology standardization.
487-487: LGTM: Consistent parameter rename maintained.Same parameter rename applied consistently in the data loader length calculation path.
514-514: LGTM: Consistent parameter rename in fallback calculation.The parameter rename is consistently applied in the fallback total steps calculation, maintaining mathematical correctness.
src/axolotl/loaders/patch_manager.py (1)
264-269: LGTM: Consistent parameter rename in patch application.The parameter rename from
sequence_parallel_degreetocontext_parallel_sizeis applied consistently within the sequence parallel patches method, maintaining the same conditional logic.src/axolotl/utils/schemas/config.py (1)
647-664: LGTM! Well-structured parameter refactoring with proper backwards compatibility.The changes correctly:
- Add the new
dp_shard_sizefield with clear documentation- Maintain backwards compatibility by keeping
sequence_parallel_degreewith deprecation notice- Introduce
context_parallel_sizewith comprehensive documentation explaining its purpose- Follow consistent field definition patterns
src/axolotl/core/trainers/grpo/sampler.py (4)
23-23: LGTM! Documentation updated consistently.The comment correctly reflects the new parameter name
context_parallel_size.
48-48: LGTM! Parameter documentation updated.The docstring parameter description correctly uses the new
context_parallel_sizenaming.
62-62: LGTM! Constructor parameter renamed consistently.The parameter name change from
sequence_parallel_degreetocontext_parallel_sizeis correct and maintains the same functionality.
80-82: LGTM! Internal attribute usage updated consistently.All internal references to the parameter have been correctly updated:
self.context_parallel_sizeself.num_sp_groups = world_size // context_parallel_sizeself.sp_group_id = rank // context_parallel_sizeThe logic remains unchanged, only the parameter name is updated.
src/axolotl/core/trainers/grpo/trainer.py (4)
103-103: LGTM! Calculation updated with new parameter name.The calculation
num_sp_groups = num_processes // self.args.context_parallel_sizecorrectly uses the renamed parameter while maintaining the same logic.
133-138: LGTM! Error message updated consistently.The error message correctly references the new parameter name
context_parallel_sizeand maintains clear, informative messaging about the validation requirements.
170-172: LGTM! Sampler initialization updated correctly.Both the batch size calculation and the sampler parameter are correctly updated to use
context_parallel_size.
238-238: LGTM! All conditional checks and calculations updated consistently.All references to the parameter in:
- Conditional checks (
if self.args.context_parallel_size > 1:)- Variable assignments (
context_parallel_size = self.args.context_parallel_size)- Calculations (
group_leader_rank = sp_group_id * context_parallel_size)are correctly updated while maintaining the same logic flow.
Also applies to: 311-311, 314-315, 322-322, 338-338, 355-355, 362-362, 586-586
src/axolotl/core/builders/base.py (3)
30-30: LGTM! Correct import added for AcceleratorConfig.The import is properly added to support the new accelerator configuration handling.
438-449: LGTM! Improved accelerator configuration handling.The method now properly instantiates
AcceleratorConfigobjects instead of passing raw dictionaries, which provides better type safety and validation. The logic correctly:
- Preserves the
use_configured_stateparameter handling- Handles both cases when config exists and when it doesn't
- Maintains backwards compatibility
514-514: LGTM! Consistent training argument configuration.Setting
"average_tokens_across_devices": Falsealigns with the device mesh parallelism changes mentioned in the PR and ensures consistent token handling across devices.src/axolotl/integrations/liger/args.py (2)
31-37: LGTM! Type annotations modernized.The change from
Optional[bool]tobool | Nonefollows modern Python type annotation practices and improves readability.
64-72: LGTM! Proper validation for incompatible configurations.The new validator correctly prevents the incompatible combination of
liger_rms_normwith tensor parallelism. The implementation:
- Checks the logical conditions properly (
liger_rms_normenabled ANDtensor_parallel_size > 1)- Provides a clear error message explaining the incompatibility
- Includes a helpful reference URL for more context
- Follows the established pattern of other validators in the class
src/axolotl/utils/ctx_managers/sequence_parallel.py (3)
8-8: LGTM!The import of
PartialStatefrom accelerate is appropriate for accessing the device mesh used in the new registration method.
180-181: Parameter renaming looks good.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis consistently applied in the docstring and throughout the class initialization.Also applies to: 192-192, 199-199
243-249: Dimension name “cp” is correct and needs no changeThe
mesh_dim_namesinsrc/axolotl/monkeypatch/ring_attn/patch.pyis always set to either
("dp", "cp")(data + column parallel)("fsdp", "cp")(FSDP + column parallel)This guarantees that
"cp"is a valid dimension in the device mesh. Passingsequence_parallel_dim=("cp",)therefore matches the patched mesh configuration and requires no updates.src/axolotl/loaders/model.py (2)
16-17: Import changes look good.The addition of
PartialStateandParallelismConfigimports and consolidation of distributed utilities are appropriate for the new parallelism configuration functionality.Also applies to: 52-52
646-652: Tensor parallel configuration looks good.The integration with the device mesh from
PartialStateand the removal of incompatibledevice_mapare handled correctly.src/axolotl/monkeypatch/ring_attn/patch.py (3)
166-204: Parameter renaming is consistent.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis properly applied throughout the function, including in assertions, calculations, and logging messages.
248-326: New device mesh-based registration function is well-implemented.The function properly extracts the sequence parallel submesh with appropriate error handling and logging. The integration with different ring attention implementations is handled correctly.
383-427: Device mesh patching updates are correct.The parameter renaming and dimension naming changes are consistent. Good documentation explaining the use of "cp" to match PyTorch native context parallelism naming convention.
src/axolotl/utils/schemas/validation.py (3)
689-689: Parameter renaming in GRPO validation is correct.The check for GRPO + sequence parallel + Liger loss properly uses the new
context_parallel_sizeparameter.
1237-1276: Sequence parallel validation updates are consistent.All references to
sequence_parallel_degreehave been properly updated tocontext_parallel_sizein method names, error messages, and logging statements.Also applies to: 1280-1280
916-940: DeepSpeed configuration update logic improvement.Good change to make the tensor parallel configuration conditional on having a DeepSpeed config file, rather than raising an error when missing. This provides better flexibility for users.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/axolotl/loaders/patch_manager.py (1)
69-69: Re-enable or confirm removal of sequence parallel patchesThis issue has already been raised in a previous review. The call to
_apply_sequence_parallel_patches()is commented out, which disables sequence-parallelism patches.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/loaders/patch_manager.py(3 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(7 hunks)src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/monkeypatch/ring_attn/patch.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
patch_prepare_from_posids(77-87)src/axolotl/monkeypatch/ring_attn/patch.py (1)
patch_prepare_device_mesh(389-433)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py
83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (3)
52-52: LGTM! Correct placement of transformers patches.The
_apply_transformers_patches()call is appropriately placed at the beginning of the pre-model load patches sequence.
72-78: LGTM! Clean implementation of transformers patches.The method correctly imports and applies the Flash Attention utility patch following the established pattern.
272-277: LGTM! Correct parameter renaming.The method correctly updates all references from
sequence_parallel_degreetocontext_parallel_size, maintaining the same logic while aligning with the PR's refactoring objectives.src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
12-74: LGTM! Well-implemented Flash Attention utility function.The function correctly prepares tensors for Flash Attention with proper handling of position IDs and sequence lengths. The documentation and TorchDynamo compatibility note are particularly helpful.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
train.yaml (3)
70-70: Add trailing newline to satisfy YAML lint
The linter error (no new line character at the end of file) can break some tooling that strictly enforces POSIX-style text files.# save_first_step: true # uncomment this to validate checkpoint saving works with your config +
30-33: Eliminate null-valued WandB keys to avoid silent mis-configuration
wandb_watch:andwandb_log_model:are currently set to explicitnull(empty value). The Axolotl config loader passes these directly to the Weights-and-Biases SDK; null values can override sensible defaults and disable expected logging.- wandb_watch: - wandb_log_model: + # wandb_watch: gradients # valid options: false | gradients | all + # wandb_log_model: checkpoint # valid options: false | checkpoint | artifactConfirm the intended behaviour or comment out the keys entirely.
25-27: Largesequence_len+ sample packing may exhaust memory
Asequence_lenof 4096 combined withsample_packing: true, flash-attention and gradient-checkpointing can easily exceed the memory footprint of 8-bit AdamW on 8 × 24 GB GPUs.Consider:
- Reducing
sequence_lenduring warm-up or- Enabling
q_lora/ LoRA adapters- Increasing
gradient_accumulation_stepsto trade compute for memory.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/loaders/model.py(5 hunks)train.yaml(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/loaders/model.py
🧰 Additional context used
🪛 YAMLlint (1.37.1)
train.yaml
[error] 70-70: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (2)
train.yaml (2)
44-46: Verifycontext_parallel_sizeagainst available GPUs and upstream patches
Settingcontext_parallel_size: 8requires exactly eight devices per node (or an 8-way device mesh) and the newaccelerate&transformersPRs referenced in the main description. Mismatching hardware or an un-patched dependency will raise runtime errors during mesh creation.Double-check that:
huggingface/acceleratePR 3682 is installed.- You actually have 8 GPUs per process group, or override with
CUDA_VISIBLE_DEVICES.dp_shard_size: 1is intentional; otherwise the total world size becomescontext_parallel_size * dp_shard_size.
58-60: Ensure placeholder special tokens exist in the tokenizer
<|finetune_right_pad_id|>and<|eot_id|>must already be present in the Qwen tokenizer vocab. If they are not, loading will fail or silently mis-align IDs.If the base tokenizer lacks these tokens, add them via
--add-special-tokenspreprocessing or switch to the built-in<|endoftext|>/<|pad|>tokens.
| sequence_parallel_degree: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Deprecated: use `context_parallel_size` instead" | ||
| }, | ||
| ) |
There was a problem hiding this comment.
Not sure if I missed this. Should we add this as an alias to context_parallel_size or make sure to add migration from the old config to new one in validator?
There was a problem hiding this comment.
+1 need to raise deprecation warning and migrate configs
|
Could we also have an doc that explains how to compose TP/Hybrid Sharding/CP? Like, when to use TP, how to shard intra-node & replicate across nodes? I'll link it from README. |
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
254-332: Extract duplicated ring attention function setup logic.The ring_attn_func handling logic (lines 306-331) is duplicated from the
register_ring_attnfunction (lines 220-252). This violates the DRY principle and makes maintenance harder.Extract the common logic into a separate function:
+def _setup_ring_attn_func(ring_attn_func: RingAttnFunc | None, heads_k_stride: int | None): + """Setup ring attention function implementation.""" + if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: + # fmt: off + # pylint: disable=protected-access + import transformers.modeling_flash_attention_utils + transformers.modeling_flash_attention_utils._flash_supports_window_size = ( + transformers.modeling_flash_attention_utils._flash_supports_window + ) + + import ring_flash_attn.adapters.hf_adapter + + from ring_flash_attn.adapters.hf_adapter import ( # isort: skip # pylint: disable=unused-import + create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig, + ) + + create_ring_flash_attention_forward_orig = ( # noqa: F811,F841 + create_ring_flash_attention_forward + ) + ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward + # fmt: on + + ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 + ) + elif ring_attn_func is RingAttnFunc.BATCH_RING: + from axolotl.monkeypatch.ring_attn.adapters.batch import ( + substitute_hf_flash_attn, + ) + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), + ring_attn_func=ring_attn_func, + )Then use it in both functions:
- if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: - # ... (lines 220-242) - elif ring_attn_func is RingAttnFunc.BATCH_RING: - # ... (lines 243-251) + _setup_ring_attn_func(ring_attn_func, heads_k_stride)
♻️ Duplicate comments (1)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
77-87: Remove redundant patching withsetattr.The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.
Apply this fix to remove the redundancy:
def patch_prepare_from_posids(): import transformers.modeling_flash_attention_utils transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access _prepare_from_posids ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_prepare_from_posids", - _prepare_from_posids, - )
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)
1221-1227: Consider moving the transformers patching to a more appropriate location.The transformers attribute aliasing (lines 1221-1226) is performed during validation, which seems like an unexpected place for runtime patching. This patching appears to be a prerequisite for importing
ring_flash_attn.Consider moving this patching logic to the patch manager or a dedicated initialization module where other patches are applied, rather than in the validation logic.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (28)
docs/sequence_parallelism.qmd(3 hunks)requirements.txt(1 hunks)src/axolotl/cli/merge_lora.py(1 hunks)src/axolotl/core/builders/base.py(3 hunks)src/axolotl/core/builders/rl.py(1 hunks)src/axolotl/core/trainers/grpo/__init__.py(1 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)src/axolotl/core/trainers/grpo/sampler.py(4 hunks)src/axolotl/core/trainers/grpo/trainer.py(8 hunks)src/axolotl/core/trainers/mixins/checkpoints.py(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/loaders/model.py(5 hunks)src/axolotl/loaders/patch_manager.py(3 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(1 hunks)src/axolotl/monkeypatch/ring_attn/__init__.py(2 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(7 hunks)src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py(1 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(5 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(5 hunks)src/axolotl/utils/trainer.py(3 hunks)tests/core/test_builders.py(1 hunks)tests/e2e/multigpu/patched/test_sp.py(1 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)tests/e2e/multigpu/test_fp8_fsdp2.py(2 hunks)tests/e2e/patched/test_sp.py(6 hunks)tests/e2e/utils.py(1 hunks)
✅ Files skipped from review due to trivial changes (7)
- tests/e2e/multigpu/patched/test_sp.py
- src/axolotl/core/trainers/grpo/args.py
- src/axolotl/monkeypatch/ring_attn/init.py
- docs/sequence_parallelism.qmd
- src/axolotl/utils/trainer.py
- tests/e2e/multigpu/solo/test_grpo.py
- src/axolotl/core/builders/rl.py
🚧 Files skipped from review as they are similar to previous changes (18)
- tests/core/test_builders.py
- requirements.txt
- src/axolotl/train.py
- src/axolotl/cli/merge_lora.py
- tests/e2e/utils.py
- src/axolotl/core/trainers/grpo/init.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/loaders/patch_manager.py
- tests/e2e/patched/test_sp.py
- tests/e2e/multigpu/test_fp8_fsdp2.py
- src/axolotl/core/trainers/mixins/checkpoints.py
- src/axolotl/utils/ctx_managers/sequence_parallel.py
- src/axolotl/monkeypatch/accelerate/fsdp2.py
- src/axolotl/core/trainers/grpo/trainer.py
- src/axolotl/loaders/model.py
- src/axolotl/core/builders/base.py
- src/axolotl/integrations/liger/args.py
- src/axolotl/core/trainers/grpo/sampler.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
load(151-180)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py
83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
🔇 Additional comments (4)
src/axolotl/monkeypatch/ring_attn/patch.py (2)
165-252: LGTM! Consistent renaming and enhanced ring attention support.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis consistently applied throughout the function. The addition of support for multiple ring attention implementations (VARLEN_LLAMA3 and BATCH_RING) enhances the flexibility of the system.
389-433: LGTM! Consistent renaming and alignment with PyTorch conventions.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis properly implemented. Good decision to use "cp" instead of "sp" for the device mesh dimension name to align with PyTorch's native context parallelism implementation.src/axolotl/utils/schemas/validation.py (2)
900-927: Verify the removal of DeepSpeed config requirement is intentional.The validation no longer raises an error when
tensor_parallel_size > 1but DeepSpeed config is missing. This is a significant behavior change that makes the validation more permissive.Please confirm this change is intentional and that tensor parallelism can now work without DeepSpeed configuration.
1205-1244: LGTM! Consistent renaming and improved error messages.The renaming from
check_sequence_parallel_degreetocheck_context_parallel_sizeis properly implemented with all references, error messages, and warnings updated to use the new terminology.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/loaders/model.py(5 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(1 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/utils/ctx_managers/sequence_parallel.py
- src/axolotl/loaders/model.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/accelerate/fsdp2.py
260-260: Found useless expression. Either assign it to a variable or remove it.
(B018)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
257-257: LGTM! Device mesh integration enhances parallelism support.The addition of the mesh parameter properly integrates the device mesh from accelerator state into FSDP2 model preparation, aligning with the PR's parallelism configuration improvements.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)
112-128: Enhanced logging function with improved memory reporting.The updated
log_gpu_memory_usagefunction leverages the new centralized memory retrieval and provides clearer logging:
- Uses the new
get_gpu_memory_usagehelper- Updates variable names to match the new memory metrics (active, allocated, reserved)
- Maintains backward compatibility with the existing logging interface
However, there's a minor issue with the conditional string formatting that can be simplified.
Apply this fix to simplify the conditional string formatting:
- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/axolotl/loaders/model.py(5 hunks)src/axolotl/utils/bench.py(2 hunks)src/axolotl/utils/callbacks/__init__.py(2 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(5 hunks)train.yaml(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- train.yaml
🚧 Files skipped from review as they are similar to previous changes (3)
- src/axolotl/utils/schemas/config.py
- src/axolotl/loaders/model.py
- src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
src/axolotl/utils/callbacks/__init__.py (1)
src/axolotl/utils/bench.py (2)
get_gpu_memory_usage(96-109)log_gpu_memory_usage(112-128)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py
124-124: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/utils/callbacks/__init__.py (2)
38-38: LGTM: Import statement correctly updated.The addition of
get_gpu_memory_usageimport is properly aligned with its usage in the updated callback implementation.
111-123: Improved GPU memory monitoring with proper step handling.The changes enhance GPU memory monitoring by:
- Removing the one-time logging guard (
loggedattribute) for continuous monitoring- Adding detailed WandB metrics (active, allocated, reserved memory)
- Maintaining existing debug logging functionality
The implementation correctly uses
state.global_step > 0to avoid logging on the initial step and properly gates WandB logging with both configuration and process checks.src/axolotl/utils/bench.py (3)
60-64: Improved memory metrics with peak usage tracking.The refactoring correctly updates the memory metrics to track:
- Peak active memory from CUDA memory stats
- Maximum allocated memory (peak allocation)
- Maximum reserved memory (peak reservation)
This provides more meaningful monitoring data compared to current memory usage, especially for understanding memory pressure during training.
96-109: Well-designed centralized memory retrieval function.The new
get_gpu_memory_usagefunction effectively centralizes device-specific memory retrieval logic with proper error handling for unsupported devices. The function correctly:
- Uses
get_device_type()to determine the current device- Handles MPS, NPU, and CUDA devices appropriately
- Raises a clear error for unsupported device types
102-102: Fix: Corrected device type check.Good catch! The device type check was properly corrected from
"gpu"to"cuda"to match the actual CUDA device type string.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/utils/bench.py (1)
119-122: Critical issue: Incorrect variable usage persists.The same bug identified in previous reviews still exists - both conditions incorrectly use
reservedinstead of the appropriate variables.Apply this fix to correct the variable usage:
- if allocated > 0: - extras.append(f"+{reserved:.03f}GB allocated") - if reserved > 0: - extras.append(f"+{reserved:.03f}GB reserved") + if allocated > 0: + extras.append(f"+{allocated:.03f}GB allocated") + if reserved > 0: + extras.append(f"+{reserved:.03f}GB reserved")
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)
123-123: Apply static analysis suggestion for cleaner conditional.The static analysis tool correctly identifies a more readable conditional expression pattern.
Apply this improvement:
- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/utils/bench.py(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py
123-123: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/utils/bench.py (3)
60-63: LGTM: Improved memory reporting with peak metrics.The refactoring from current memory usage to peak/max metrics provides more meaningful insights for performance analysis. The implementation correctly uses:
memory_stats().get("active_bytes.all.peak", 0)for peak active memorymax_memory_allocated()for maximum allocated memorymax_memory_reserved()for maximum reserved memory
95-108: Well-designed centralized memory retrieval function.The new
get_gpu_memory_usagefunction effectively centralizes device-specific memory logic with proper error handling for unsupported devices. The device type detection usingget_device_type()is robust and the fallback error message is informative.
111-127: Good refactoring with enhanced logging capabilities.The function refactoring successfully:
- Uses the new centralized
get_gpu_memory_usagefunction- Updates terminology to more accurate labels ("active", "allocated", "reserved")
- Maintains backward compatibility with existing function signature
- Improves logging with debug level and proper stack level
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)
123-123: Apply static analysis suggestion for improved readability.The conditional expression can be simplified for better readability.
- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/utils/bench.py(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py
123-123: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/bench.py (2)
60-63: LGTM! Enhanced memory reporting with peak/max statistics.The updated implementation provides more valuable memory insights by reporting peak active memory and maximum allocated/reserved memory instead of current usage. This is particularly useful for training monitoring and debugging memory issues.
111-127: LGTM! Function refactored with proper variable usage.The refactoring successfully:
- Uses the centralized
get_gpu_memory_usagefunction- Adds proper type hints
- Uses more descriptive variable names (active/allocated/reserved)
- Correctly fixes the previous issue where both conditions used
reservedinstead of the appropriate variablesThe debug-level logging is appropriate for detailed memory monitoring.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/utils/bench.py (1)
95-109: Address previous review feedback on device detection.The current implementation still has the device detection issues flagged in previous reviews:
- Fragile string comparisons: Lines 99 and 101 use string containment checks which are unreliable
- Inconsistent MPS detection: Line 97 checks MPS availability independently of device type
- Unnecessary ValueError: Based on winglian's feedback, raising errors may not be appropriate for this non-critical utility
Consider this approach based on previous feedback:
def get_gpu_memory_usage(device: int | torch.device = 0): - cur_device_type = str(get_device_type()) - if torch.backends.mps.is_available(): + cur_device = get_device_type() + if cur_device.type == "mps": usage, cache, misc = mps_memory_usage_all() - elif "npu" in cur_device_type and is_torch_npu_available(): + elif cur_device.type == "npu": usage, cache, misc = npu_memory_usage_all(device) - elif "cuda" in cur_device_type and torch.cuda.is_available(): + elif cur_device.type == "cuda": usage, cache, misc = gpu_memory_usage_all(device) else: - raise ValueError( - f"Unable to determine memory statistics for current device {device}" - ) + # Return default values for unsupported devices (CPU, etc.) + return 0.0, 0.0, 0.0 return usage, cache, misc
🧹 Nitpick comments (1)
src/axolotl/utils/bench.py (1)
127-127: Apply style improvement suggested by static analysis.The ternary operator can be simplified for better readability.
- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/utils/bench.py(2 hunks)src/axolotl/utils/callbacks/__init__.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/utils/callbacks/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py
127-127: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (4)
src/axolotl/utils/bench.py (4)
60-63: LGTM: Improved memory metrics for benchmarking.The change from current memory usage to peak/max memory metrics is a good improvement for benchmarking purposes. Using
memory_stats().get("active_bytes.all.peak"),max_memory_allocated(), andmax_memory_reserved()provides more valuable insights into actual memory requirements during training.
117-120: Good exception handling for unsupported devices.The try-catch block properly handles cases where
get_gpu_memory_usage()raises a ValueError (e.g., for CPU-only environments), gracefully returning without logging.
123-126: Variable usage bug has been correctly fixed.The previous issue with using
reservedfor both allocated and reserved memory has been properly addressed. The code now correctly usesallocatedfor the allocated memory string andreservedfor the reserved memory string.
111-131: Well-structured refactoring with improved separation of concerns.The refactoring successfully centralizes memory usage logic and improves the logging function with:
- Proper type hints for better code quality
- Clear separation between memory retrieval and logging
- Appropriate exception handling for unsupported devices
- Updated variable names reflecting the new peak/max memory metrics
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
cicd/single_gpu.py (1)
68-72: LGTM! Consider making the process count configurable.The explicit environment configuration for dataset processing is good for CI/CD consistency. The hardcoded value of "8" processes aligns with the broader parallelism improvements in this PR.
Consider making this configurable via an environment variable:
sp_env = os.environ.copy() -sp_env["AXOLOTL_DATASET_PROCESSES"] = "8" +sp_env["AXOLOTL_DATASET_PROCESSES"] = os.environ.get("AXOLOTL_DATASET_PROCESSES", "8")src/axolotl/loaders/model.py (1)
400-452: Excellent implementation of parallelism configuration calculation.The static method correctly computes parallelism dimensions by systematically dividing the world size. The validation logic ensures compatibility and provides clear error messages for misconfigurations.
Consider simplifying the nested if statements as suggested by static analysis:
- if dp_shard_size is None and dp_replicate_size in (None, 1): - if remaining_world_size > 1: - pc_kwargs["dp_shard_size"] = remaining_world_size - remaining_world_size = 1 + if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1- if remaining_world_size > 1: - if "dp_shard_size" not in pc_kwargs and is_fsdp: - pc_kwargs["dp_shard_size"] = remaining_world_size - remaining_world_size = 1 + if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
cicd/single_gpu.py(1 hunks)src/axolotl/loaders/model.py(5 hunks)src/axolotl/utils/data/shared.py(1 hunks)tests/test_loaders.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
cicd/single_gpu.py (1)
Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.
🧬 Code Graph Analysis (1)
tests/test_loaders.py (1)
src/axolotl/loaders/model.py (1)
_get_parallel_config_kwargs(401-452)
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py
420-421: Use a single if statement instead of nested if statements
(SIM102)
441-442: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
🔇 Additional comments (6)
src/axolotl/utils/data/shared.py (1)
433-437: Excellent refactoring! Improves code readability.Replacing the magic number "8" with the descriptive
min_rows_per_proc = 256makes the intent much clearer. This change aligns well with the parallelism configuration improvements throughout the PR.tests/test_loaders.py (1)
175-214: Excellent test coverage for parallelism configuration logic.The parameterized test comprehensively covers various combinations of parallelism settings. The test cases effectively validate the
_get_parallel_config_kwargsmethod's behavior for different world sizes and parallelism configurations.src/axolotl/loaders/model.py (4)
16-17: Good addition of accelerate imports for parallelism support.The imports for
PartialStateandParallelismConfigare essential for the new parallelism configuration functionality.
184-192: Well-structured conditional logic for parallel configuration.The logic correctly determines when to enable parallel configuration based on FSDP, tensor parallelism, and context parallelism settings, with proper handling of FSDP version constraints.
454-472: Solid integration with accelerate's parallelism infrastructure.The method correctly creates a
ParallelismConfigand device mesh, then properly configures thePartialState. This establishes the parallelism configuration early in the model loading process.
707-712: Proper tensor parallelism configuration in model building.The code correctly sets up tensor parallelism parameters and removes the incompatible
device_mapwhen tensor parallelism is enabled, which prevents conflicts between device mapping and tensor parallel plans.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/axolotl/loaders/model.py (1)
408-460: Add validation for world size divisibility.The method should validate that the world size is evenly divisible by each parallelism factor before performing division to prevent runtime errors.
@staticmethod def _get_parallel_config_kwargs( world_size: int, tensor_parallel_size: int = 1, context_parallel_size: int = 1, dp_shard_size: int | None = None, dp_replicate_size: int | None = None, is_fsdp: bool = False, ): pc_kwargs = {} remaining_world_size = world_size if tensor_parallel_size and tensor_parallel_size > 1: + if remaining_world_size % tensor_parallel_size != 0: + raise ValueError( + f"World size ({world_size}) must be divisible by tensor_parallel_size ({tensor_parallel_size})" + ) pc_kwargs["tp_size"] = tensor_parallel_size remaining_world_size = remaining_world_size // tensor_parallel_size if context_parallel_size and context_parallel_size > 1: + if remaining_world_size % context_parallel_size != 0: + raise ValueError( + f"Remaining world size ({remaining_world_size}) must be divisible by context_parallel_size ({context_parallel_size})" + ) pc_kwargs["cp_size"] = context_parallel_size remaining_world_size = remaining_world_size // context_parallel_sizeAlso consider simplifying the nested if statements at lines 428-429 and 449-450:
- if dp_shard_size is None and dp_replicate_size in (None, 1): - if remaining_world_size > 1: + if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1: pc_kwargs["dp_shard_size"] = remaining_world_size remaining_world_size = 1
🧹 Nitpick comments (8)
src/axolotl/utils/environment.py (1)
34-37: Consider adding error handling for missing packages.The function may raise
PackageNotFoundErrorif the package isn't installed. Consider documenting this behavior or adding error handling depending on the expected usage pattern.def get_package_version(package: str) -> Version: + """Get the installed version of a package. + + Args: + package: The name of the package to check. + + Returns: + The parsed version of the package. + + Raises: + PackageNotFoundError: If the package is not installed. + """ version_str = version(package) return parse(version_str)src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)
8-18: Document the temporary nature of this monkeypatch.This patch sets protected attributes on the model, which suggests it's working around a limitation in the transformers library. Consider adding a comment explaining why this is needed and referencing the upstream PR (transformers #39622) that may eliminate the need for this patch.
def distribute_model(model, distributed_config, device_mesh, tp_size): + """ + Wrapper for transformers' distribute_model that adds tensor parallel metadata. + + This is a temporary workaround until transformers properly exposes TP information. + See: https://github.com/huggingface/transformers/pull/39622 + """ res = transformers.integrations.tensor_parallel.distribute_model(src/axolotl/core/trainers/mixins/dist_parallel.py (1)
25-27: Document the upstream dependency for this workaround.Consider adding a reference to the accelerate PR that will eliminate the need for this workaround.
# check for device mesh as we don't worry about this for DDP and it wouldn't be set # and is only specific to older accelerate atm + # This workaround can be removed once accelerate PR #3682 is merged and released if "device_mesh" in PartialState()._shared_state:src/axolotl/integrations/liger/args.py (1)
68-83: Validators correctly enforce tensor parallelism constraints.Both validators properly check for incompatibilities between Liger features and tensor parallelism. The reference to the GitHub issue in the first validator provides helpful context.
The TODO comment indicates this might need a more comprehensive fix. Would you like me to help investigate a more permanent solution or open an issue to track this technical debt?
src/axolotl/monkeypatch/accelerate/distributed.py (3)
5-5: Consider reducing pylint disable scope.The broad pylint disable includes important checks like
protected-accessandinconsistent-return-statements. Consider:
- Addressing the underlying issues instead of disabling the checks
- Using more targeted inline disables where absolutely necessary
- Documenting why each disable is required
165-180: Addstrict=Trueto zip for safer tuple unpacking.The static analysis correctly identifies that
zip()should use thestrictparameter to ensure both sequences have the same length.- return tuple(zip(*sorted_items)) + return tuple(zip(*sorted_items, strict=True))
214-219: Use more Pythonic dict membership check.- assert ( - parallelism in self._sizes.keys() - ), f"Parallelism must be one of {self._sizes.keys()}" + assert ( + parallelism in self._sizes + ), f"Parallelism must be one of {list(self._sizes.keys())}"src/axolotl/loaders/model.py (1)
842-845: Simplify nested condition and document protected member access.- if self.cfg.tensor_parallel_size > 1: - if self.model._tp_size != self.cfg.tensor_parallel_size: + if self.cfg.tensor_parallel_size > 1 and self.model._tp_size != self.cfg.tensor_parallel_size: self.model._tp_size = self.cfg.tensor_parallel_sizeConsider adding a comment explaining why accessing the protected
_tp_sizeattribute is necessary for tensor parallelism functionality.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
requirements.txt(1 hunks)setup.py(2 hunks)src/axolotl/core/builders/base.py(3 hunks)src/axolotl/core/trainers/base.py(2 hunks)src/axolotl/core/trainers/mixins/__init__.py(1 hunks)src/axolotl/core/trainers/mixins/dist_parallel.py(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/loaders/patch_manager.py(3 hunks)src/axolotl/monkeypatch/accelerate/distributed.py(1 hunks)src/axolotl/monkeypatch/transformers/tensor_parallel.py(1 hunks)src/axolotl/utils/environment.py(2 hunks)src/axolotl/utils/schemas/validation.py(5 hunks)tests/e2e/multigpu/patched/test_sp.py(2 hunks)tests/e2e/multigpu/test_tp.py(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- src/axolotl/core/trainers/mixins/init.py
- src/axolotl/core/trainers/base.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/e2e/multigpu/patched/test_sp.py
- requirements.txt
- src/axolotl/loaders/patch_manager.py
- src/axolotl/core/builders/base.py
- src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/core/trainers/mixins/dist_parallel.py
20-22: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/monkeypatch/transformers/tensor_parallel.py
22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
src/axolotl/monkeypatch/accelerate/distributed.py
180-180: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
216-216: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
src/axolotl/loaders/model.py
428-429: Use a single if statement instead of nested if statements
(SIM102)
449-450: Use a single if statement instead of nested if statements
(SIM102)
842-843: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: pre-commit
🔇 Additional comments (7)
setup.py (1)
30-30: Version compatibility handling looks correct.The conditional vllm version management based on PyTorch version is appropriate, especially given the incompatibility between vllm 0.9.x and the updated transformers version.
Also applies to: 72-76
tests/e2e/multigpu/test_tp.py (1)
13-64: Well-structured tensor parallelism test.The test appropriately validates basic tensor parallelism functionality with a minimal configuration. The use of a small model and few training steps makes it suitable for CI/CD pipelines.
src/axolotl/integrations/liger/args.py (1)
31-37: LGTM! Modern type annotation syntax.The update to use
bool | Noneinstead ofOptional[bool]follows modern Python 3.10+ conventions and improves code readability.src/axolotl/monkeypatch/accelerate/distributed.py (1)
182-212: Robust validation logic for parallelism configurations.The validation comprehensively checks size constraints and enforces sensible limitations on parallelism combinations. The error messages provide clear guidance for users.
src/axolotl/loaders/model.py (3)
88-89: LGTM! Well-structured parallelism configuration attributes.The class attributes are properly typed and initialized with sensible defaults.
187-200: Well-structured parallelism setup logic.The method correctly determines when to enable parallelism configuration based on FSDP and parallel size settings, with appropriate version compatibility checks.
722-728: Correct tensor parallelism configuration for model initialization.The setup properly configures tensor parallelism arguments and correctly removes incompatible device_map when using tp_plan.
| if not is_package_version_ge("accelerate", "1.10.0"): | ||
| # pylint: disable=protected-access | ||
| if int(os.environ.get("WORLD_SIZE", "1")) > 1: | ||
| from accelerate.state import PartialState | ||
|
|
||
| # check for device mesh as we don't worry about this for DDP and it wouldn't be set | ||
| # and is only specific to older accelerate atm | ||
| if "device_mesh" in PartialState()._shared_state: | ||
| device_mesh: DeviceMesh = PartialState()._shared_state[ | ||
| "device_mesh" | ||
| ] | ||
| mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names | ||
| if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1: | ||
| self.accelerator.state.distributed_type = "TP" | ||
| PartialState().distributed_type = "TP" | ||
| tp_plugin = TorchTensorParallelPlugin( | ||
| tp_size=device_mesh["tp"].size(), | ||
| torch_device_mesh=device_mesh, | ||
| ) | ||
| self.accelerator.state.torch_tp_plugin = tp_plugin | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Consider combining nested if statements and addressing Python compatibility.
The implementation has a few areas for improvement:
- The nested if statements can be combined as suggested by static analysis
- The type hint
tuple[str, ...] | Nonerequires Python 3.10+
- if not is_package_version_ge("accelerate", "1.10.0"):
- # pylint: disable=protected-access
- if int(os.environ.get("WORLD_SIZE", "1")) > 1:
+ if (
+ not is_package_version_ge("accelerate", "1.10.0")
+ and int(os.environ.get("WORLD_SIZE", "1")) > 1
+ ):
+ # pylint: disable=protected-access
from accelerate.state import PartialState
# check for device mesh as we don't worry about this for DDP and it wouldn't be set
# and is only specific to older accelerate atm
if "device_mesh" in PartialState()._shared_state:
device_mesh: DeviceMesh = PartialState()._shared_state[
"device_mesh"
]
- mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names
+ mesh_dim_names = device_mesh.mesh_dim_names
if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not is_package_version_ge("accelerate", "1.10.0"): | |
| # pylint: disable=protected-access | |
| if int(os.environ.get("WORLD_SIZE", "1")) > 1: | |
| from accelerate.state import PartialState | |
| # check for device mesh as we don't worry about this for DDP and it wouldn't be set | |
| # and is only specific to older accelerate atm | |
| if "device_mesh" in PartialState()._shared_state: | |
| device_mesh: DeviceMesh = PartialState()._shared_state[ | |
| "device_mesh" | |
| ] | |
| mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names | |
| if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1: | |
| self.accelerator.state.distributed_type = "TP" | |
| PartialState().distributed_type = "TP" | |
| tp_plugin = TorchTensorParallelPlugin( | |
| tp_size=device_mesh["tp"].size(), | |
| torch_device_mesh=device_mesh, | |
| ) | |
| self.accelerator.state.torch_tp_plugin = tp_plugin | |
| if ( | |
| not is_package_version_ge("accelerate", "1.10.0") | |
| and int(os.environ.get("WORLD_SIZE", "1")) > 1 | |
| ): | |
| # pylint: disable=protected-access | |
| from accelerate.state import PartialState | |
| # check for device mesh as we don't worry about this for DDP and it wouldn't be set | |
| # and is only specific to older accelerate atm | |
| if "device_mesh" in PartialState()._shared_state: | |
| device_mesh: DeviceMesh = PartialState()._shared_state[ | |
| "device_mesh" | |
| ] | |
| - mesh_dim_names: tuple[str, ...] | None = device_mesh.mesh_dim_names | |
| + mesh_dim_names = device_mesh.mesh_dim_names | |
| if "tp" in mesh_dim_names and device_mesh["tp"].size() > 1: | |
| self.accelerator.state.distributed_type = "TP" | |
| PartialState().distributed_type = "TP" | |
| tp_plugin = TorchTensorParallelPlugin( | |
| tp_size=device_mesh["tp"].size(), | |
| torch_device_mesh=device_mesh, | |
| ) | |
| self.accelerator.state.torch_tp_plugin = tp_plugin |
🧰 Tools
🪛 Ruff (0.12.2)
20-22: Use a single if statement instead of nested if statements
(SIM102)
🤖 Prompt for AI Agents
In src/axolotl/core/trainers/mixins/dist_parallel.py around lines 20 to 40,
combine the nested if statements checking the package version and WORLD_SIZE
into a single condition to simplify the logic. Also, replace the type hint
`tuple[str, ...] | None` with a syntax compatible with earlier Python versions,
such as using `Optional[Tuple[str, ...]]` and importing the necessary types from
the typing module to maintain compatibility below Python 3.10.
| def patch_tp_fix(): | ||
| transformers.integrations.tensor_parallel.distribute_model = distribute_model | ||
| setattr( | ||
| sys.modules["transformers.integrations.tensor_parallel"], | ||
| "distribute_model", | ||
| distribute_model, | ||
| ) |
There was a problem hiding this comment.
Remove redundant setattr call.
The static analysis correctly identifies that the setattr call is redundant. The direct assignment on line 21 is sufficient to patch the module.
def patch_tp_fix():
transformers.integrations.tensor_parallel.distribute_model = distribute_model
- setattr(
- sys.modules["transformers.integrations.tensor_parallel"],
- "distribute_model",
- distribute_model,
- )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def patch_tp_fix(): | |
| transformers.integrations.tensor_parallel.distribute_model = distribute_model | |
| setattr( | |
| sys.modules["transformers.integrations.tensor_parallel"], | |
| "distribute_model", | |
| distribute_model, | |
| ) | |
| def patch_tp_fix(): | |
| transformers.integrations.tensor_parallel.distribute_model = distribute_model |
🧰 Tools
🪛 Ruff (0.12.2)
22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/transformers/tensor_parallel.py between lines 20 and
26, remove the redundant setattr call that sets "distribute_model" on the module
since the direct assignment on line 21 already patches the module correctly.
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
src/axolotl/utils/schemas/validation.py (1)
1280-1292: Include DistributedValidationMixin in ValidationMixin inheritanceThe
DistributedValidationMixin(defined at lines 1265–1273) provides thecheck_tensor_parallel_optimizervalidator for distributed training, but it isn’t currently inherited byValidationMixin(lines 1280–1292). Without it, distributed-specific validations won’t run.File:
src/axolotl/utils/schemas/validation.py
Location:class ValidationMixindefinitionSuggested diff:
class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, TrainingValidationMixin, LoRAValidationMixin, RLValidationMixin, OptimizationValidationMixin, SystemValidationMixin, ChatTemplateValidationMixin, PretrainingValidationMixin, ModelCompatibilityValidationMixin, + DistributedValidationMixin, ComplexValidationMixin, ):
♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)
408-460: Add validation for world size divisibility.The method correctly calculates the data parallel replicate size, but should validate that the world size is evenly divisible by all parallelism factors to prevent runtime errors.
462-488: Replace print with logging and document protected member access.Two issues to address:
- Line 474 uses
print()instead of the logger- Lines 481-487 access protected members of
PartialState
🧹 Nitpick comments (3)
cicd/multigpu.sh (1)
24-25: Consider surfacing upload failures instead of silencing them
|| trueneutralises any non-zero exit fromcodecov, which means genuine issues (e.g., malformed XML, network outages) won’t be visible in CI. If the upload is meant to be non-blocking, emit a warning so it’s still detectable in logs:- codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true + if ! codecov upload-process -t "${CODECOV_TOKEN}" \ + -f multigpu-coverage.xml \ + -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}; then + echo "[WARN] Codecov upload failed but CI will proceed." + fisrc/axolotl/loaders/model.py (1)
841-847: Consider simplifying nested if statements.The workaround logic is necessary and well-documented. Consider combining the nested if statements for cleaner code:
- if self.cfg.tensor_parallel_size > 1: - # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh - # TODO(wing): remove once 4.54.1 is released - if self.model._tp_size != self.cfg.tensor_parallel_size: + if (self.cfg.tensor_parallel_size > 1 + and self.model._tp_size != self.cfg.tensor_parallel_size): + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is releasedsrc/axolotl/utils/schemas/validation.py (1)
1265-1277: LGTM! Important validation for tensor parallelism compatibility.The new validation correctly prevents incompatible optimizer configurations with tensor parallelism. Consider simplifying the nested if statements:
- if self.tensor_parallel_size > 1: - if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: + if (self.tensor_parallel_size > 1 + and self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]):
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
cicd/multigpu.sh(1 hunks)setup.py(2 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/utils/samplers/multipack.py(2 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/samplers/multipack.py
🧰 Additional context used
🧠 Learnings (1)
src/axolotl/utils/schemas/validation.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py
428-429: Use a single if statement instead of nested if statements
(SIM102)
449-450: Use a single if statement instead of nested if statements
(SIM102)
842-845: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/utils/schemas/validation.py
1270-1271: Use a single if statement instead of nested if statements
(SIM102)
🔇 Additional comments (10)
cicd/multigpu.sh (1)
22-23: Guard clause for missing token is a solid reliability winWrapping the upload in an
if [ -n "$CODECOV_TOKEN" ]block prevents noisy failures on forks / local runs where the secret isn’t present.setup.py (1)
30-30: LGTM! Consistent pattern for version extraction.The vllm version extraction follows the same pattern as xformers and autoawq versions.
src/axolotl/loaders/model.py (4)
16-16: LGTM! Import changes support new parallelism functionality.The added imports for
PartialState,ParallelismConfig, andget_world_sizeare necessary for the new parallelism configuration features.Also applies to: 49-49, 52-52
88-89: LGTM! Clean attribute declarations for parallelism state.The class attributes are properly typed and initialized for managing parallelism configuration.
187-200: LGTM! Sound logic for parallelism configuration activation.The conditions for enabling parallelism configuration are well-defined and the FSDP version 2 requirement is appropriately handled.
722-727: LGTM! Proper tensor parallelism setup with conflict resolution.The tensor parallelism configuration correctly sets the required parameters and appropriately removes conflicting device_map settings.
src/axolotl/utils/schemas/validation.py (4)
676-676: LGTM! Parameter rename aligns with codebase-wide changes.The change from
sequence_parallel_degreetocontext_parallel_sizeis consistent with the broader parameter renaming effort across the codebase.
903-926: LGTM! More flexible DeepSpeed configuration handling.The conditional logic appropriately handles cases where no DeepSpeed configuration file is present, making the validation more robust.
1205-1244: LGTM! Comprehensive parameter rename with consistent updates.The method rename and all parameter references have been consistently updated from
sequence_parallel_degreetocontext_parallel_size. The transformers patching logic for flash attention compatibility is appropriately maintained.
1248-1248: LGTM! Consistent parameter reference update.The parameter reference update maintains consistency with the broader renaming effort.
| # vllm 0.9.x is incompatible with latest transformers | ||
| extras_require_map.pop(_install_requires.index(vllm_version)) | ||
| else: | ||
| _install_requires.append("xformers==0.0.31.post1") | ||
| extras_require_map["vllm"] = ["vllm>=0.9.0"] | ||
| extras_require_map["vllm"] = ["vllm>=0.10.0"] |
There was a problem hiding this comment.
Fix syntax error in vllm version handling.
Line 73 has a critical syntax error - extras_require_map.pop() expects a dictionary key, but _install_requires.index(vllm_version) returns an integer index.
- # vllm 0.9.x is incompatible with latest transformers
- extras_require_map.pop(_install_requires.index(vllm_version))
+ # vllm 0.9.x is incompatible with latest transformers
+ _install_requires.pop(_install_requires.index(vllm_version))This should remove vllm from the install requirements (similar to xformers on line 69) rather than from the extras_require_map.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # vllm 0.9.x is incompatible with latest transformers | |
| extras_require_map.pop(_install_requires.index(vllm_version)) | |
| else: | |
| _install_requires.append("xformers==0.0.31.post1") | |
| extras_require_map["vllm"] = ["vllm>=0.9.0"] | |
| extras_require_map["vllm"] = ["vllm>=0.10.0"] | |
| # vllm 0.9.x is incompatible with latest transformers | |
| _install_requires.pop(_install_requires.index(vllm_version)) | |
| else: | |
| _install_requires.append("xformers==0.0.31.post1") | |
| extras_require_map["vllm"] = ["vllm>=0.10.0"] |
🤖 Prompt for AI Agents
In setup.py around lines 72 to 76, the code incorrectly uses
extras_require_map.pop() with an integer index instead of a dictionary key,
causing a syntax error. To fix this, remove the vllm version from the
_install_requires list using _install_requires.remove(vllm_version) instead of
popping from extras_require_map. This aligns with the intended behavior of
removing vllm from install requirements rather than from extras_require_map.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)
408-460: Method implementation looks correct, but validation issue remains unaddressed.The parallelism configuration calculation logic is well-structured and handles various scenarios appropriately. However, the past review comment about adding validation for world size divisibility by each parallelism factor has not been addressed.
462-488: Print statement and protected member access issues remain unaddressed.The parallelism configuration setup logic is correct, but the previously identified issues with using
print()instead of logging and accessing protected members ofPartialStatehave not been resolved.
🧹 Nitpick comments (3)
.github/workflows/multi-gpu-e2e.yml (1)
46-48: Minor: quote scalar to stay consistent with other string valuesAll other scalar strings in the matrix (
"3.11","true") are quoted. Quotingvllmkeeps style consistent and guards against accidental boolean parsing.- axolotl_extras: vllm + axolotl_extras: "vllm"src/axolotl/loaders/model.py (1)
841-847: LGTM! Necessary workaround with minor optimization opportunity.The workaround for the upstream transformers 4.54.0 issue is appropriate and well-documented. The TODO comment ensures it will be removed when no longer needed.
Consider simplifying the nested if statement as suggested by static analysis:
- if self.cfg.tensor_parallel_size > 1: - # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh - # TODO(wing): remove once 4.54.1 is released - if self.model._tp_size != self.cfg.tensor_parallel_size: - self.model._tp_size = self.cfg.tensor_parallel_size - self.model._device_mesh = self.model_kwargs["device_mesh"] + if ( + self.cfg.tensor_parallel_size > 1 + and self.model._tp_size != self.cfg.tensor_parallel_size + ): + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is released + self.model._tp_size = self.cfg.tensor_parallel_size + self.model._device_mesh = self.model_kwargs["device_mesh"]src/axolotl/utils/schemas/validation.py (1)
1265-1276: LGTM! New tensor parallelism optimizer validation with minor optimization opportunity.The new validation mixin and
check_tensor_parallel_optimizermethod appropriately prevent incompatible optimizer configurations with tensor parallelism. The logic correctly identifies optimizers that don't work with tensor parallelism.Consider simplifying the nested if statement as suggested by static analysis:
- def check_tensor_parallel_optimizer(self): - if self.tensor_parallel_size > 1: - if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: - raise ValueError( - "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" - ) + def check_tensor_parallel_optimizer(self): + if ( + self.tensor_parallel_size > 1 + and self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"] + ): + raise ValueError( + "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" + )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
.github/workflows/multi-gpu-e2e.yml(1 hunks)cicd/multigpu.sh(1 hunks)setup.py(1 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/utils/samplers/multipack.py(2 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/samplers/multipack.py
🚧 Files skipped from review as they are similar to previous changes (2)
- cicd/multigpu.sh
- setup.py
🧰 Additional context used
🧠 Learnings (2)
.github/workflows/multi-gpu-e2e.yml (1)
Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.
src/axolotl/utils/schemas/validation.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
load(154-183)
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py
428-429: Use a single if statement instead of nested if statements
(SIM102)
449-450: Use a single if statement instead of nested if statements
(SIM102)
842-845: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/utils/schemas/validation.py
1270-1271: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (8)
src/axolotl/loaders/model.py (4)
16-16: LGTM! Import additions support new parallelism features.The new imports for
PartialState,ParallelismConfig, and consolidatedget_world_sizeare appropriate for the parallelism configuration functionality being added.Also applies to: 49-49, 52-52
88-89: LGTM! New class attributes are properly typed.The new parallelism configuration attributes are correctly typed and initialized with appropriate default values to manage the parallelism state.
187-200: LGTM! Parallelism configuration logic is well-structured.The conditional logic properly determines when parallelism configuration should be enabled based on FSDP, tensor parallel, and context parallel settings. The check to disable parallelism for non-FSDP v2 configurations is appropriate.
722-728: LGTM! Tensor parallelism configuration is properly implemented.The logic correctly configures tensor parallelism by setting the appropriate model kwargs and removing the incompatible
device_map. Usingtp_plan: "auto"is a good approach for automatic tensor parallel planning.src/axolotl/utils/schemas/validation.py (4)
676-676: LGTM! Parameter rename is consistent with the refactoring.The change from
sequence_parallel_degreetocontext_parallel_sizemaintains the same validation logic while aligning with the systematic parameter renaming throughout the codebase.
903-926: LGTM! Improved robustness by making DeepSpeed config updates conditional.The addition of the conditional check ensures the DeepSpeed configuration is only updated when a DeepSpeed config is actually present, preventing errors when tensor parallelism is used without DeepSpeed. This is a sensible improvement to the validation logic.
1205-1244: LGTM! Method renamed consistently with preserved validation logic.The systematic renaming from
check_sequence_parallel_degreetocheck_context_parallel_sizemaintains all the existing validation logic while using the new parameter naming convention. The validation requirements for flash attention, micro batch size constraints, and ring_flash_attn imports are properly preserved.
1248-1248: LGTM! Consistent parameter rename in ring attention validation.The change maintains the same validation logic while using the updated parameter name, consistent with the systematic refactoring.
5351801 to
94a6a84
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (7)
src/axolotl/loaders/patch_manager.py (1)
69-69: Re-enable or confirm removal of sequence parallel patches.The call to
_apply_sequence_parallel_patches()remains commented out, which disables sequence-parallelism patches and makescontext_parallel_sizeineffective. This was previously flagged in past reviews.Please confirm:
- Is disabling sequence-parallelism intentional and permanent?
- If not, where are these patches now applied?
- If you still need
context_parallel_sizefunctionality, please re-enable:- # self._apply_sequence_parallel_patches() + self._apply_sequence_parallel_patches()src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)
20-26: Remove redundant setattr call.The static analysis correctly identifies that the
setattrcall is redundant. The direct assignment on line 21 is sufficient to patch the module.def patch_tp_fix(): transformers.integrations.tensor_parallel.distribute_model = distribute_model - setattr( - sys.modules["transformers.integrations.tensor_parallel"], - "distribute_model", - distribute_model, - )src/axolotl/core/trainers/mixins/dist_parallel.py (1)
20-22: Combine nested if statements and address Python compatibility.The implementation has a few areas for improvement:
- The nested if statements can be combined as suggested by static analysis
- The type hint
tuple[str, ...] | Nonerequires Python 3.10+- if not is_package_version_ge("accelerate", "1.10.0"): - # pylint: disable=protected-access - if int(os.environ.get("WORLD_SIZE", "1")) > 1: + if ( + not is_package_version_ge("accelerate", "1.10.0") + and int(os.environ.get("WORLD_SIZE", "1")) > 1 + ): + # pylint: disable=protected-accesssrc/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
77-87: Remove redundant patching withsetattr.The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.
def patch_prepare_from_posids(): import transformers.modeling_flash_attention_utils transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access _prepare_from_posids ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_prepare_from_posids", - _prepare_from_posids, - )src/axolotl/loaders/model.py (3)
408-465: Static method for computing parallelism config kwargs is well-structured.The
_get_parallel_config_kwargsmethod properly computes and validates parallelism parameters by sequentially dividing the world size. The validation ensures the combined parallelism matches the world size exactly.
478-478: Replace print with logging.
485-490: Document protected member access.
🧹 Nitpick comments (6)
src/axolotl/core/builders/base.py (1)
439-462: Consider simplifying the parallelism configuration logic.The current implementation has complex nested checks for
parallelism_configpresence. While functional, it could be more readable and maintainable.Consider refactoring:
def _configure_accelerator_config(self, training_args_kwargs: dict): partial_state = PartialState() - has_pc_attr = ( - hasattr(partial_state, "parallelism_config") - and partial_state.parallelism_config - ) - has_pc_key = ( - "parallelism_config" - in partial_state._shared_state # pylint: disable=protected-access - and partial_state._shared_state[ # pylint: disable=protected-access - "parallelism_config" - ] - ) - use_configured_state = has_pc_attr or has_pc_key + + def _has_parallelism_config(): + return ( + (hasattr(partial_state, "parallelism_config") and partial_state.parallelism_config) or + (partial_state._shared_state.get("parallelism_config")) # pylint: disable=protected-access + ) + + use_configured_state = _has_parallelism_config() if self.cfg.accelerator_config: use_configured_state = self.cfg.accelerator_config.pop( "use_configured_state", use_configured_state ) training_args_kwargs["accelerator_config"] = AcceleratorConfig( use_configured_state=use_configured_state, **self.cfg.accelerator_config ) else: training_args_kwargs["accelerator_config"] = AcceleratorConfig( use_configured_state=use_configured_state, )src/axolotl/utils/bench.py (2)
104-106: Consider removing ValueError for non-critical utility.Based on past feedback, raising ValueErrors may not be necessary for this non-critical utility. Consider returning default values or silently handling unsupported devices instead.
- raise ValueError( - f"Unable to determine memory statistics for current device {device}" - ) + # Return default values for unsupported devices + return 0.0, 0.0, 0.0
111-131: LGTM - Good refactoring with minor style improvement opportunity.The refactored function correctly uses the new helper, handles errors gracefully, and provides clearer logging format.
Consider the static analysis suggestion for a simpler ternary expression:
- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"src/axolotl/integrations/liger/args.py (1)
68-83: Validators correctly enforce tensor parallelism constraints.The new validators properly prevent incompatible configurations between LIGER components and tensor parallelism. The TODO comment on line 80 indicates this might be part of a larger compatibility issue that needs investigation.
Would you like me to help investigate the broader compatibility issues between tensor parallelism and liger losses mentioned in the TODO?
src/axolotl/monkeypatch/accelerate/distributed.py (1)
1-220: Comprehensive fallback implementation for ParallelismConfig.The fallback implementation provides a complete interface for parallelism configuration with proper validation logic. The class structure and methods are well-designed to support various parallelism strategies.
Consider addressing the static analysis hints:
- return tuple(zip(*sorted_items)) + return tuple(zip(*sorted_items, strict=False))- assert ( - parallelism in self._sizes.keys() - ), f"Parallelism must be one of {self._sizes.keys()}" + assert ( + parallelism in self._sizes + ), f"Parallelism must be one of {list(self._sizes.keys())}"src/axolotl/utils/schemas/validation.py (1)
1265-1278: New distributed validation mixin correctly enforces optimizer constraints.The
DistributedValidationMixinproperly validates that tensor parallelism is incompatible with certain 8-bit optimizers (paged_adamw_8bit, adamw_8bit, adamw_bnb_8bit).Consider simplifying the nested if statement as suggested by static analysis:
- def check_tensor_parallel_optimizer(self): - if self.tensor_parallel_size > 1: - if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: - raise ValueError( - "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" - ) + def check_tensor_parallel_optimizer(self): + if (self.tensor_parallel_size > 1 and + self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]): + raise ValueError( + "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" + )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (43)
cicd/single_gpu.py(1 hunks)docs/sequence_parallelism.qmd(3 hunks)setup.py(1 hunks)src/axolotl/cli/merge_lora.py(1 hunks)src/axolotl/core/builders/base.py(2 hunks)src/axolotl/core/builders/rl.py(1 hunks)src/axolotl/core/trainers/base.py(2 hunks)src/axolotl/core/trainers/grpo/__init__.py(1 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)src/axolotl/core/trainers/grpo/sampler.py(4 hunks)src/axolotl/core/trainers/grpo/trainer.py(8 hunks)src/axolotl/core/trainers/mixins/__init__.py(1 hunks)src/axolotl/core/trainers/mixins/checkpoints.py(1 hunks)src/axolotl/core/trainers/mixins/dist_parallel.py(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/loaders/patch_manager.py(3 hunks)src/axolotl/monkeypatch/accelerate/distributed.py(1 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(1 hunks)src/axolotl/monkeypatch/ring_attn/__init__.py(2 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(7 hunks)src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py(1 hunks)src/axolotl/monkeypatch/transformers/tensor_parallel.py(1 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/bench.py(2 hunks)src/axolotl/utils/callbacks/__init__.py(2 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(5 hunks)src/axolotl/utils/data/shared.py(1 hunks)src/axolotl/utils/environment.py(2 hunks)src/axolotl/utils/samplers/multipack.py(2 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)src/axolotl/utils/trainer.py(3 hunks)tests/core/test_builders.py(1 hunks)tests/e2e/multigpu/patched/test_sp.py(2 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)tests/e2e/multigpu/test_fp8_fsdp2.py(2 hunks)tests/e2e/multigpu/test_tp.py(1 hunks)tests/e2e/patched/test_sp.py(6 hunks)tests/e2e/test_load_model.py(1 hunks)tests/e2e/utils.py(1 hunks)tests/test_loaders.py(1 hunks)train.yaml(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- src/axolotl/core/trainers/mixins/init.py
- src/axolotl/utils/samplers/multipack.py
🚧 Files skipped from review as they are similar to previous changes (26)
- src/axolotl/cli/merge_lora.py
- src/axolotl/monkeypatch/accelerate/fsdp2.py
- tests/e2e/test_load_model.py
- docs/sequence_parallelism.qmd
- src/axolotl/utils/data/shared.py
- src/axolotl/core/trainers/grpo/args.py
- tests/core/test_builders.py
- tests/e2e/utils.py
- src/axolotl/core/trainers/mixins/checkpoints.py
- tests/e2e/patched/test_sp.py
- src/axolotl/train.py
- cicd/single_gpu.py
- tests/e2e/multigpu/solo/test_grpo.py
- src/axolotl/monkeypatch/ring_attn/init.py
- src/axolotl/core/trainers/grpo/sampler.py
- src/axolotl/core/trainers/grpo/trainer.py
- src/axolotl/utils/trainer.py
- tests/e2e/multigpu/test_fp8_fsdp2.py
- tests/test_loaders.py
- src/axolotl/core/trainers/grpo/init.py
- tests/e2e/multigpu/patched/test_sp.py
- src/axolotl/core/builders/rl.py
- setup.py
- train.yaml
- src/axolotl/utils/schemas/config.py
- src/axolotl/utils/ctx_managers/sequence_parallel.py
🧰 Additional context used
🧠 Learnings (1)
src/axolotl/utils/schemas/validation.py (1)
Learnt from: NanoCode012
PR: #2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
🧬 Code Graph Analysis (8)
src/axolotl/core/trainers/base.py (1)
src/axolotl/core/trainers/mixins/dist_parallel.py (1)
DistParallelMixin(12-41)
src/axolotl/loaders/patch_manager.py (4)
src/axolotl/integrations/base.py (2)
cfg(350-351)cfg(354-355)src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
patch_prepare_from_posids(77-87)src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)
patch_tp_fix(20-26)src/axolotl/monkeypatch/ring_attn/patch.py (1)
patch_prepare_device_mesh(389-433)
tests/e2e/multigpu/test_tp.py (2)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/utils.py (2)
check_tensorboard(149-163)require_torch_2_7_0(80-89)
src/axolotl/utils/callbacks/__init__.py (1)
src/axolotl/utils/bench.py (2)
get_gpu_memory_usage(95-108)log_gpu_memory_usage(111-131)
src/axolotl/core/builders/base.py (1)
tests/e2e/patched/test_sp.py (1)
partial_state(25-28)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
src/axolotl/core/trainers/mixins/dist_parallel.py (2)
src/axolotl/utils/environment.py (1)
is_package_version_ge(39-41)src/axolotl/core/trainers/base.py (1)
create_accelerator_and_postprocess(515-525)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
load(154-183)
🪛 Ruff (0.12.2)
src/axolotl/monkeypatch/transformers/tensor_parallel.py
22-26: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
src/axolotl/utils/bench.py
127-127: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
src/axolotl/core/trainers/mixins/dist_parallel.py
20-22: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/monkeypatch/accelerate/distributed.py
180-180: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
216-216: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
src/axolotl/loaders/model.py
432-433: Use a single if statement instead of nested if statements
(SIM102)
453-454: Use a single if statement instead of nested if statements
(SIM102)
846-849: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py
83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
src/axolotl/utils/schemas/validation.py
1270-1271: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (24)
src/axolotl/core/trainers/base.py (2)
30-30: LGTM: Clean integration of distributed parallelism mixin.The import and integration of
DistParallelMixinis correctly implemented.
54-54: LGTM: Proper mixin inheritance order.The
DistParallelMixinis correctly positioned in the inheritance chain, allowing it to extend the trainer's accelerator creation capabilities for distributed training.tests/e2e/multigpu/test_tp.py (1)
16-64: LGTM: Well-structured tensor parallelism integration test.This test effectively validates tensor parallelism functionality with:
- Proper version requirements (
@require_torch_2_7_0)- Appropriate test configuration (small model, limited steps for quick execution)
- Multi-process execution setup matching the
tensor_parallel_size- Validation through TensorBoard logs
The test configuration and execution approach follow established patterns from the existing test suite.
src/axolotl/loaders/patch_manager.py (3)
52-52: LGTM: Appropriate integration of transformers patches.The new
_apply_transformers_patches()method properly applies Flash Attention and tensor parallelism fixes from the transformers monkeypatch modules.
75-84: LGTM: Clean implementation of transformers patching.The new method properly imports and applies the required transformers patches for Flash Attention and tensor parallelism support.
279-284: LGTM: Consistent parameter renaming.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis consistent with the broader refactoring across the codebase and aligns with PyTorch's native context parallelism terminology.src/axolotl/monkeypatch/transformers/tensor_parallel.py (1)
8-17: LGTM - Clean monkey patch implementation.The wrapped
distribute_modelfunction correctly calls the original implementation and adds the necessary tensor parallel metadata (_tp_sizeand_device_mesh) to the model instance. The pylint disable comments are appropriate for this monkey patch context.src/axolotl/core/trainers/mixins/dist_parallel.py (1)
23-40: LGTM - Solid tensor parallelism setup logic.The implementation correctly handles backward compatibility with older accelerate versions by manually configuring tensor parallelism. The device mesh inspection and plugin setup follow the expected patterns for accelerate integration.
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
12-74: LGTM - Comprehensive flash attention utility implementation.The
_prepare_from_posidsfunction is well-implemented with:
- Comprehensive documentation explaining all parameters and return values
- Proper tensor handling with contiguous views and device management
- Good awareness of TorchDynamo compatibility requirements
- Sound logic for extracting cumulative sequence lengths and position handling
src/axolotl/utils/callbacks/__init__.py (2)
38-38: LGTM - Correct import addition.The import of
get_gpu_memory_usageis needed for the enhanced GPU memory logging functionality implemented in the callback.
111-126: LGTM - Enhanced GPU memory monitoring.The updated callback implementation provides better monitoring capabilities:
- Continuous logging after the first step gives better visibility into memory patterns
- WandB integration with structured metrics (active, allocated, reserved) provides detailed tracking
- Proper error handling for unsupported devices
- Correct process-zero checking for WandB logging
src/axolotl/utils/bench.py (1)
60-63: LGTM - Improved memory metrics for monitoring.The change to peak/max memory metrics (active bytes peak, max allocated, max reserved) provides better insights into memory usage patterns during training compared to current usage snapshots.
src/axolotl/integrations/liger/args.py (1)
31-37: Type annotation updates look good.The migration from
Optional[bool]tobool | Nonesyntax is consistent with modern Python type hints.src/axolotl/monkeypatch/ring_attn/patch.py (4)
166-196: Parameter renaming is consistent and well-documented.The renaming from
sequence_parallel_degreetocontext_parallel_sizeis properly implemented with updated docstrings and logging messages.
220-252: Ring attention function selection logic is properly implemented.The conditional logic for selecting between
VARLEN_LLAMA3andBATCH_RINGimplementations based on thering_attn_funcparameter is well-structured. The monkeypatch approach for overriding the flash attention forward function is appropriate for this use case.
254-332: New DeviceMesh-based registration function is well-designed.The
register_ring_attn_from_device_meshfunction provides a clean interface for creating ring attention groups from existing device meshes. Error handling for missing dimensions is appropriate, and the logging provides good visibility into the configuration.
389-433: Device mesh patching correctly uses "cp" dimension name.The update to use "cp" (context parallelism) instead of "sp" aligns with PyTorch's native naming conventions as noted in the comments. The mesh shape construction properly uses
context_parallel_size.src/axolotl/loaders/model.py (3)
88-90: Parallelism configuration attributes are properly initialized.The new attributes
use_parallel_configandparallelism_configare correctly typed and initialized.
726-732: Tensor parallelism setup in model kwargs is correct.The code properly configures tensor parallelism parameters and removes incompatible device_map when tensor parallelism is enabled.
846-852: Workaround for upstream transformers issue is documented.The workaround for transformers 4.54.0 not setting
_tp_sizeor_device_meshis properly documented with a TODO comment for future removal.src/axolotl/utils/schemas/validation.py (4)
676-679: GRPO validation correctly uses context_parallel_size.The validation properly checks for incompatibility between GRPO, Liger loss, and context parallelism using the renamed parameter.
900-927: DeepSpeed tensor parallelism config update is now conditional.The logic now only proceeds if a DeepSpeed config path is present, which is more robust than the previous implementation that would error if missing.
1205-1244: Context parallel size validation is comprehensive.The validation properly:
- Sets default value to 1 if not specified
- Requires flash attention when context_parallel_size > 1
- Enforces micro_batch_size=1 with sample packing due to ring-flash-attn requirements
- Patches transformers before importing ring_flash_attn
- Provides helpful error messages and warnings
1247-1263: Ring attention function validation properly uses context_parallel_size.The validation correctly checks context_parallel_size and sets appropriate defaults based on sample packing configuration.
| def get_package_version(package: str) -> Version: | ||
| version_str = version(package) | ||
| return parse(version_str) | ||
|
|
||
|
|
||
| def is_package_version_ge(package: str, version_: str) -> bool: | ||
| package_version = get_package_version(package) | ||
| return package_version >= parse(version_) |
There was a problem hiding this comment.
Add error handling for missing packages.
The functions don't handle the case where a package is not installed, which will cause importlib.metadata.version() to raise a PackageNotFoundError. This could lead to unexpected crashes in calling code.
Consider adding error handling:
def get_package_version(package: str) -> Version:
- version_str = version(package)
- return parse(version_str)
+ try:
+ version_str = version(package)
+ return parse(version_str)
+ except PackageNotFoundError:
+ raise ValueError(f"Package '{package}' is not installed")
def is_package_version_ge(package: str, version_: str) -> bool:
- package_version = get_package_version(package)
- return package_version >= parse(version_)
+ try:
+ package_version = get_package_version(package)
+ return package_version >= parse(version_)
+ except ValueError:
+ return False # Or handle appropriately based on use caseCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/utils/environment.py around lines 34 to 41, the functions
get_package_version and is_package_version_ge do not handle the case where the
specified package is not installed, causing importlib.metadata.version() to
raise a PackageNotFoundError. To fix this, add try-except blocks around the
calls to version() to catch PackageNotFoundError and handle it gracefully, such
as returning None or raising a custom error, and update is_package_version_ge to
handle this case accordingly to avoid unexpected crashes.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/axolotl/loaders/model.py (2)
462-488: Replace print with logging and document protected member access.Two issues to address:
- Line 474 uses
print()instead of the logger- Lines 481-487 access protected members of
PartialStateif pc_kwargs: - print(pc_kwargs) + LOG.info("Parallelism configuration: %s", pc_kwargs) self.parallelism_config = ParallelismConfig( **pc_kwargs, )Consider adding a comment explaining why protected member access is necessary:
# We need to access PartialState's protected members to properly configure # the parallelism state before model initialization. This is required for # tensor parallelism to work correctly with accelerate.
408-460: Missing divisibility validation and code structure improvements needed.Two issues to address:
Missing divisibility validation: The past review comment about adding validation for world size divisibility by parallelism factors hasn't been implemented. This could lead to runtime errors.
Simplify nested if statements: Static analysis suggests combining nested if statements for better readability.
Apply these improvements:
@staticmethod def _get_parallel_config_kwargs( world_size: int, tensor_parallel_size: int = 1, context_parallel_size: int = 1, dp_shard_size: int | None = None, dp_replicate_size: int | None = None, is_fsdp: bool = False, ): pc_kwargs = {} remaining_world_size = world_size if tensor_parallel_size and tensor_parallel_size > 1: + if remaining_world_size % tensor_parallel_size != 0: + raise ValueError( + f"World size ({remaining_world_size}) must be divisible by tensor_parallel_size ({tensor_parallel_size})" + ) pc_kwargs["tp_size"] = tensor_parallel_size remaining_world_size = remaining_world_size // tensor_parallel_size if context_parallel_size and context_parallel_size > 1: + if remaining_world_size % context_parallel_size != 0: + raise ValueError( + f"Remaining world size ({remaining_world_size}) must be divisible by context_parallel_size ({context_parallel_size})" + ) pc_kwargs["cp_size"] = context_parallel_size remaining_world_size = remaining_world_size // context_parallel_size - if dp_shard_size is None and dp_replicate_size in (None, 1): - if remaining_world_size > 1: + if dp_shard_size is None and dp_replicate_size in (None, 1) and remaining_world_size > 1: pc_kwargs["dp_shard_size"] = remaining_world_size remaining_world_size = 1 if dp_replicate_size and dp_replicate_size > 1: + if remaining_world_size % dp_replicate_size != 0: + raise ValueError( + f"Remaining world size ({remaining_world_size}) must be divisible by dp_replicate_size ({dp_replicate_size})" + ) pc_kwargs["dp_replicate_size"] = dp_replicate_size remaining_world_size = remaining_world_size // dp_replicate_size - if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: + if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: + if remaining_world_size % dp_shard_size != 0: + raise ValueError( + f"Remaining world size ({remaining_world_size}) must be divisible by dp_shard_size ({dp_shard_size})" + ) if not is_fsdp: raise ValueError( "dp_shard_size was configured without a corresponding fsdp_config! " "Please ensure you have configured FSDP using fsdp_config." ) pc_kwargs["dp_shard_size"] = dp_shard_size remaining_world_size = remaining_world_size // dp_shard_size if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs: pc_kwargs["dp_replicate_size"] = remaining_world_size remaining_world_size = 1 - if remaining_world_size > 1: - if "dp_shard_size" not in pc_kwargs and is_fsdp: + if remaining_world_size > 1 and "dp_shard_size" not in pc_kwargs and is_fsdp: pc_kwargs["dp_shard_size"] = remaining_world_size remaining_world_size = 1
🧹 Nitpick comments (1)
src/axolotl/loaders/model.py (1)
841-848: Simplify nested if statement.The workaround logic is correct, but the nested if statements can be combined for better readability as suggested by static analysis.
- if self.cfg.tensor_parallel_size > 1: - # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh - # TODO(wing): remove once 4.54.1 is released - if self.model._tp_size != self.cfg.tensor_parallel_size: + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is released + if self.cfg.tensor_parallel_size > 1 and self.model._tp_size != self.cfg.tensor_parallel_size: self.model._tp_size = self.cfg.tensor_parallel_size self.model._device_mesh = self.model_kwargs["device_mesh"]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/loaders/model.py(7 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/loaders/model.py
428-429: Use a single if statement instead of nested if statements
(SIM102)
449-450: Use a single if statement instead of nested if statements
(SIM102)
842-845: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/loaders/model.py (4)
16-16: Import changes look good.The new imports for
PartialState,ParallelismConfig, andget_world_sizeare appropriate for the parallelism functionality being added.Also applies to: 49-49, 52-52
88-89: Class attributes properly typed.The new parallelism-related class attributes are well-defined with appropriate type annotations and default values.
187-200: Parallelism configuration logic is well-structured.The conditional logic correctly determines when to enable parallelism configuration based on FSDP, tensor parallelism, and context parallelism settings. The FSDP version check is a good safeguard.
722-728: Tensor parallelism setup is correct.The logic properly configures tensor parallelism by setting the required model kwargs and removing incompatible
device_map. The conditional check ensures it only applies when tensor parallelism is enabled.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
cicd/multigpu.sh(1 hunks)requirements.txt(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
🔇 Additional comments (1)
cicd/multigpu.sh (1)
5-5: Good call adding--durations=10to surface slow tests
This will automatically highlight the ten slowest tests in CI logs, making it easier to spot performance regressions without extra tooling.
| huggingface_hub>=0.33.0 | ||
| peft==0.16.0 | ||
| transformers==4.54.0 | ||
| transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Pin transformers to an immutable commit or tag to guarantee reproducible installs
Pointing to winglian/transformers.git@v4.54.1-maybe relies on a mutable branch/tag.
If that ref is force-pushed or deleted, every fresh environment will break or pick up different code, making binary wheels & training runs non-reproducible. Pin to a commit SHA (or an immutable annotated tag) and document why this fork is required.
Example fix:
-transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe
+# TODO: replace <commit-sha> with the exact commit you depend on
+transformers @ git+https://github.com/winglian/transformers.git@<commit-sha>Also verify the fork remains compatible with accelerate==1.9.0; the upstream PR 3682 suggests you may need a custom accelerate build as well.
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "🔍 Verifying ‘v4.54.1-maybe’ ref in winglian/transformers:"
git ls-remote --refs https://github.com/winglian/transformers.git \
refs/tags/v4.54.1-maybe refs/heads/v4.54.1-maybe || true
echo
echo "🔍 Checking for any ‘accelerate’ pins in requirements.txt:"
grep -n "accelerate" requirements.txt || trueLength of output: 523
Pin transformers to an exact commit SHA
You’re currently tracking the mutable branch v4.54.1-maybe, which can be force-pushed or deleted at any time—breaking reproducibility. Replace it with the precise commit hash we just looked up:
-requirements.txt:16
-transformers @ git+https://github.com/winglian/transformers.git@v4.54.1-maybe
+requirements.txt:16
+transformers @ git+https://github.com/winglian/transformers.git@d274901a61d5ff2977072ada1738bd38ff14be70• Add a comment (or README entry) explaining why you’re using this fork.
• Verify that this commit remains compatible with accelerate==1.9.0 (pinned on line 18).
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In requirements.txt at line 16, replace the current transformers dependency that
tracks the mutable branch v4.54.1-maybe with the exact commit SHA to ensure
reproducibility. Add a comment explaining why this specific fork is used. Also,
verify that this commit is compatible with accelerate==1.9.0 pinned on line 18.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
baseten/run.sh (1)
2-2: Addpipefailfor safer error-handling in pipelines
set -euxdoes not cover failures that occur inside a pipeline (cmd1 | cmd2).
Usingset -euo pipefail(or explicitlyset -o pipefail) prevents silent data loss by exiting on the first failing command in any pipeline.-set -eux +set -euo pipefail +set -x # keep the `-x` (debug) behaviour
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
baseten/run.sh(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
baseten/run.sh (1)
Learnt from: NanoCode012
PR: #2854
File: README.md:73-77
Timestamp: 2025-07-02T02:56:20.788Z
Learning: For Axolotl Docker commands, the --ipc=host flag should be included by default to prevent shared memory failures that commonly occur with PyTorch DataLoaders and multiprocessing during machine learning training workflows.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
| # if node rank 0 | ||
| axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared | ||
|
|
There was a problem hiding this comment.
axolotl preprocess will be executed on every node – guard with node-rank check
The comment says “if node rank 0” but the command is unconditional.
When launched via torchrun, every node/container will attempt to write to the same $BT_CHECKPOINT_DIR, leading to corruption or a noisy/slow preprocess phase.
-# if node rank 0
-axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared
+# Only the leader should run preprocessing
+if [[ "${BT_NODE_RANK:-0}" == "0" ]]; then
+ axolotl preprocess train.yaml \
+ --output-dir="$BT_CHECKPOINT_DIR" \
+ --dataset-prepared-path="${BT_CHECKPOINT_DIR}/last_run_prepared"
+fi(Uses parameter expansion so the script still works when BT_NODE_RANK is undefined in single-node setups.)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # if node rank 0 | |
| axolotl preprocess train.yaml --output-dir=$BT_CHECKPOINT_DIR --dataset-prepared-path=${BT_CHECKPOINT_DIR}/last_run_prepared | |
| # Only the leader should run preprocessing | |
| if [[ "${BT_NODE_RANK:-0}" == "0" ]]; then | |
| axolotl preprocess train.yaml \ | |
| --output-dir="$BT_CHECKPOINT_DIR" \ | |
| --dataset-prepared-path="${BT_CHECKPOINT_DIR}/last_run_prepared" | |
| fi |
🤖 Prompt for AI Agents
In baseten/run.sh around lines 8 to 10, the axolotl preprocess command runs
unconditionally on all nodes, risking concurrent writes to $BT_CHECKPOINT_DIR.
Wrap the preprocess command in a conditional that checks if BT_NODE_RANK is
unset or equals 0, using parameter expansion to handle single-node setups
safely. This ensures only the node with rank 0 executes the preprocess step,
preventing directory corruption and redundant processing.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (5)
src/axolotl/utils/bench.py (1)
95-107: Improve device detection consistency.The device type detection logic has inconsistencies that should be addressed:
- Line 97 checks MPS availability independently of the device type from line 96
- String containment checks on lines 99 and 101 are fragile
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py (1)
77-87: Remove redundant patching withsetattr.The function patches the same attribute twice. The direct assignment on lines 80-82 is sufficient.
src/axolotl/loaders/model.py (1)
408-460: Add validation for world size divisibility.The method should validate that world_size is evenly divisible by each parallelism factor before performing division to prevent runtime errors.
src/axolotl/utils/schemas/validation.py (2)
1213-1218: Monkey-patching transformers internals needs documentation and version guards.This unconditional monkey patch of
transformers.modeling_flash_attention_utils._flash_supports_window_sizewas previously flagged but remains unaddressed. The patch runs on every transformers version and may break if internals change.Please add:
- Version check or attribute existence guard before applying the patch
- Documentation explaining why this patch is needed and which transformers versions require it
- Clear specification of supported transformers versions in setup.py/pyproject.toml
1257-1268: Critical: New validation mixin not included in main ValidationMixin.The
DistributedValidationMixincontains important tensor parallel + 8-bit optimizer compatibility validation, but it's not included in the mainValidationMixinclass (lines 1272-1284), so it won't be executed.Add
DistributedValidationMixinto the mainValidationMixinclass:class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, TrainingValidationMixin, LoRAValidationMixin, RLValidationMixin, OptimizationValidationMixin, SystemValidationMixin, ChatTemplateValidationMixin, PretrainingValidationMixin, ModelCompatibilityValidationMixin, ComplexValidationMixin, + DistributedValidationMixin, ):Also, simplify the nested if statements:
@model_validator(mode="after") def check_tensor_parallel_optimizer(self): - if self.tensor_parallel_size > 1: - if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: - raise ValueError( - "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" - ) + if (self.tensor_parallel_size > 1 and + self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]): + raise ValueError( + "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" + )
🧹 Nitpick comments (2)
src/axolotl/utils/bench.py (1)
125-125: Simplify ternary expression.- msg = f"{cur_device_type} memory active:" if not msg else msg + msg = msg if msg else f"{cur_device_type} memory active:"src/axolotl/loaders/model.py (1)
841-847: Simplify nested conditional and document the workaround.- if self.cfg.tensor_parallel_size > 1: - # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh - # TODO(wing): remove once 4.54.1 is released - if self.model._tp_size != self.cfg.tensor_parallel_size: + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is released + if (self.cfg.tensor_parallel_size > 1 and + self.model._tp_size != self.cfg.tensor_parallel_size): self.model._tp_size = self.cfg.tensor_parallel_size self.model._device_mesh = self.model_kwargs["device_mesh"]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (48)
cicd/multigpu.sh(1 hunks)cicd/single_gpu.py(1 hunks)docs/sequence_parallelism.qmd(3 hunks)examples/alst/llama3-8b-deepspeed-alst.yaml(1 hunks)requirements.txt(1 hunks)setup.py(1 hunks)src/axolotl/cli/merge_lora.py(1 hunks)src/axolotl/core/builders/base.py(2 hunks)src/axolotl/core/builders/rl.py(1 hunks)src/axolotl/core/trainers/base.py(2 hunks)src/axolotl/core/trainers/dpo/trainer.py(2 hunks)src/axolotl/core/trainers/grpo/__init__.py(1 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)src/axolotl/core/trainers/grpo/sampler.py(4 hunks)src/axolotl/core/trainers/grpo/trainer.py(10 hunks)src/axolotl/core/trainers/mamba.py(1 hunks)src/axolotl/core/trainers/mixins/__init__.py(1 hunks)src/axolotl/core/trainers/mixins/checkpoints.py(1 hunks)src/axolotl/core/trainers/mixins/distributed_parallel.py(1 hunks)src/axolotl/core/trainers/trl.py(5 hunks)src/axolotl/integrations/kd/trainer.py(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/loaders/model.py(7 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(1 hunks)src/axolotl/monkeypatch/ring_attn/__init__.py(1 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(3 hunks)src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py(1 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/bench.py(2 hunks)src/axolotl/utils/callbacks/__init__.py(2 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(5 hunks)src/axolotl/utils/data/shared.py(1 hunks)src/axolotl/utils/environment.py(2 hunks)src/axolotl/utils/samplers/multipack.py(2 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)src/axolotl/utils/trainer.py(3 hunks)tests/core/test_builders.py(1 hunks)tests/e2e/multigpu/patched/test_sp.py(2 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)tests/e2e/multigpu/test_fp8_fsdp2.py(2 hunks)tests/e2e/multigpu/test_tp.py(1 hunks)tests/e2e/patched/test_sp.py(0 hunks)tests/e2e/test_load_model.py(1 hunks)tests/e2e/utils.py(1 hunks)tests/test_loaders.py(1 hunks)tests/utils/schemas/validation/test_fsdp.py(0 hunks)
💤 Files with no reviewable changes (2)
- tests/utils/schemas/validation/test_fsdp.py
- tests/e2e/patched/test_sp.py
✅ Files skipped from review due to trivial changes (3)
- src/axolotl/integrations/kd/trainer.py
- tests/e2e/multigpu/patched/test_sp.py
- src/axolotl/core/trainers/dpo/trainer.py
🚧 Files skipped from review as they are similar to previous changes (38)
- examples/alst/llama3-8b-deepspeed-alst.yaml
- src/axolotl/core/trainers/mamba.py
- cicd/multigpu.sh
- src/axolotl/core/trainers/mixins/init.py
- src/axolotl/core/trainers/grpo/args.py
- docs/sequence_parallelism.qmd
- src/axolotl/utils/samplers/multipack.py
- src/axolotl/core/trainers/mixins/checkpoints.py
- src/axolotl/cli/merge_lora.py
- tests/core/test_builders.py
- src/axolotl/train.py
- src/axolotl/core/trainers/base.py
- cicd/single_gpu.py
- tests/e2e/test_load_model.py
- src/axolotl/core/builders/rl.py
- requirements.txt
- src/axolotl/monkeypatch/accelerate/fsdp2.py
- tests/e2e/multigpu/solo/test_grpo.py
- tests/e2e/utils.py
- src/axolotl/core/trainers/grpo/init.py
- tests/e2e/multigpu/test_tp.py
- src/axolotl/utils/environment.py
- src/axolotl/loaders/patch_manager.py
- setup.py
- src/axolotl/utils/trainer.py
- tests/e2e/multigpu/test_fp8_fsdp2.py
- src/axolotl/core/trainers/mixins/distributed_parallel.py
- src/axolotl/core/builders/base.py
- src/axolotl/core/trainers/grpo/trainer.py
- src/axolotl/core/trainers/grpo/sampler.py
- src/axolotl/monkeypatch/ring_attn/init.py
- tests/test_loaders.py
- src/axolotl/core/trainers/trl.py
- src/axolotl/utils/callbacks/init.py
- src/axolotl/utils/ctx_managers/sequence_parallel.py
- src/axolotl/utils/data/shared.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/integrations/liger/args.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: grpo (cfg.rl == "grpo" or cfg.rl is rltype.grpo) should be excluded from dataset label checking duri...
Learnt from: NanoCode012
PR: axolotl-ai-cloud/axolotl#2952
File: src/axolotl/common/datasets.py:125-125
Timestamp: 2025-07-22T08:27:00.129Z
Learning: GRPO (cfg.rl == "grpo" or cfg.rl is RLType.GRPO) should be excluded from dataset label checking during preprocessing, while other RL methods should continue to have this functionality available.
Applied to files:
src/axolotl/utils/schemas/validation.py
🧬 Code Graph Analysis (4)
src/axolotl/utils/bench.py (1)
src/axolotl/utils/distributed.py (1)
get_device_type(20-28)
src/axolotl/loaders/model.py (1)
src/axolotl/utils/distributed.py (3)
get_device_count(31-37)get_device_type(20-28)get_world_size(103-104)
src/axolotl/monkeypatch/ring_attn/patch.py (1)
src/axolotl/utils/schemas/enums.py (1)
RingAttnFunc(84-92)
src/axolotl/utils/schemas/validation.py (1)
src/axolotl/loaders/model.py (1)
load(154-183)
🪛 Ruff (0.12.2)
src/axolotl/utils/bench.py
125-125: Use msg if msg else f"{cur_device_type} memory active:" instead of f"{cur_device_type} memory active:" if not msg else msg
Replace with msg if msg else f"{cur_device_type} memory active:"
(SIM212)
src/axolotl/loaders/model.py
428-429: Use a single if statement instead of nested if statements
(SIM102)
449-450: Use a single if statement instead of nested if statements
(SIM102)
841-844: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py
83-87: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
src/axolotl/utils/schemas/validation.py
1262-1263: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: preview
🔇 Additional comments (8)
src/axolotl/utils/bench.py (1)
60-63: LGTM! More comprehensive memory metrics.The updated return values provide better visibility into GPU memory usage with distinct metrics for active, allocated, and reserved memory.
src/axolotl/monkeypatch/ring_attn/patch.py (1)
137-188: Excellent refactoring to use DeviceMesh abstraction.The migration from manual process group creation to using PyTorch's DeviceMesh is a significant improvement. The error handling and logging are well-implemented.
src/axolotl/loaders/model.py (2)
187-200: LGTM! Clear parallelism configuration setup.The conditional logic properly determines when to use parallel configuration based on FSDP, tensor parallelism, and context parallelism settings.
462-487: Well-structured parallelism configuration.The method properly sets up the parallelism configuration and device mesh in PartialState. The protected member access is necessary for proper integration with accelerate.
src/axolotl/utils/schemas/validation.py (4)
676-676: LGTM! Parameter renamed consistently with refactoring effort.The update from
sequence_parallel_degreetocontext_parallel_sizealigns with the broader codebase refactoring mentioned in the AI summary.
888-912: LGTM! Improved conditional handling of DeepSpeed configuration.The updated logic now gracefully handles cases where no DeepSpeed configuration is present, making tensor parallelism setup more flexible. The temporary config file generation and settings updates are correctly implemented.
1192-1236: LGTM! Well-handled parameter renaming with backward compatibility.The method rename from
check_sequence_parallel_degreetocheck_context_parallel_sizeis implemented correctly with:
- Proper backward compatibility for the deprecated parameter
- Clear deprecation warning to guide users
- Consistent validation logic maintained
- Updated error messages using the new parameter name
This follows best practices for API migration.
1240-1240: LGTM! Consistent parameter name update.The update to use
context_parallel_sizeinstead ofsequence_parallel_degreeis consistent with the broader refactoring effort.
Description
Supersedes #2947
Needs upstream:
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
Chores
transformersandaccelerate.Refactor