checkpoint model on first step callback#2906
Conversation
WalkthroughA new configuration option, Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant TrainerBuilder
participant Callback
participant Trainer
User->>Config: Set save_first_step (True/False)
Config->>TrainerBuilder: Pass configuration
TrainerBuilder->>Callback: Add SaveModelOnFirstStepCallback if save_first_step is True
TrainerBuilder->>Trainer: Build with callbacks
Trainer->>Callback: on_step_end (after step 1)
Callback->>Trainer: If step==1, set control.should_save=True
Trainer->>Trainer: Save checkpoint if should_save
Possibly related PRs
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (92)
✅ Files skipped from review due to trivial changes (91)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
winglian
left a comment
There was a problem hiding this comment.
Good to go once prints from testing are removed.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/core/builders/base.py(2 hunks)src/axolotl/utils/callbacks/__init__.py(4 hunks)src/axolotl/utils/schemas/config.py(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/builders/base.py (1)
src/axolotl/utils/callbacks/__init__.py (1)
SaveModelOnFirstStepCallback(143-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (7)
src/axolotl/utils/schemas/config.py (1)
697-702: LGTM! Configuration field properly implemented.The new
save_first_stepfield is well-positioned among other checkpoint-related options and includes appropriate documentation. The default value ofTruealigns with the PR objective to enable early failure detection by default.src/axolotl/core/builders/base.py (2)
39-39: LGTM! Proper import addition.The import is correctly added to the existing callback imports.
145-146: LGTM! Callback integration follows established patterns.The conditional logic properly checks the configuration flag and appends the callback when enabled, following the same pattern as other optional callbacks in the method.
src/axolotl/utils/callbacks/__init__.py (4)
67-67: LGTM! Type annotation improvement.Adding the return type annotation improves code clarity and type safety.
103-103: LGTM! Type annotation and pylint directive improvement.The return type annotation and unused argument directive improve code quality.
125-125: LGTM! Type annotation and pylint directive improvement.The return type annotation and unused argument directive improve code quality.
129-129: LGTM! Type annotation improvement.Adding the return type annotation improves code clarity and type safety.
|
📖 Documentation Preview: https://6876a30c4b8cc53b5b9d7ab3--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 60428fc |
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/e2e/solo/test_save_first_step.py (2)
1-3: Fix docstring inconsistency.The docstring mentions "relora llama" but this file is specifically for testing the
save_first_stepcallback feature.-""" -E2E tests for relora llama -""" +""" +E2E tests for save_first_step callback +"""
22-53: Consider reducing code duplication.Both test methods have nearly identical configurations. Consider extracting the common configuration into a helper method to reduce duplication and improve maintainability.
+ def _get_base_config(self, temp_dir, save_first_step): + return DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": save_first_step, + } + )Also applies to: 62-94
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (53)
tests/e2e/integrations/test_cut_cross_entropy.py(2 hunks)tests/e2e/integrations/test_hooks.py(1 hunks)tests/e2e/integrations/test_kd.py(1 hunks)tests/e2e/integrations/test_liger.py(2 hunks)tests/e2e/integrations/test_llm_compressor.py(1 hunks)tests/e2e/multigpu/patched/test_sp.py(1 hunks)tests/e2e/multigpu/solo/test_flex.py(1 hunks)tests/e2e/multigpu/solo/test_grpo.py(3 hunks)tests/e2e/multigpu/test_eval.py(2 hunks)tests/e2e/multigpu/test_gemma3.py(1 hunks)tests/e2e/multigpu/test_llama.py(12 hunks)tests/e2e/multigpu/test_qwen2.py(1 hunks)tests/e2e/multigpu/test_ray.py(2 hunks)tests/e2e/patched/test_4d_multipack_llama.py(2 hunks)tests/e2e/patched/test_activation_checkpointing.py(1 hunks)tests/e2e/patched/test_fa_xentropy.py(1 hunks)tests/e2e/patched/test_falcon_samplepack.py(2 hunks)tests/e2e/patched/test_flattening.py(1 hunks)tests/e2e/patched/test_fused_llama.py(1 hunks)tests/e2e/patched/test_llama_s2_attention.py(2 hunks)tests/e2e/patched/test_lora_llama_multipack.py(2 hunks)tests/e2e/patched/test_mistral_samplepack.py(2 hunks)tests/e2e/patched/test_mixtral_samplepack.py(2 hunks)tests/e2e/patched/test_model_patches.py(2 hunks)tests/e2e/patched/test_peft_embeddings.py(1 hunks)tests/e2e/patched/test_phi_multipack.py(2 hunks)tests/e2e/patched/test_resume.py(1 hunks)tests/e2e/patched/test_sp.py(1 hunks)tests/e2e/patched/test_unsloth_qlora.py(3 hunks)tests/e2e/solo/test_flex.py(1 hunks)tests/e2e/solo/test_relora_llama.py(1 hunks)tests/e2e/solo/test_save_first_step.py(1 hunks)tests/e2e/test_deepseekv3.py(2 hunks)tests/e2e/test_dpo.py(7 hunks)tests/e2e/test_embeddings_lr.py(2 hunks)tests/e2e/test_evaluate.py(1 hunks)tests/e2e/test_falcon.py(3 hunks)tests/e2e/test_gemma3_text.py(2 hunks)tests/e2e/test_llama.py(4 hunks)tests/e2e/test_llama_pretrain.py(1 hunks)tests/e2e/test_llama_vision.py(2 hunks)tests/e2e/test_lora_llama.py(1 hunks)tests/e2e/test_mamba.py(1 hunks)tests/e2e/test_mistral.py(2 hunks)tests/e2e/test_mixtral.py(5 hunks)tests/e2e/test_optimizers.py(5 hunks)tests/e2e/test_packing_loss.py(1 hunks)tests/e2e/test_phi.py(2 hunks)tests/e2e/test_process_reward_model_smollm2.py(1 hunks)tests/e2e/test_qat.py(2 hunks)tests/e2e/test_qwen.py(1 hunks)tests/e2e/test_reward_model_smollm2.py(1 hunks)tests/e2e/test_schedulers.py(1 hunks)
✅ Files skipped from review due to trivial changes (26)
- tests/e2e/multigpu/solo/test_flex.py
- tests/e2e/multigpu/test_ray.py
- tests/e2e/patched/test_phi_multipack.py
- tests/e2e/patched/test_lora_llama_multipack.py
- tests/e2e/test_mistral.py
- tests/e2e/patched/test_sp.py
- tests/e2e/patched/test_model_patches.py
- tests/e2e/patched/test_4d_multipack_llama.py
- tests/e2e/test_schedulers.py
- tests/e2e/test_qat.py
- tests/e2e/integrations/test_kd.py
- tests/e2e/integrations/test_cut_cross_entropy.py
- tests/e2e/test_falcon.py
- tests/e2e/test_deepseekv3.py
- tests/e2e/patched/test_unsloth_qlora.py
- tests/e2e/test_llama_vision.py
- tests/e2e/test_dpo.py
- tests/e2e/multigpu/solo/test_grpo.py
- tests/e2e/patched/test_llama_s2_attention.py
- tests/e2e/test_phi.py
- tests/e2e/test_embeddings_lr.py
- tests/e2e/test_mixtral.py
- tests/e2e/test_llama.py
- tests/e2e/integrations/test_liger.py
- tests/e2e/test_optimizers.py
- tests/e2e/multigpu/test_llama.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: preview
🔇 Additional comments (30)
tests/e2e/test_lora_llama.py (1)
52-52: Configuration addition looks good.The explicit setting of
save_first_step: Falseappropriately disables the new checkpoint callback during testing, avoiding unnecessary overhead in this short-running test.tests/e2e/multigpu/test_qwen2.py (1)
75-75: Appropriate configuration for multi-GPU testing.Setting
save_first_step: Falseis particularly sensible in this multi-GPU FSDP test context, avoiding potential complications with early checkpointing during distributed training.tests/e2e/patched/test_peft_embeddings.py (1)
52-52: Good integration with PEFT testing.Disabling
save_first_stepappropriately prevents interference with the test's focus on embedding upcast behavior and avoids potential early checkpointing issues with quantized models.tests/e2e/patched/test_activation_checkpointing.py (1)
72-72: Maintains test focus on activation checkpointing.Setting
save_first_step: Falseappropriately prevents potential confusion between gradient checkpointing (the test's focus) and model checkpointing (the new callback feature).tests/e2e/test_qwen.py (1)
62-62: Appropriate for DPO testing context.Disabling
save_first_stepis sensible for this DPO test, preventing potential interference with preference optimization training while maintaining test focus.tests/e2e/test_packing_loss.py (1)
51-51: LGTM - Consistent implementation of new configuration option.The addition of
"save_first_step": Falseis consistent with the PR objectives to introduce early checkpointing capability. Setting it toFalsein tests maintains existing behavior while allowing the feature to be tested elsewhere.tests/e2e/test_reward_model_smollm2.py (1)
61-61: LGTM - Consistent implementation of new configuration option.The addition of
"save_first_step": Falsemaintains consistency with the broader PR changes and appropriately disables the new checkpointing feature for this test to preserve existing behavior.tests/e2e/patched/test_flattening.py (1)
64-64: LGTM - Consistent implementation of new configuration option.The addition of
"save_first_step": Falseis properly placed within the configuration dictionary and aligns with the PR's systematic approach to introducing the new checkpointing feature while maintaining existing test behavior.tests/e2e/patched/test_resume.py (1)
61-61: LGTM - Appropriate for resume testing context.The addition of
"save_first_step": Falseis well-placed and particularly suitable for a resume test, where the focus is on resumption behavior rather than initial checkpointing functionality.tests/e2e/solo/test_flex.py (1)
52-52: LGTM - Appropriate for focused feature testing.The addition of
"save_first_step": Falseis correctly implemented and appropriate for this flex attention test, allowing it to focus on the core functionality without interference from the new checkpointing feature.tests/e2e/patched/test_fa_xentropy.py (1)
65-65: LGTM: Appropriate test configuration update.Setting
save_first_steptoFalsein tests is the correct approach to avoid unnecessary checkpointing overhead while maintaining focus on the test's specific functionality.tests/e2e/solo/test_relora_llama.py (1)
68-68: LGTM: Prevents interference with ReLoRA checkpoint validation.Disabling first-step checkpointing is appropriate here since this test has specific checkpoint validation logic for ReLoRA functionality, and the additional checkpoint would complicate the test assertions.
tests/e2e/test_evaluate.py (1)
39-39: LGTM: Appropriate for evaluation-focused test.Disabling first-step checkpointing is correct for an evaluation test since checkpointing behavior is not relevant to validating the evaluate CLI functionality.
tests/e2e/test_process_reward_model_smollm2.py (1)
52-52: LGTM: Maintains focus on process reward model functionality.Disabling first-step checkpointing is appropriate for this specialized test, allowing it to focus on validating the process reward model's token classification capabilities without checkpointing overhead.
tests/e2e/patched/test_mixtral_samplepack.py (1)
55-55: LGTM: Consistent configuration across both test methods.Both test methods appropriately disable first-step checkpointing, maintaining focus on their respective testing objectives (QLoRA and full fine-tuning) without unnecessary checkpointing overhead.
Also applies to: 94-94
tests/e2e/integrations/test_hooks.py (1)
156-156: Good addition for test stability.Adding
save_first_step: Falseto this existing test configuration is appropriate to prevent the new checkpointing behavior from interfering with the plugin hooks testing logic.tests/e2e/multigpu/test_gemma3.py (1)
74-74: Consistent test configuration update.Properly disabling the new
save_first_stepfeature in this multi-GPU test maintains focus on the core DDP functionality being tested.tests/e2e/patched/test_fused_llama.py (1)
56-56: Good practice for future test enablement.Adding the
save_first_step: Falseconfiguration even to this skipped test ensures consistency and prevents issues when the test is re-enabled.tests/e2e/test_mamba.py (1)
54-54: Consistent configuration management.Properly including the
save_first_step: Falseoption maintains test configuration consistency across the test suite.tests/e2e/multigpu/test_eval.py (1)
70-70: Well-maintained test configuration consistency.Both evaluation test methods properly include the
save_first_step: Falseconfiguration, ensuring the new checkpointing feature doesn't interfere with evaluation-focused testing.Also applies to: 142-142
tests/e2e/patched/test_mistral_samplepack.py (2)
59-59: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core LoRA packing functionality.
101-101: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core fine-tuning packing functionality.tests/e2e/test_llama_pretrain.py (1)
56-56: LGTM! Consistent configuration for pretraining integration test.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this pretraining integration test, allowing it to focus on the core pretraining functionality across different packing configurations.tests/e2e/multigpu/patched/test_sp.py (1)
72-72: LGTM! Consistent configuration for sequence parallelism integration test.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this multi-GPU sequence parallelism integration test, allowing it to focus on the core parallelism functionality.tests/e2e/patched/test_falcon_samplepack.py (2)
61-61: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core QLoRA functionality when the test becomes active.
103-103: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core fine-tuning functionality when the test becomes active.tests/e2e/test_gemma3_text.py (2)
66-66: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core LoRA functionality with Gemma3 models across different packing configurations.
117-117: LGTM! Consistent configuration for integration testing.The addition of
"save_first_step": Falseappropriately disables the new first-step checkpointing feature for this integration test, allowing it to focus on the core full fine-tuning functionality with Gemma3 models across different packing configurations.tests/e2e/integrations/test_llm_compressor.py (1)
84-84: LGTM! Appropriate configuration for existing test.Adding
save_first_step: Falseensures this existing test maintains its original behavior without the overhead of saving a first-step checkpoint, which is appropriate since this test focuses on LLMCompressor integration rather than the save_first_step feature.tests/e2e/solo/test_save_first_step.py (1)
59-60: Excellent test coverage for the save_first_step feature.The test logic is well-designed:
- Positive test verifies
checkpoint-1exists whensave_first_step=True- Negative test uses
pytest.raises(AssertionError)to verifycheckpoint-1doesn't exist whensave_first_step=FalseThis provides comprehensive coverage for the new callback functionality and aligns perfectly with the PR objectives of enabling early checkpoint validation.
Also applies to: 100-102
NanoCode012
left a comment
There was a problem hiding this comment.
I like this idea, but we may now get people complaining the training takes a bit longer to start (depending on model size).
I however would prefer this in as it would save people doing large runs time+resources.
|
I would suggest we add |
I think the latter is good alternative. Or, we could raise a warning when it's not explicitly set? I'd prefer not to add warnings when possible, though, since we already have a lot of logs clutter. |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/e2e/test_save_first_step.py (1)
1-3: Update docstring to reflect the actual test scope.The docstring mentions "relora llama" but these tests are actually for the general
save_first_stepcallback feature and aren't specific to relora configurations.-""" -E2E tests for relora llama -""" +""" +E2E tests for save_first_step callback feature +"""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/e2e/test_save_first_step.py(1 hunks)
🔇 Additional comments (1)
tests/e2e/test_save_first_step.py (1)
21-61: LGTM! Test correctly validates save_first_step=True behavior.The test properly sets up a minimal training configuration with
save_first_step=Trueand verifies that the checkpoint is created after training.
2d6b619 to
f3703ea
Compare
|
I rebased and pushed as there were some merge conflicts from merging other PRs. |
Description
Title. This is a good sanity check and ensures users don't waste a long time training, only for their model checkpoint to fail at the end of their run.
Note that we default the value for the config
save_first_steptoTrue. It's up for debate as to whether this is a good default.Motivation and Context
As reported on our discord, some user runs end with errors on the final model checkpoint. This should limit the amount of frustration in these cases by failing faster.
How has this been tested?
Manually with a few configs.
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Tests