bump hf deps#2735
Conversation
WalkthroughThis update removes monkeypatches for Gemma3 and Mllama attention, refactors dataloader construction in the AxolotlTrainer, adjusts LoRA target module regexes across configs and docs, updates dependencies and test thresholds, and improves test coverage for LoRA packing. It also refines multipack batch handling logic and patches for FSDP2 state dict loading. Additionally, it extends VLLM serve functionality with reasoning options and adds sequence parallelism support in GRPO training. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant AxolotlTrainer
participant Dataset
participant Sampler
participant DataCollator
participant Accelerator
User->>AxolotlTrainer: Request train/eval dataloader
AxolotlTrainer->>Dataset: Remove "length" column if present
AxolotlTrainer->>AxolotlTrainer: Remove unused columns (based on mode and packing)
AxolotlTrainer->>DataCollator: Select appropriate collator
AxolotlTrainer->>Sampler: Create sampler (if needed)
AxolotlTrainer->>AxolotlTrainer: Build dataloader params
AxolotlTrainer->>Accelerator: Prepare dataloader
Accelerator-->>AxolotlTrainer: Return prepared dataloader
AxolotlTrainer-->>User: Return dataloader
Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
182-256: Consider breaking down this method for better maintainability.While the consolidation of dataloader creation is beneficial, the method handles many concerns. Consider extracting some logic into helper methods for better readability and testing.
Potential extractions:
- Dataset preprocessing logic (lines 193-206)
- Persistent workers handling (lines 233-248)
- Sampler configuration (lines 216-230)
This would make the main method more focused and easier to understand.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/core/trainers/base.py(6 hunks)tests/e2e/integrations/test_kd.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/e2e/integrations/test_kd.py (1)
tests/conftest.py (1)
temp_dir(414-419)
src/axolotl/core/trainers/base.py (2)
src/axolotl/core/trainers/grpo/trainer.py (1)
_get_train_sampler(141-161)src/axolotl/core/trainer_builder.py (4)
train_dataset(132-133)train_dataset(136-137)eval_dataset(140-141)eval_dataset(144-145)
⏰ Context from checks skipped due to timeout of 90000ms (11)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
🔇 Additional comments (4)
tests/e2e/integrations/test_kd.py (1)
93-93: Threshold adjustment looks reasonable.The increase from 1.2 to 1.4 aligns with the dependency upgrades that may affect training dynamics.
src/axolotl/core/trainers/base.py (3)
9-10: Import additions are appropriate.The new imports support the refactored dataloader creation logic and improved type hints.
Also applies to: 22-23
116-131: Good improvement to method flexibility.The optional dataset parameter and proper validation using
has_lengthutility enhance the method's robustness and reusability.
291-300: Clean refactoring to use the consolidated dataloader method.The changes properly leverage
_get_dataloaderwhile maintaining the necessary data collator management.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
199-201: Simplify nested if statements.The nested if statements can be combined into a single condition for better readability.
-if isinstance(train_dataset, datasets.Dataset): - if self.args.sample_packing and not self.args.pretraining: - train_dataset = train_dataset.remove_columns(["length"]) +if isinstance(train_dataset, datasets.Dataset) and self.args.sample_packing and not self.args.pretraining: + train_dataset = train_dataset.remove_columns(["length"])🧰 Tools
🪛 Ruff (0.11.9)
199-200: Use a single
ifstatement instead of nestedifstatements(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/core/trainers/base.py(4 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (2)
src/axolotl/core/trainers/grpo/trainer.py (2)
_get_train_sampler(141-161)get_train_dataloader(225-250)src/axolotl/core/trainer_builder.py (4)
train_dataset(132-133)train_dataset(136-137)eval_dataset(140-141)eval_dataset(144-145)
🪛 Ruff (0.11.9)
src/axolotl/core/trainers/base.py
23-23: transformers.trainer_utils.seed_worker imported but unused
Remove unused import: transformers.trainer_utils.seed_worker
(F401)
199-200: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (11)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/core/trainers/base.py (5)
116-152: LGTM! Good improvement to method flexibility.The updated method signature with explicit dataset parameter and proper null checks improves the method's flexibility and error handling.
154-186: LGTM! Consistent with the train sampler changes.The explicit dataset parameter and null check improve error handling and API consistency.
212-268: Excellent refactoring with improved error handling and optimization!The changes include:
- Proper error handling for missing eval dataset
- Smart optimization to reuse persistent eval dataloaders
- Correct data collator switching and restoration
- Consistent use of parent class's
_get_dataloadermethodThis implementation is more robust and maintainable.
270-278: LGTM! Good use of the utility function.Using
has_lengthfrom transformers is the proper way to check dataset validity.
280-317: Well-implemented benchmarking dataloader with proper error handling!The rewritten method includes:
- Clear documentation
- Proper validation of the benchmark dataset
- Correct data collator switching pattern
- Appropriate handling of iterable vs non-iterable datasets
The implementation is thorough and maintainable.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/core/trainers/base.py (1)
23-23: Unused import issue resolved.The
seed_workerimport is now properly used in line 236 within the_get_dataloadermethod, resolving the previously flagged unused import issue.
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
261-264: Consider simplifying nested conditional logic.The static analysis tool suggests combining these nested if statements for better readability.
- if isinstance(train_dataset, datasets.Dataset): - if self.args.sample_packing and not self.args.pretraining: - train_dataset = train_dataset.remove_columns(["length"]) + if (isinstance(train_dataset, datasets.Dataset) + and self.args.sample_packing + and not self.args.pretraining): + train_dataset = train_dataset.remove_columns(["length"])🧰 Tools
🪛 Ruff (0.11.9)
261-262: Use a single
ifstatement instead of nestedifstatements(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/core/trainers/base.py(5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (2)
src/axolotl/core/trainers/grpo/trainer.py (2)
_get_train_sampler(141-161)get_train_dataloader(225-250)src/axolotl/core/trainer_builder.py (4)
train_dataset(132-133)train_dataset(136-137)eval_dataset(140-141)eval_dataset(144-145)
🪛 Ruff (0.11.9)
src/axolotl/core/trainers/base.py
261-262: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (11)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: pre-commit
🔇 Additional comments (10)
src/axolotl/core/trainers/base.py (10)
10-10: LGTM: Enhanced type annotations.Good addition of more specific typing imports to improve type safety and code clarity.
15-15: LGTM: Added dataset availability check.The import of
is_datasets_availableutility function enhances compatibility checking, which is used appropriately in the_get_dataloadermethod.
47-48: LGTM: Added necessary type annotations and caching attribute.The explicit
data_collatortype annotation and_eval_dataloadersdictionary for caching persistent evaluation dataloaders are well-designed additions that improve type safety and performance.
118-156: LGTM: Improved sampler method with explicit dataset parameter.The refactoring of
_get_train_samplerto accept an explicittrain_datasetparameter enhances flexibility and error handling. The method now properly handles None cases and maintains backward compatibility by defaulting toself.train_dataset.
158-190: LGTM: Enhanced eval sampler with required dataset parameter.The
_get_eval_samplermethod correctly requires a non-Noneeval_datasetparameter, improving error handling and making dependencies explicit. The logic for handling multipack sampling is clear and consistent.
192-251: LGTM: Excellent unified dataloader creation method.The new
_get_dataloadermethod consolidates dataloader creation logic effectively. Key improvements include:
- Unified parameter handling for batch size, samplers, and workers
- Proper dataset compatibility checking with
is_datasets_available()- Support for persistent workers and prefetch factors
- Intelligent caching of eval dataloaders to prevent fork bombs
- Clear separation of training vs non-training configurations
The implementation is well-structured and handles edge cases appropriately.
253-272: LGTM: Simplified training dataloader with improved error handling.The refactored
get_train_dataloadermethod properly validates the training dataset existence and uses the new unified_get_dataloadermethod. The dataset preprocessing for sample packing is handled correctly.🧰 Tools
🪛 Ruff (0.11.9)
261-262: Use a single
ifstatement instead of nestedifstatements(SIM102)
274-336: LGTM: Comprehensive evaluation dataloader with persistent caching.The refactored
get_eval_dataloadermethod includes excellent improvements:
- Proper validation of evaluation dataset availability
- Intelligent reuse of cached persistent dataloaders to prevent resource waste
- Correct handling of data collator switching for different evaluation modes
- Proper restoration of original data collators after temporary switches
The logic for handling different sample packing scenarios is comprehensive and correct.
338-346: LGTM: Improved bench sampler with proper validation.The updated
_get_bench_samplermethod now useshas_lengthutility for proper dataset validation, which is more robust than basic None checking.
348-401: LGTM: Comprehensive benchmarking dataloader implementation.The completely rewritten
get_bench_dataloadermethod is well-structured with:
- Proper dataset validation with clear error messages
- Correct data collator management with save/restore pattern
- Comprehensive parameter handling including conditional parameters
- Proper sampler integration for different dataset types
- Good documentation with clear method signature
The implementation follows established patterns from the other dataloader methods and handles edge cases appropriately.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/axolotl/core/trainers/base.py (1)
23-23: The unused import concern is now resolved.The
seed_workerimport is actually used on line 236 in the worker initialization function, so the previous review comment about removing this unused import is no longer applicable.
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
202-206: Remove redundant batch_size assignment.The batch_size assignment is redundant since
self.args.sample_packingis alwaysTruein this elif block (it's part of the condition on line 199). This will always assignbatch_sizeto itself.- batch_size = ( - batch_size - if self.args.sample_packing - else self.args.per_device_eval_batch_size - )Remove this redundant assignment since the condition will always be true in this context.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/core/trainers/base.py(4 hunks)tests/e2e/test_llama_vision.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (11)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (8)
tests/e2e/test_llama_vision.py (1)
89-89: LGTM! Consistent module path update in multimodal test.This change matches the update in the text-only test method and maintains consistency across both test configurations.
src/axolotl/core/trainers/base.py (7)
116-145: Good refactoring to make the sampler method more flexible.The updated
_get_train_samplermethod now accepts an optionaltrain_datasetparameter, making it more flexible and allowing explicit dataset passing to the multipack sampler. The logic correctly handles different sampling strategies.
147-173: Improved evaluation sampler with explicit dataset parameter.The
_get_eval_samplermethod now properly accepts aneval_datasetparameter and uses it directly instead of relying on internal state. This makes the method more predictable and testable.
186-186: Clear data collator selection logic.The conditional assignment of data collator based on training mode is straightforward and correct.
223-230: Well-implemented sampler handling logic.The method correctly handles both regular samplers and batch samplers, with proper parameter exclusion when using batch samplers. The mutually exclusive handling of
batch_sizeandbatch_sampleris implemented correctly.
234-239: Proper worker initialization for training.The worker initialization with seeding is correctly applied only for training dataloaders, using the appropriate parameters for distributed training.
248-256: Good implementation of persistent dataloader storage.The logic for storing non-prepared dataloaders for evaluation with persistent workers is well-implemented, handling both initialization and updates of the
_eval_dataloadersdictionary properly.
175-258: Excellent consolidation of dataloader creation logic.This new
_get_dataloadermethod successfully consolidates the previously separate dataloader creation logic into a single, flexible method. The method handles:
- Appropriate data collator selection
- Dataset column management based on training mode and sample packing
- Comprehensive dataloader parameter configuration
- Proper sampler integration
- Persistent dataloader storage for evaluation
The refactoring improves code maintainability and reduces duplication while supporting advanced features like sample packing.
| "lora_alpha": 16, | ||
| "lora_dropout": 0.05, | ||
| "lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", | ||
| "lora_target_modules": r"language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
LGTM! Module path update aligns with codebase conventions.
The removal of .model from the LoRA target modules path correctly aligns with the updated module structure.
🏁 Script executed:
#!/bin/bash
# Description: Verify the correct module path structure for Llama Vision models
# Expected: Find references to "language_model.layers" without ".model" component
echo "Checking for LoRA target module patterns in the codebase..."
rg -A 3 -B 3 "lora_target_modules.*language_model\.layers" --type py
echo -e "\nChecking for any remaining references to old pattern..."
rg "language_model\.model\.layers" --type pyLength of output: 1589
Inconsistent LoRA module path in monkeypatch
The end-to-end tests now target language_model.layers…, but src/axolotl/monkeypatch/lora_kernels.py still uses model.language_model.model.layers. This mismatch will break LoRA injection at runtime.
Please update the monkeypatch to match the new module path:
• File: src/axolotl/monkeypatch/lora_kernels.py
Change:
- layers = model.language_model.model.layers
+ layers = model.language_model.layersEnsure any other occurrences of language_model.model.layers are removed or updated to language_model.layers.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/lora_kernels.py, locate all occurrences of the module
path language_model.model.layers and update them to language_model.layers to
match the new module structure used in the tests and codebase. Ensure no
references to the old path remain to prevent runtime errors in LoRA injection.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
docs/multimodal.qmd(1 hunks)examples/llama-3-vision/lora-11b.yaml(1 hunks)tests/e2e/test_llama_vision.py(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/llama-3-vision/lora-11b.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/e2e/test_llama_vision.py
| # load_in_8bit: true | ||
| adapter: lora | ||
| lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' | ||
| lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
LGTM! LoRA target module path correction looks good.
The updated regex pattern correctly targets the model hierarchy with model.language_model.layers instead of the previous pattern. This change aligns with the broader codebase updates mentioned in the PR summary.
Please verify that this pattern matches the actual model architecture by running:
🏁 Script executed:
#!/bin/bash
# Description: Verify the LoRA target module pattern is consistent across the codebase
# Expected: All lora_target_modules should use the same pattern format
# Search for all lora_target_modules configurations
rg -A 2 -B 2 "lora_target_modules.*model\.language_model\.layers"
# Also check for any remaining old patterns that might need updating
rg -A 2 -B 2 "lora_target_modules.*language_model\.model\.layers"Length of output: 3449
Fix outdated LoRA module patterns in example configs
The update to use model.language_model.layers is correct, but several example YAMLs still reference the old pattern. Please update these to maintain consistency:
• examples/pixtral/lora-12b.yml
• examples/mistral/mistral-small-3.1-24B-lora.yml
• examples/llava/lora-7b.yaml
• examples/gemma3/gemma-3-4b-vision-qlora.yml
• examples/gemma3/gemma-3-4b-qlora.yml
Suggested change in each file (replace the old regex):
- lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
+ lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' | |
| - lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' | |
| + lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' |
🤖 Prompt for AI Agents
In docs/multimodal.qmd at line 46, the lora_target_modules regex pattern has
been correctly updated to use model.language_model.layers. To maintain
consistency across the codebase, update the same regex pattern in the example
YAML configuration files listed: examples/pixtral/lora-12b.yml,
examples/mistral/mistral-small-3.1-24B-lora.yml, examples/llava/lora-7b.yaml,
examples/gemma3/gemma-3-4b-vision-qlora.yml, and
examples/gemma3/gemma-3-4b-qlora.yml. Replace the old lora_target_modules regex
pattern with the new one matching model.language_model.layers as shown in
docs/multimodal.qmd line 46.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
examples/gemma3/gemma-3-4b-vision-qlora.yml (1)
33-33: Nit: Anchor and escape your regex
To ensure exact matching and prevent unintended partial matches, consider anchoring and escaping:-lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: '^model\\.language_model\\.layers\\.\\d+\\.(mlp|cross_attn|self_attn)\\.(up|down|gate|q|k|v|o)_proj$'examples/gemma3/gemma-3-4b-qlora.yml (1)
31-31: Nit: Anchor and escape your regex
For precision, wrap with anchors and escape dots:-lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: '^model\\.language_model\\.layers\\.\\d+\\.(mlp|cross_attn|self_attn)\\.(up|down|gate|q|k|v|o)_proj$'examples/pixtral/lora-12b.yml (1)
28-28: Nit: Anchor and escape regex for precision
Consider adding anchors and escaping dots:-lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: '^model\\.language_model\\.layers\\.\\d+\\.(mlp|cross_attn|self_attn)\\.(up|down|gate|q|k|v|o)_proj$'
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/gemma3/gemma-3-4b-qlora.yml(1 hunks)examples/gemma3/gemma-3-4b-vision-qlora.yml(1 hunks)examples/llava/lora-7b.yaml(1 hunks)examples/mistral/mistral-small-3.1-24B-lora.yml(1 hunks)examples/pixtral/lora-12b.yml(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- examples/llava/lora-7b.yaml
- examples/mistral/mistral-small-3.1-24B-lora.yml
🔇 Additional comments (5)
examples/gemma3/gemma-3-4b-vision-qlora.yml (1)
33-33: Regex prefix update approved
Thelora_target_modulespattern prefix has been updated tomodel.language_model, correctly reflecting the current module hierarchy.examples/gemma3/gemma-3-4b-qlora.yml (2)
31-31: Regex prefix update approved
Thelora_target_modulesregex now correctly usesmodel.language_modelas the prefix, aligning with the code’s attribute path.
31-31:Details
✅ Verification successful
Verify no residual old patterns
Run a global search to catch any configs still referencinglanguage_model.model:
🏁 Script executed:
rg "lora_target_modules" | grep -n "language_model\\.model"Length of output: 58
#!/bin/bash # Search for any leftover references to the old pattern rg -n "language_model\\.model"
No obsolete
language_model.modelreferences detected
A global search forlanguage_model.modelacross the repository returned no matches, confirming alllora_target_modulespatterns have been updated.examples/pixtral/lora-12b.yml (2)
28-28: Regex prefix update approved
The updatedlora_target_modulesprefix now matchesmodel.language_model, consistent with other example configurations.
28-28:Details
✅ Verification successful
Verify consistency across configs
Search for any remaining outdated patterns:
🏁 Script executed:
rg "lora_target_modules" | grep -n "language_model\\.model"Length of output: 58
All
lora_target_modulespatterns are consistent across configs
A search for any straylanguage_model.modeloccurrences withinlora_target_modulesyielded no results, so there are no outdated patterns to update.
| lora_alpha: 16 | ||
| lora_dropout: 0.05 | ||
| lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' | ||
| lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Verify no residual old patterns
Ensure consistency by searching for any configs still using the old language_model.model prefix:
Let me know if you’d like a sed-based script to batch-update remaining files.
🏁 Script executed:
rg "lora_target_modules" -A1 | grep -B1 "language_model\\.model"Length of output: 63
🏁 Script executed:
#!/bin/bash
# Search the entire repository for any occurrences of the old prefix
rg -n "language_model\.model"Length of output: 126
Residual old prefix detected in lora_kernels.py
The search revealed one remaining use of the old language_model.model path:
- File:
src/axolotl/monkeypatch/lora_kernels.py
Line 346:- layers = model.language_model.model.layers + layers = model.language_model.layers
Please update this occurrence (and any similar references) to remove the extra .model segment so it aligns with the new target-module pattern.
🤖 Prompt for AI Agents
In examples/gemma3/gemma-3-4b-vision-qlora.yml at line 33, the
lora_target_modules pattern is updated correctly, but there is a residual old
prefix "language_model.model" used in src/axolotl/monkeypatch/lora_kernels.py at
line 346. Locate this line and any similar references in that file and remove
the extra ".model" segment from the path to ensure consistency with the new
pattern. This will align all references to use the updated module path format.
|
I also think this file can be updated. Let me know if you'd like a hand with this. |
|
@salmanmohammadi , what needs to be changed in the linked file? The current error seems to be due to our linked test loads from https://huggingface.co/JackFram/llama-68m which is pytorch bin model https://github.com/axolotl-ai-cloud/axolotl/actions/runs/15346236028/job/43186938948?pr=2735 (Is this our sign to drop py2.5 support?) |
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
_quarto.yml(0 hunks)src/axolotl/core/builders/causal.py(1 hunks)tests/e2e/test_optimizers.py(2 hunks)
💤 Files with no reviewable changes (1)
- _quarto.yml
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/test_optimizers.py (1)
tests/e2e/utils.py (4)
check_model_output_exists(150-171)require_torch_2_5_1(56-65)require_torch_2_6_0(68-77)with_temp_dir(19-31)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
🔇 Additional comments (2)
tests/e2e/test_optimizers.py (2)
15-20: LGTM: Import addition aligns with test requirements.The addition of
require_torch_2_6_0to the imports is necessary for the decorator usage in thetest_came_pytorchmethod.
209-209: LGTM: Appropriate version gating for compatibility.Adding the
@require_torch_2_6_0decorator to thetest_came_pytorchmethod is a good solution that aligns with the PR's dependency bump objectives. This ensures the test only runs when PyTorch 2.6.0+ is available, which is particularly relevant given:
- The test uses the "JackFram/llama-68m" model and "came_pytorch" optimizer
- The PR comments mentioned errors related to this specific model
- The pattern follows existing version requirements used elsewhere in the codebase
This approach maintains test coverage while preventing compatibility issues on older PyTorch versions.
| training_arguments_kwargs["multipack_real_batches"] = ( | ||
| not self.cfg.flash_attention or self.cfg.multipack_real_batches | ||
| self.cfg.multipack_real_batches | ||
| if self.cfg.multipack_real_batches is not None | ||
| else not self.cfg.flash_attention | ||
| ) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Breaking change: Verify the new multipack_real_batches logic is intentional
This change modifies the precedence for determining multipack_real_batches and could be a breaking change:
Previous behavior: not self.cfg.flash_attention or self.cfg.multipack_real_batches
- When
flash_attention=False, result was alwaysTrueregardless ofmultipack_real_batchessetting - Only respected explicit
multipack_real_batcheswhenflash_attention=True
New behavior: Explicit configuration takes priority over derived behavior
- If
multipack_real_batchesis set, use that value - If
multipack_real_batchesisNone, fall back tonot flash_attention
While the new logic is more intuitive and respects user configuration, users with configs like:
flash_attention: false
multipack_real_batches: falsewill now get multipack_real_batches=False instead of True.
🏁 Script executed:
#!/bin/bash
# Verify impact of multipack_real_batches logic change
echo "Searching for configurations that might be affected by this change..."
# Look for configs with both flash_attention and multipack_real_batches settings
rg -A 5 -B 5 "flash_attention.*false|multipack_real_batches.*false" --type yaml
# Search for any documentation or comments about this behavior
rg -A 10 -B 5 "multipack_real_batches.*flash_attention|flash_attention.*multipack_real_batches"
# Look for related test files that might need updates
fd "test.*multipack|multipack.*test" --type f --exec cat {}Length of output: 16809
Document and test explicit multipack_real_batches override with flash_attention=False
The change in src/axolotl/core/builders/causal.py (246–250) only affects users who explicitly set
flash_attention: false
multipack_real_batches: false—default behavior (no multipack_real_batches key) remains unchanged. To avoid surprises:
• Update the configuration reference or CLI help (TrainerCliArgs) to call out this override priority.
• Add a unit or E2E test for the case where flash_attention=False and multipack_real_batches=false to ensure users see the intended behavior.
• Consider a note in CHANGELOG or upgrade guide highlighting this breaking change for explicit configs.
🤖 Prompt for AI Agents
In src/axolotl/core/builders/causal.py around lines 246 to 250, the logic for
setting multipack_real_batches now prioritizes the explicit configuration over
the derived value from flash_attention, which is a breaking change for users who
set flash_attention to false and multipack_real_batches to false. To fix this,
update the TrainerCliArgs configuration reference or CLI help to clearly
document that multipack_real_batches overrides flash_attention when explicitly
set. Additionally, add a unit or end-to-end test covering the scenario where
flash_attention is false and multipack_real_batches is false to verify the new
behavior. Finally, include a note in the CHANGELOG or upgrade guide to inform
users about this breaking change.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
60-97: Consider refactoring to reduce code duplication.The tensor distribution logic is correct, but the casting and contiguity handling is duplicated between the main process (lines 70-77) and other ranks (lines 89-96). While not critical, this could be simplified.
Consider extracting the common pattern into a helper:
for param_name in param_names: mesh = meta_sharded_sd[param_name].device_mesh if accelerator.is_main_process: full_param = full_sd[param_name].detach().cuda() dist.broadcast(full_param, src=0, group=mesh.get_group()) sharded_tensor = distribute_tensor( full_param, mesh, sharded_sd[param_name].placements ) - to_contiguous, casting_dtype = _infer_parameter_dtype( - model, - param_name, - full_param, - ) - sharded_tensor = _cast_and_contiguous( - sharded_tensor, to_contiguous, casting_dtype - ) - sharded_sd[param_name] = sharded_tensor else: full_tensor = torch.empty( sharded_sd[param_name].size(), device="cuda", dtype=sharded_sd[param_name].dtype, ) dist.broadcast(full_tensor, src=0, group=mesh.get_group()) sharded_tensor = distribute_tensor( full_tensor, mesh, sharded_sd[param_name].placements ) - to_contiguous, casting_dtype = _infer_parameter_dtype( - model, - param_name, - full_tensor, - ) - sharded_tensor = _cast_and_contiguous( - sharded_tensor, to_contiguous, casting_dtype - ) - sharded_sd[param_name] = sharded_tensor + + # Apply dtype casting and contiguity after distribution + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_param if accelerator.is_main_process else full_tensor, + ) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) + sharded_sd[param_name] = sharded_tensor
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
🔇 Additional comments (5)
src/axolotl/monkeypatch/accelerate/fsdp2.py (5)
21-23: Good documentation update!The clarification that the model is expected to be on meta device to prevent VRAM spikes is helpful for users of this function.
28-31: LGTM!The initialization logic correctly retrieves the meta device state dict and prepares for sharded tensor distribution.
33-52: Well-designed helper function with good LoRA support!The function properly handles both regular parameters and LoRA parameters that aren't standard PyTorch parameters. The float8 dtype checking and casting logic is appropriate.
99-101: Correct use ofassign=Truefor meta device parameters!The
assign=Trueparameter is essential here since the model parameters are on meta device. This ensures the state dict tensors are assigned directly rather than copied.
194-211: LGTM! Simplified patching aligns with dependency updates.The removal of
set_state_dict_typepatching and focus on just the essential patches (fsdp2_load_full_state_dictandget_state_dict) is appropriate for the updated accelerate dependencies.🧰 Tools
🪛 Ruff (0.11.9)
199-203: Do not call
setattrwith a constant attribute value. It is not any safer than normal property access.Replace
setattrwith assignment(B010)
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/e2e/multigpu/solo/test_grpo.py (1)
268-359: Consider improving test consistency and parameterization.The new test
test_llama_lora_speffectively exercises the sequence parallelism functionality. However, there are some inconsistencies with the existing test pattern:
Missing parameterization: Unlike
test_llama_dorawhich tests with[1, 2]GPUs, this test hardcodes 2 processes (line 348). Consider adding parameterization for consistency.Configuration differences: The test removes DoRA configuration (
peft_use_dora) and addssequence_parallel_degree: 2, which properly exercises the new functionality from the main file.Consider applying this diff to maintain consistency with the existing test pattern:
+ @pytest.mark.parametrize( + "num_gpus", + [2], # Only test with 2 GPUs since sequence_parallel_degree is 2 + ) @require_vllm - def test_llama_lora_sp(self, temp_dir): + def test_llama_lora_sp(self, temp_dir, num_gpus): # ... configuration code ... "--num-processes", - str(2), + str(num_gpus),Alternatively, if sequence parallelism specifically requires 2 processes, document this requirement in a comment.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/core/trainers/grpo/__init__.py(2 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/core/trainers/grpo/__init__.py (1)
src/axolotl/utils/schemas/trl.py (1)
TRLConfig(6-163)
tests/e2e/multigpu/solo/test_grpo.py (3)
tests/e2e/utils.py (1)
require_vllm(92-107)tests/conftest.py (1)
temp_dir(414-419)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (2)
src/axolotl/core/trainers/grpo/__init__.py (1)
72-73: LGTM! Sequence parallel degree support added correctly.The logic correctly sets the
sequence_parallel_degreeparameter when it's greater than 1, which aligns with the test case that uses this feature.tests/e2e/multigpu/solo/test_grpo.py (1)
265-267: LGTM! Minor formatting improvement.The parentheses around
recursive_kill(vllm_process)are unnecessary but don't affect functionality.
| trl: TRLConfig = cfg.trl | ||
| if trl.reward_funcs and isinstance(trl.reward_funcs, list): | ||
| for reward_func in trl.reward_funcs: | ||
| trainer_kwargs["reward_fn"] = cls.get_reward_func(reward_func) | ||
| elif trl.reward_funcs and isinstance(trl.reward_funcs, str): | ||
| trainer_kwargs["reward_fn"] = cls.get_reward_func(trl.reward_funcs) | ||
|
|
There was a problem hiding this comment.
Fix logic error in reward function handling.
There are two issues with this implementation:
- Missing null check: Line 129 accesses
cfg.trlwithout verifying it exists, which could cause an AttributeError. - Logic error: When
trl.reward_funcsis a list, the loop overwritestrainer_kwargs["reward_fn"]in each iteration, meaning only the last reward function is actually set.
Apply this diff to fix both issues:
- trl: TRLConfig = cfg.trl
- if trl.reward_funcs and isinstance(trl.reward_funcs, list):
- for reward_func in trl.reward_funcs:
- trainer_kwargs["reward_fn"] = cls.get_reward_func(reward_func)
- elif trl.reward_funcs and isinstance(trl.reward_funcs, str):
- trainer_kwargs["reward_fn"] = cls.get_reward_func(trl.reward_funcs)
+ if cfg.trl:
+ trl: TRLConfig = cfg.trl
+ if trl.reward_funcs and isinstance(trl.reward_funcs, list):
+ # Handle multiple reward functions - this may need different logic
+ # depending on how the trainer expects to receive multiple functions
+ trainer_kwargs["reward_fn"] = [cls.get_reward_func(reward_func) for reward_func in trl.reward_funcs]
+ elif trl.reward_funcs and isinstance(trl.reward_funcs, str):
+ trainer_kwargs["reward_fn"] = cls.get_reward_func(trl.reward_funcs)Note: Please verify how the trainer expects to receive multiple reward functions - as a list or if only one function should be set.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| trl: TRLConfig = cfg.trl | |
| if trl.reward_funcs and isinstance(trl.reward_funcs, list): | |
| for reward_func in trl.reward_funcs: | |
| trainer_kwargs["reward_fn"] = cls.get_reward_func(reward_func) | |
| elif trl.reward_funcs and isinstance(trl.reward_funcs, str): | |
| trainer_kwargs["reward_fn"] = cls.get_reward_func(trl.reward_funcs) | |
| if cfg.trl: | |
| trl: TRLConfig = cfg.trl | |
| if trl.reward_funcs and isinstance(trl.reward_funcs, list): | |
| # Handle multiple reward functions - this may need different logic | |
| # depending on how the trainer expects to receive multiple functions | |
| trainer_kwargs["reward_fn"] = [ | |
| cls.get_reward_func(reward_func) | |
| for reward_func in trl.reward_funcs | |
| ] | |
| elif trl.reward_funcs and isinstance(trl.reward_funcs, str): | |
| trainer_kwargs["reward_fn"] = cls.get_reward_func(trl.reward_funcs) |
🤖 Prompt for AI Agents
In src/axolotl/core/trainers/grpo/__init__.py around lines 129 to 135, first add
a check to ensure cfg.trl exists before accessing it to avoid AttributeError.
Then, fix the logic for handling trl.reward_funcs when it is a list by
collecting all reward functions into a list and assigning that list to
trainer_kwargs["reward_fn"] instead of overwriting it in each loop iteration.
Confirm whether the trainer expects a list of reward functions or a single
function and adjust accordingly.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/core/trainers/grpo/args.py (1)
16-16: Good addition, but consider adding documentation.The new
sequence_parallel_degreefield is well-implemented with appropriate type annotation and default value. However, consider adding a docstring or inline comment to explain the purpose and expected values for this configuration field.- sequence_parallel_degree: int | None = None + sequence_parallel_degree: int | None = None # Degree of sequence parallelism for distributed trainingAlternatively, you could add a docstring to the field or expand the class docstring to include information about all configuration options.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/core/trainers/grpo/__init__.py(3 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/core/trainers/grpo/init.py
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
| self.sp_group = None | ||
| self.rank = dist.get_rank() | ||
| self.world_size = dist.get_world_size() | ||
| self.local_rank = 0 | ||
| self.local_world_size = 1 | ||
|
|
||
| def train(self, *args, **kwargs): | ||
| # Initialize the SP group | ||
| self.sp_group = get_ring_attn_group() | ||
| self.rank = dist.get_rank() | ||
| self.world_size = dist.get_world_size() | ||
| self.local_rank = dist.get_rank(group=self.sp_group) | ||
| self.local_world_size = dist.get_world_size(group=self.sp_group) | ||
|
|
||
| return super().train(*args, **kwargs) | ||
|
|
There was a problem hiding this comment.
@djsaunde Does this seem sane? we talked about this the other day since we don't have ring attention groups initialized until the train is called with the context manager.
There was a problem hiding this comment.
Yes, seems fine, though you don't need to re-set self.rank or self.world_size.
| accelerate==1.7.0 | ||
| datasets==3.6.0 | ||
| deepspeed>=0.17.0 | ||
| trl==0.18.0 |
There was a problem hiding this comment.
TRL released a patch https://github.com/huggingface/trl/releases/tag/v0.18.1 but not sure if there's anything that we depend on
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/test_packed_dataset.py (1)
76-154: Clarify the sequence length vs batch shape discrepancy.The test configures
sequence_len: 1024but asserts that batch shapes are(1, 8192). Please clarify:
- Is this 8x multiplier intentional for testing sample packing behavior?
- Does the multipack sampler concatenate sequences to achieve this shape?
- Should the test include a comment explaining this expected behavior?
The test structure and LoRA configuration look good otherwise.
src/axolotl/core/trainers/base.py (1)
175-260: Consider breaking down the large method for better maintainability.The
_get_dataloadermethod is quite comprehensive but could benefit from being broken into smaller, focused helper methods for better readability and testability.Consider extracting logical chunks into helper methods such as:
_preprocess_dataset_for_dataloader(dataset, is_training, description)_build_dataloader_params(batch_size, data_collator, sampler_fn, dataset, is_training)_cache_eval_dataloader_if_needed(dataloader, dataloader_key)This would make the main method more focused and easier to test individual components.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (30)
.github/workflows/multi-gpu-e2e.yml(1 hunks).github/workflows/tests.yml(3 hunks)_quarto.yml(0 hunks)docs/multimodal.qmd(1 hunks)examples/gemma3/gemma-3-4b-qlora.yml(1 hunks)examples/gemma3/gemma-3-4b-vision-qlora.yml(1 hunks)examples/llama-3-vision/lora-11b.yaml(1 hunks)examples/llava/lora-7b.yaml(1 hunks)examples/mistral/mistral-small-3.1-24B-lora.yml(1 hunks)examples/pixtral/lora-12b.yml(1 hunks)requirements.txt(1 hunks)scripts/cutcrossentropy_install.py(1 hunks)setup.py(1 hunks)src/axolotl/core/builders/causal.py(1 hunks)src/axolotl/core/trainers/base.py(4 hunks)src/axolotl/core/trainers/grpo/__init__.py(3 hunks)src/axolotl/core/trainers/grpo/args.py(1 hunks)src/axolotl/core/trainers/grpo/trainer.py(1 hunks)src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py(0 hunks)src/axolotl/loaders/patch_manager.py(0 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)src/axolotl/monkeypatch/attention/mllama.py(0 hunks)src/axolotl/monkeypatch/gemma3.py(0 hunks)src/axolotl/monkeypatch/lora_kernels.py(1 hunks)src/axolotl/train.py(1 hunks)tests/e2e/integrations/test_kd.py(1 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)tests/e2e/test_llama_vision.py(2 hunks)tests/e2e/test_optimizers.py(2 hunks)tests/test_packed_dataset.py(2 hunks)
💤 Files with no reviewable changes (5)
- _quarto.yml
- src/axolotl/loaders/patch_manager.py
- src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
- src/axolotl/monkeypatch/attention/mllama.py
- src/axolotl/monkeypatch/gemma3.py
✅ Files skipped from review due to trivial changes (3)
- setup.py
- .github/workflows/multi-gpu-e2e.yml
- docs/multimodal.qmd
🚧 Files skipped from review as they are similar to previous changes (17)
- requirements.txt
- tests/e2e/integrations/test_kd.py
- src/axolotl/core/trainers/grpo/args.py
- tests/e2e/test_llama_vision.py
- scripts/cutcrossentropy_install.py
- examples/gemma3/gemma-3-4b-qlora.yml
- examples/gemma3/gemma-3-4b-vision-qlora.yml
- .github/workflows/tests.yml
- examples/pixtral/lora-12b.yml
- examples/llama-3-vision/lora-11b.yaml
- src/axolotl/core/builders/causal.py
- examples/mistral/mistral-small-3.1-24B-lora.yml
- tests/e2e/test_optimizers.py
- examples/llava/lora-7b.yaml
- src/axolotl/core/trainers/grpo/init.py
- src/axolotl/monkeypatch/lora_kernels.py
- tests/e2e/multigpu/solo/test_grpo.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/trainers/base.py (1)
src/axolotl/core/trainers/grpo/trainer.py (1)
_get_train_sampler(155-175)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (12)
src/axolotl/train.py (1)
207-207: LGTM! Good defensive programming practice.The additional truthiness check prevents adding None or falsy reference models to the models list, which could cause issues in the SequenceParallelContextManager. This change aligns well with the sequence parallelism improvements.
tests/test_packed_dataset.py (1)
9-18: LGTM! Appropriate imports for the new test functionality.The imported modules correctly support the new LoRA packing test with CLI args, dataset loading, trainer setup, and configuration utilities.
src/axolotl/core/trainers/grpo/trainer.py (1)
139-154: LGTM! Good architectural improvement for sequence parallel initialization.Moving the sequence parallel group initialization from constructor to the
trainmethod ensures that:
- The SP group is available when training starts (not during construction)
- Distributed training setup is complete before initialization
- Fresh rank and world size information is obtained at training time
The initialization logic correctly updates both global and local (SP group) rank/world size information before delegating to the superclass.
src/axolotl/monkeypatch/accelerate/fsdp2.py (3)
33-58: Well-designed helper functions for parameter handling.The
_infer_parameter_dtypeand_cast_and_contiguoushelper functions provide clean abstractions for:
- Dtype inference with special handling for float8_e4m3fn and floating-point types
- LoRA parameter support through submodule attribute access
- Contiguity checking and casting operations
The separation of concerns makes the code more maintainable and testable.
70-77: Improved tensor casting and contiguity handling.The refactored code properly addresses tensor dtype casting and contiguity requirements using the helper functions. This addresses the previous review concern about the additional
.contiguous()call by making the contiguity handling explicit and conditional based on the original parameter properties.Also applies to: 89-96
99-101: Correct use of assign=True for meta device parameters.Using
assign=Trueinload_state_dictis appropriate for meta device parameters, as it allows direct assignment without device/dtype mismatches that would occur with normal loading.src/axolotl/core/trainers/base.py (6)
9-10: Import additions look appropriate.The new imports
partial,wraps,Callable,Literal, andOptionalare all used in the refactored code and support the new functionality.
23-23: Theseed_workerimport is now correctly used.This addresses the previous review comment about unused imports -
seed_workeris now used in the_get_dataloadermethod on lines 237-240.
116-142: Good refactor with explicit dataset parameter.The addition of the optional
train_datasetparameter and its usage on line 142 makes the sampler creation more explicit and testable. The method correctly handles curriculum sampling and sample packing cases.
147-171: Eval sampler logic is consistent and well-structured.The evaluation sampler method follows the same pattern as the training sampler and correctly handles the multipack evaluation logic based on the configuration flags.
175-213: The dataloader preprocessing logic resolves previous conditional issues.The new conditional structure for dataset preprocessing is much cleaner than the previous implementation and correctly handles the various combinations of training/evaluation modes with sample packing. This addresses the contradictory logic flagged in previous reviews.
214-260: Comprehensive dataloader parameter construction with good caching support.The unified dataloader parameter construction handles all the necessary configurations including batch sampling, worker settings, and sampler integration. The addition of evaluation dataloader caching (lines 252-257) for persistent workers is a good optimization.
| batch_size = ( | ||
| batch_size | ||
| if self.args.sample_packing | ||
| else self.args.per_device_eval_batch_size | ||
| ) |
There was a problem hiding this comment.
Fix redundant conditional in batch size assignment.
The ternary condition on lines 203-205 is redundant because we're already in a branch where self.args.sample_packing is True, so the condition will always evaluate to the first case.
- batch_size = (
- batch_size
- if self.args.sample_packing
- else self.args.per_device_eval_batch_size
- )
+ # batch_size remains unchanged since sample_packing is already TrueAlternatively, if this was intended to set a different batch size for evaluation, clarify the logic:
- batch_size = (
- batch_size
- if self.args.sample_packing
- else self.args.per_device_eval_batch_size
- )
+ # Use the passed batch_size for evaluation with sample packing
+ passCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/core/trainers/base.py around lines 202 to 206, the batch_size
assignment uses a redundant conditional checking self.args.sample_packing, which
is always True in this context. Simplify the assignment by directly setting
batch_size to the value corresponding to the sample_packing condition without
the ternary operator, or clarify and adjust the logic if different batch sizes
are intended for evaluation scenarios.
salmanmohammadi
left a comment
There was a problem hiding this comment.
LGTM, @djsaunde could you take a pass?
| class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): | ||
| """Axolotl GRPO Config for GRPO training""" | ||
|
|
||
| sequence_parallel_degree: int | None = None |
There was a problem hiding this comment.
Why do we need this? AFAICT this isn't used
There was a problem hiding this comment.
we have this code in AxolotlGRPOSequenceParallelTrainer:
num_sp_groups = num_processes // self.args.sequence_parallel_degree
djsaunde
left a comment
There was a problem hiding this comment.
I wonder if there's an easier way to update our custom dataloader logic, seems annoying to have to update whenever upstream changes
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
src/axolotl/cli/vllm_serve.py (1)
83-141: 💡 Verification agent🧩 Analysis chain
Add error handling and document maintenance risks.
The monkey patching approach has maintenance risks and the Connection operations lack error handling.
- Add error handling for Connection operations:
# Send ready signal to parent process - connection.send({"status": "ready"}) + try: + connection.send({"status": "ready"}) + except Exception as e: + # Log the error or handle appropriately + raise RuntimeError(f"Failed to send ready signal: {e}") while True: # Wait for commands from the parent process try: command = connection.recv() except KeyboardInterrupt: llm.collective_rpc(method="close_communicator") break + except Exception as e: + # Log the error or handle appropriately + raise RuntimeError(f"Failed to receive command: {e}")
- Document the maintenance risk with a comment:
def patch_vllm_worker(): + """ + Monkey patch the vllm worker to add reasoning support. + + WARNING: This patch modifies the internal implementation of trl.scripts.vllm_serve.llm_worker. + It may break if the upstream implementation changes. Consider contributing this feature + upstream or finding a more maintainable approach. + """ from multiprocessing.connection import ConnectionLet me verify if the upstream
llm_workerfunction signature is compatible:
🏁 Script executed:
#!/bin/bash # Description: Check the signature of the original llm_worker function in trl # to ensure our patch is compatible # Search for the llm_worker function definition in the trl package rg -A 5 "def llm_worker" --glob "**/*.py" | grep -E "(def llm_worker|script_args|ScriptArguments)"Length of output: 225
Add error handling and document the maintenance risk of monkey‐patching
The
patch_vllm_workerhelper replacestrl.scripts.vllm_serve.llm_workerin place, which may break when the upstream implementation changes. It also doesn’t guard against failures in theConnectionAPI. Please:
- Annotate the maintenance risk at the top of
patch_vllm_workerwith a clear docstring warning.- Wrap all
connection.sendandconnection.recvcalls intry/exceptblocks and rethrow or log a clear error.Suggested diff in
src/axolotl/cli/vllm_serve.py:def patch_vllm_worker(): + """ + Monkey-patch the vLLM worker to add reasoning support. + + WARNING: This overrides `trl.scripts.vllm_serve.llm_worker` internally. + Upstream changes to that implementation may silently break this patch. + Consider contributing this feature upstream or adopting a more robust extension mechanism. + """ from multiprocessing.connection import Connection @@ - # Send ready signal to parent process - connection.send({"status": "ready"}) + # Send ready signal to parent process + try: + connection.send({"status": "ready"}) + except Exception as e: + raise RuntimeError(f"[vllm_serve] failed to send ready signal: {e}") while True: # Wait for commands from the parent process - try: - command = connection.recv() - except KeyboardInterrupt: - llm.collective_rpc(method="close_communicator") - break + try: + command = connection.recv() + except KeyboardInterrupt: + llm.collective_rpc(method="close_communicator") + break + except Exception as e: + raise RuntimeError(f"[vllm_serve] failed to receive command: {e}") # Handle commands if command["type"] in ["call", "fire_and_forget"]:
🧹 Nitpick comments (1)
src/axolotl/cli/vllm_serve.py (1)
23-23: Consider Python version compatibility for type hints.The type hint
bool | Noneuses Python 3.10+ union syntax. If the project needs to support older Python versions, consider usingOptional[bool]instead.For broader Python version compatibility:
- enable_reasoning: bool | None = field(default=None, kw_only=True) + enable_reasoning: Optional[bool] = field(default=None, kw_only=True)Also add the import:
-from typing import Union +from typing import Optional, Union
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
requirements.txt(1 hunks)src/axolotl/cli/args.py(1 hunks)src/axolotl/cli/vllm_serve.py(3 hunks)src/axolotl/utils/schemas/vllm.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/schemas/vllm.py
🚧 Files skipped from review as they are similar to previous changes (1)
- requirements.txt
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
| enable_reasoning: Optional[bool] = field( | ||
| default=None, | ||
| ) | ||
|
|
||
| reasoning_parser: Optional[str] = field( | ||
| default=None, | ||
| ) | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Add metadata with help descriptions for consistency.
The new fields enable_reasoning and reasoning_parser lack metadata with help descriptions, while all other fields in the VllmServeCliArgs dataclass include them. This inconsistency affects CLI discoverability and user experience.
Apply this diff to add descriptive metadata:
enable_reasoning: Optional[bool] = field(
default=None,
+ metadata={
+ "help": "Enable reasoning mode for VLLM generation. When enabled, the model can perform multi-step reasoning."
+ },
)
reasoning_parser: Optional[str] = field(
default=None,
+ metadata={
+ "help": "Specify the reasoning parser to use for parsing model outputs in reasoning mode."
+ },
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| enable_reasoning: Optional[bool] = field( | |
| default=None, | |
| ) | |
| reasoning_parser: Optional[str] = field( | |
| default=None, | |
| ) | |
| enable_reasoning: Optional[bool] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Enable reasoning mode for VLLM generation. When enabled, the model can perform multi-step reasoning." | |
| }, | |
| ) | |
| reasoning_parser: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Specify the reasoning parser to use for parsing model outputs in reasoning mode." | |
| }, | |
| ) |
🤖 Prompt for AI Agents
In src/axolotl/cli/args.py around lines 91 to 98, the fields enable_reasoning
and reasoning_parser are missing metadata with help descriptions, unlike other
fields in the VllmServeCliArgs dataclass. Add metadata to both fields with
appropriate help strings describing their purpose to maintain consistency and
improve CLI usability.
Summary by CodeRabbit
modalpackage version in CI workflows.cut-cross-entropypackage.