diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml index 1950314948..bbe8785f16 100644 --- a/examples/cloud/modal.yaml +++ b/examples/cloud/modal.yaml @@ -26,3 +26,5 @@ timeout: 86400 # Preprocess specific configurations memory_preprocess: 32 timeout_preprocess: 14400 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml index 4a30e9a776..da2777270e 100644 --- a/examples/cohere/command-r-7b-qlora.yml +++ b/examples/cohere/command-r-7b-qlora.yml @@ -35,7 +35,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 4 @@ -56,3 +55,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml index 2c0495ceda..1a051b98bd 100644 --- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml index de9c956e0c..8073426412 100644 --- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml index 0ed97db369..78bf6b1797 100644 --- a/examples/deepseek-v2/fft-fsdp-16b.yaml +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -55,3 +55,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 34dbeaafed..da1d9aefd5 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -79,3 +79,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml index dc0051bd5e..9d92e8662f 100644 --- a/examples/devstral/devstral-small-qlora.yml +++ b/examples/devstral/devstral-small-qlora.yml @@ -62,3 +62,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml index 1dd901154a..484c31fecc 100644 --- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml index 24dc7cae37..dea2a6e6d8 100644 --- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml @@ -46,7 +46,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 4 @@ -69,3 +68,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml index 43eb1967ba..b187efbf6e 100644 --- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml index 00929bbf01..4d981ad95f 100644 --- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml index e2640de7bc..5ee13facd3 100644 --- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml index 183e423b51..4b665c3cd9 100644 --- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index cb96a32c1d..68d213fada 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -60,3 +60,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index ce01a4572e..624ebdcd22 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -50,3 +50,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 217c887aa6..99921770db 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -66,3 +66,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index d78559ae3b..025cb9240f 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -60,3 +60,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 183eb88e84..e9e606b69f 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -62,3 +62,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml index 86d9b43f8b..8973cedd4b 100644 --- a/examples/glm4/qlora-32b.yaml +++ b/examples/glm4/qlora-32b.yaml @@ -60,3 +60,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 2cb0eea411..494154886b 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -54,3 +54,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index d13ce64839..64db8f2ff7 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -55,3 +55,5 @@ saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 6badaba19b..fda30e2d2f 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -64,3 +64,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml index 95961557e3..74c90c1e1e 100644 --- a/examples/lfm2/lfm2-350m-fft.yaml +++ b/examples/lfm2/lfm2-350m-fft.yaml @@ -46,3 +46,5 @@ evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 86b1b6a218..c44cd22307 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -55,3 +55,5 @@ saves_per_epoch: 1 deepspeed: #deepspeed_configs/zero2.json # multi-gpu only weight_decay: 0.1 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 0f1b34016c..580fabdf8d 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -64,3 +64,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index a76a792aef..a44e261beb 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -60,3 +60,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 22dbf2d992..085627f63b 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -52,3 +52,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 679aed3a99..759fce0441 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -52,3 +52,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index a42eabd4b7..3bf30120bb 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -67,3 +67,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index de65928bc8..09596c71e3 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -53,3 +53,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index e0a5f7068c..ca8b14a1cc 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -58,3 +58,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index 2b0ae2c70b..64d749b5a9 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -57,3 +57,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 5d979c96c2..08d8ee5c13 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -77,3 +77,5 @@ fsdp_config: special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index eccfa6d8c1..e2808935f6 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -72,3 +72,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index fdae3e6c4d..2dfe6d492d 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -42,3 +42,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 51f1c768b1..10ab2a320c 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -71,3 +71,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index acab862f64..83b7f9a37c 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -64,3 +64,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml index 10e9747cb1..b20dbad844 100644 --- a/examples/llama-3/lora-1b-deduplicate-dpo.yml +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -83,3 +83,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml index 630ec92f6a..67e518184b 100644 --- a/examples/llama-3/lora-1b-deduplicate-sft.yml +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -61,3 +61,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index a2d07ca491..92a948c2e6 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -65,3 +65,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml index bb23164ebb..178a1fb89c 100644 --- a/examples/llama-3/lora-1b-ray.yml +++ b/examples/llama-3/lora-1b-ray.yml @@ -64,3 +64,5 @@ special_tokens: use_ray: true ray_num_workers: 4 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml index 769dd32e60..c4ce3eb0fd 100644 --- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -63,3 +63,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index acc17e21f2..82085483f1 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -60,3 +60,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index ad50cd38a3..c393897557 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -57,3 +57,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index 89a51ea68f..f156e23d36 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -61,3 +61,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml index 5c8fe66289..6b76ea8d9d 100644 --- a/examples/llama-3/qlora-1b.yml +++ b/examples/llama-3/qlora-1b.yml @@ -62,3 +62,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 2b7d51925c..1ee922b59b 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -60,3 +60,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index 412b6721ca..5edd8353ad 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -69,3 +69,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 4cc9fc3dba..a674eca279 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -54,3 +54,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml index 1bbb880289..8577a19d2f 100644 --- a/examples/llama-3/sparse-finetuning.yaml +++ b/examples/llama-3/sparse-finetuning.yaml @@ -75,3 +75,5 @@ llmcompressor: ] start: 0 save_compressed: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml index 2be94f4efa..d4a038e113 100644 --- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml @@ -86,3 +86,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml index eeae872a6b..bea10d979e 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml @@ -90,3 +90,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml index 17ad706344..737d938126 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml @@ -83,3 +83,5 @@ weight_decay: 0.0 special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml index eff708e4db..390be5af78 100644 --- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml @@ -86,3 +86,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml index 9a411883e4..b319349c4a 100644 --- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml +++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml @@ -84,3 +84,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml index 20352f81eb..6be3988ef0 100644 --- a/examples/llama-4/scout-qlora-single-h100-flex.yaml +++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml @@ -82,3 +82,5 @@ weight_decay: 0.0 special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml index 9fbd34107b..a67936cf18 100644 --- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml +++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml @@ -87,3 +87,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 5198c8e744..a4bac8987d 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml index b10e8baf6a..b23d2309a0 100644 --- a/examples/magistral/magistral-small-fsdp-qlora.yaml +++ b/examples/magistral/magistral-small-fsdp-qlora.yaml @@ -70,3 +70,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer fsdp_activation_checkpointing: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml index e3e746f224..f0fce014fa 100644 --- a/examples/magistral/magistral-small-qlora.yaml +++ b/examples/magistral/magistral-small-qlora.yaml @@ -61,3 +61,5 @@ flash_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 3d4583932e..2261bd2156 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -48,3 +48,5 @@ weight_decay: 0.0 special_tokens: tokens: save_safetensors: False + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml index f626a92a17..e9bcbb7d68 100644 --- a/examples/mistral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral-ds-zero3.yaml @@ -53,3 +53,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 15edffb44e..8c4d80f792 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -43,3 +43,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml index e6f46affb1..d54c3e30bb 100644 --- a/examples/mistral/lora-mps.yml +++ b/examples/mistral/lora-mps.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index 9af4274fdf..161255468e 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml index af707973ff..8d03786904 100644 --- a/examples/mistral/mistral-dpo-qlora.yml +++ b/examples/mistral/mistral-dpo-qlora.yml @@ -80,3 +80,5 @@ weight_decay: 0.0 special_tokens: bos_token: "<|im_start|>" eos_token: "<|im_end|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index e234b19a24..cec958c54c 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -74,3 +74,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml index 6c0212b7cb..f37dc09fa3 100644 --- a/examples/mistral/mistral-qlora-orpo.yml +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -69,3 +69,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml index 3e3b45862d..4a492c5953 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml index af6ba5a769..64ef9930ce 100644 --- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml @@ -72,3 +72,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml index b1843a138d..c8d0a2711b 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -77,3 +77,5 @@ fsdp_config: fsdp_forward_prefetch: false fsdp_backward_prefetch: BACKWARD_PRE special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 4c256420cf..5be9b4db89 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -81,3 +81,5 @@ saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml index 25e1d71551..100e4464ff 100644 --- a/examples/mistral/mixtral_22.yml +++ b/examples/mistral/mixtral_22.yml @@ -51,3 +51,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 607e337014..08df36e150 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml index 9bcbbeee0d..57f65d9666 100644 --- a/examples/orpheus/finetune.yml +++ b/examples/orpheus/finetune.yml @@ -50,3 +50,5 @@ weight_decay: 0.05 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml index ad4ce9cd44..9f3bbdf539 100644 --- a/examples/phi/lora-3.5.yaml +++ b/examples/phi/lora-3.5.yaml @@ -63,3 +63,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 4 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 1562a73536..fc6d649d71 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -57,3 +57,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 4cd53db979..ccd92c817f 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -60,3 +60,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index ca733cc71b..853250ccbb 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -57,3 +57,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml index d0d14fea67..130298bc04 100644 --- a/examples/phi/phi3-ft-fsdp.yml +++ b/examples/phi/phi3-ft-fsdp.yml @@ -71,3 +71,5 @@ fsdp_config: resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml index 17c48da6f7..42b87e8d05 100644 --- a/examples/phi/phi3-ft.yml +++ b/examples/phi/phi3-ft.yml @@ -59,3 +59,5 @@ warmup_ratio: 0.2 debug: true weight_decay: 0.1 resize_token_embeddings_to_32x: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index 6ad0a5e999..ea769d202c 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -55,3 +55,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml index e8932b9688..8ea6081999 100644 --- a/examples/qwen2-vl/lora-7b.yaml +++ b/examples/qwen2-vl/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index bd896c2b3d..69a74ae4a4 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -54,3 +54,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml index 4afa24f3ce..af188f75d6 100644 --- a/examples/qwen2/prm.yaml +++ b/examples/qwen2/prm.yaml @@ -55,3 +55,5 @@ eval_steps: 100 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index ed2670ab61..861ce5517e 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -67,3 +67,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml index 822407a1fe..1854b8216b 100644 --- a/examples/qwen2/reward-model.yaml +++ b/examples/qwen2/reward-model.yaml @@ -26,7 +26,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 4 @@ -50,3 +49,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml index 25d02805f7..13a97dec3e 100644 --- a/examples/qwen2_5-vl/lora-7b.yaml +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml index 45a4395ac1..1f148ece5c 100644 --- a/examples/qwen3/32b-qlora.yaml +++ b/examples/qwen3/32b-qlora.yaml @@ -67,3 +67,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml index 6832b6af75..e4d0ed4fb0 100644 --- a/examples/qwen3/8b-qat-fsdp2.yml +++ b/examples/qwen3/8b-qat-fsdp2.yml @@ -76,3 +76,5 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml index dc3377b4f2..762f9648d1 100644 --- a/examples/qwen3/qlora-fsdp.yaml +++ b/examples/qwen3/qlora-fsdp.yaml @@ -66,3 +66,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 4df0100406..d3a3b32424 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -36,6 +36,7 @@ GCCallback, GPUStatsCallback, SaveAxolotlConfigtoWandBCallback, + SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.schemas.enums import CustomSupportedOptimizers @@ -135,6 +136,8 @@ def get_callbacks(self) -> list[TrainerCallback]: callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.save_first_step: + callbacks.append(SaveModelOnFirstStepCallback()) callbacks.append(GPUStatsCallback(cfg=self.cfg)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 5f804d6afa..bb777fc90f 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -64,7 +64,7 @@ def on_step_end( state: TrainerState, control: TrainerControl, **kwargs, - ): + ) -> TrainerControl: # Save if ( args.save_strategy == IntervalStrategy.STEPS @@ -100,11 +100,11 @@ def __init__(self, cfg): def on_step_end( self, - args: TrainingArguments, + args: TrainingArguments, # pylint: disable=unused-argument state: TrainerState, control: TrainerControl, **kwargs, - ): + ) -> TrainerControl: if not self.logged and state.global_step > 1: log_gpu_memory_usage(LOG, "while training", self.cfg.device) self.logged = True @@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback): def __init__(self, cfg): self.cfg = cfg - self.logged = False self.violations = 0 self.threshold = cfg.loss_watchdog_threshold self.patience = cfg.loss_watchdog_patience or 3 def on_step_end( self, - _args: TrainingArguments, + args: TrainingArguments, # pylint: disable=unused-argument state: TrainerState, control: TrainerControl, **_kwargs, - ): + ) -> TrainerControl: if len(state.log_history) > 0 and "loss" in state.log_history[-1]: if state.log_history[-1]["loss"] > self.threshold: self.violations += 1 @@ -141,6 +140,21 @@ def on_step_end( return control +class SaveModelOnFirstStepCallback(TrainerCallback): + """Callback to save the model on the first step of training if enabled""" + + def on_step_end( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + if state.global_step == 1: + control.should_save = True + return control + + def bench_eval_callback_factory(trainer, tokenizer): accuracy = evaluate.load("accuracy") abcd_idx = [ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1726feb67f..d5fea9158a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -695,6 +695,7 @@ class AxolotlInputConfig( "description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`" }, ) + save_steps: int | float | None = Field( default=None, json_schema_extra={ @@ -716,6 +717,13 @@ class AxolotlInputConfig( save_total_limit: int | None = Field( default=None, json_schema_extra={"description": "Checkpoints saved at a time"} ) + save_first_step: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to checkpoint a model after the first step of training. Defaults to False." + }, + ) + logging_steps: int | None = Field( default=None, json_schema_extra={"description": "Logging frequency"} ) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 790b34f3e6..34e6c96447 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -44,6 +44,7 @@ def min_cfg(temp_dir): "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } @@ -98,6 +99,7 @@ def test_qwen2_w_cce(self, temp_dir): "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index 4734449fe2..8743efb981 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -153,6 +153,7 @@ def test_plugin_hooks(self, temp_dir): "max_steps": 5, "flash_attention": True, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 212450e89e..1ac3b537e7 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -67,6 +67,7 @@ def min_cfg(temp_dir): "output_dir": temp_dir, "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 6ab3d7ab89..b1f5befdd4 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -50,6 +50,7 @@ def test_llama_wo_flce(self, temp_dir): "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) # pylint: disable=duplicate-code @@ -96,6 +97,7 @@ def test_llama_w_flce(self, temp_dir): "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) # pylint: disable=duplicate-code diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py index 247ae3bac2..dceecea9ff 100644 --- a/tests/e2e/integrations/test_llm_compressor.py +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -81,6 +81,7 @@ def test_llmcompressor_plugin( }, "save_compressed": save_compressed, }, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index 5593c7eb62..80098e6841 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -69,6 +69,7 @@ def _run_sequence_parallel_test( "use_tensorboard": True, "sequence_parallel_degree": 2, "ring_attn_func": ring_attn_func, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index bdf5ada6ba..cbdf8de96b 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -61,6 +61,7 @@ def test_loss_llama(self, temp_dir): "max_steps": 2, "use_tensorboard": True, "save_strategy": "no", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index c047343456..d022ae2d92 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -223,6 +223,7 @@ def test_llama_dora(self, temp_dir, num_gpus): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -317,6 +318,7 @@ def test_llama_lora_sp(self, temp_dir): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -409,6 +411,7 @@ def test_llama_fft(self, temp_dir, num_gpus): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index d6429cf63c..4f86278ffb 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -67,6 +67,7 @@ def test_eval_sample_packing(self, temp_dir): "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) @@ -138,6 +139,7 @@ def test_eval(self, temp_dir): "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 3868d90f0f..4a7b101a83 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -71,6 +71,7 @@ def test_lora_ddp_packed(self, temp_dir): "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index f0c74fbf8b..aab14dcc4f 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -69,6 +69,7 @@ def test_lora_ddp(self, temp_dir): "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -135,6 +136,7 @@ def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -210,6 +212,7 @@ def test_dpo_lora_ddp(self, temp_dir): "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -289,6 +292,7 @@ def test_dpo_qlora_ddp(self, temp_dir): "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -365,6 +369,7 @@ def test_fsdp(self, temp_dir, gradient_accumulation_steps): }, "use_tensorboard": True, "seed": 42, + "save_first_step": False, } ) @@ -442,6 +447,7 @@ def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type): "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -520,6 +526,7 @@ def test_fsdp2_packed( "fsdp_reshard_after_forward": fsdp_reshard_after_forward, }, "use_tensorboard": True, + "save_first_step": False, } ) if attention_backend == "flash": @@ -605,6 +612,7 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir): "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -689,6 +697,7 @@ def test_ds_zero3_packed( "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / deepspeed), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -765,6 +774,7 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora): "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, "seed": 42, + "save_first_step": False, **adapter, } ) @@ -840,6 +850,7 @@ def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora): "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -908,6 +919,7 @@ def test_fix_untrained_tokens(self, temp_dir): "save_safetensors": True, # "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 43a722b488..dd14222968 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -56,6 +56,7 @@ def test_lora_ddp(self, temp_dir): "use_tensorboard": True, "use_ray": True, "ray_num_workers": 2, + "save_first_step": False, } ) @@ -115,6 +116,7 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps): "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 08b62accc0..1824443e79 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -55,6 +55,7 @@ def test_sdp_lora_packing(self, temp_dir): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -102,6 +103,7 @@ def test_torch_lora_packing(self, temp_dir): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index d494ed1ebd..3d5b3dc56c 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -69,6 +69,7 @@ def test_activation_checkpointing_offload( "bf16": True, "save_safetensors": True, "gradient_checkpointing": gradient_checkpointing, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index ca8b21178f..38099b220b 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -62,6 +62,7 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index a593b07918..ef31b11c78 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -58,6 +58,7 @@ def test_qlora(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -99,6 +100,7 @@ def test_ft(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py index f77a1fbe5a..fdaab558dc 100644 --- a/tests/e2e/patched/test_flattening.py +++ b/tests/e2e/patched/test_flattening.py @@ -61,6 +61,7 @@ def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps): "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 1bbc82a38a..a3fe591ee8 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -53,6 +53,7 @@ def test_fft_packing(self, temp_dir): "max_steps": 10, "save_steps": 5, "eval_steps": 5, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index d2dcc5e4b7..ba5556a593 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -58,6 +58,7 @@ def test_lora_s2_attn(self, temp_dir): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) @@ -100,6 +101,7 @@ def test_fft_s2_attn(self, temp_dir): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 5df6bfecc6..fdf6adbc6c 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -55,6 +55,7 @@ def test_lora_packing(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -108,6 +109,7 @@ def test_lora_gptq_packed(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 442089bae7..bea0f9c68c 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -56,6 +56,7 @@ def test_lora_packing(self, temp_dir): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -97,6 +98,7 @@ def test_ft_packing(self, temp_dir): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 5f778660bb..09e427abd3 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -52,6 +52,7 @@ def test_qlora(self, temp_dir): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -90,6 +91,7 @@ def test_ft(self, temp_dir): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 5ea88b001b..b90be23e43 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -45,6 +45,7 @@ def test_mixtral_multipack(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -78,6 +79,7 @@ def test_mistral_multipack(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py index d4f59a128f..4769319aef 100644 --- a/tests/e2e/patched/test_peft_embeddings.py +++ b/tests/e2e/patched/test_peft_embeddings.py @@ -49,6 +49,7 @@ def test_peft_embeddings_upcast(self, temp_dir): "bf16": "auto", "save_safetensors": True, "embeddings_skip_upcast": True, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index d241ce1853..1f0ddd6303 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -54,6 +54,7 @@ def test_ft_packed(self, temp_dir): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) @@ -105,6 +106,7 @@ def test_qlora_packed(self, temp_dir): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 3639567335..54b8245eec 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -58,6 +58,7 @@ def test_resume_lora_packed(self, temp_dir): "max_steps": 15, "use_tensorboard": True, "save_safetensors": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 2b4d11b30c..4a2c69d457 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -47,6 +47,7 @@ def fixture_cfg(): "special_tokens": { "pad_token": "<|endoftext|>", }, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 69171481c7..2c8ee4eb05 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -62,6 +62,7 @@ def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing): "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) @@ -112,6 +113,7 @@ def test_unsloth_llama_qlora_unpacked(self, temp_dir): "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) @@ -167,6 +169,7 @@ def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention) "lr_scheduler": "cosine", "use_tensorboard": True, "fp16": True, + "save_first_step": False, } ) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index 2799137136..76364fc0e5 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -49,6 +49,7 @@ def test_loss_llama(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 7af5504963..f6fcad8415 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -65,6 +65,7 @@ def test_relora(self, temp_dir): "lr_scheduler": "cosine", "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index 7dfc4ae159..e4a47fb0aa 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -67,6 +67,7 @@ def test_lora_deepseekv3(self, temp_dir, sample_packing): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -116,6 +117,7 @@ def test_fft_deepseekv3(self, temp_dir, sample_packing): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 2cdb576891..a1df695352 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -56,6 +56,7 @@ def test_dpo_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -105,6 +106,7 @@ def test_dpo_nll_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -154,6 +156,7 @@ def test_dpo_use_weighting(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -203,6 +206,7 @@ def test_kto_pair_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -251,6 +255,7 @@ def test_ipo_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -302,6 +307,7 @@ def test_orpo_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -370,6 +376,7 @@ def test_kto_lora(self, temp_dir): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 9b65f8feb6..e4a06ad148 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -48,6 +48,7 @@ def test_train_w_embedding_lr_scale(self, temp_dir): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -93,6 +94,7 @@ def test_train_w_embedding_lr(self, temp_dir): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py index 6271bba289..977497e5e8 100644 --- a/tests/e2e/test_evaluate.py +++ b/tests/e2e/test_evaluate.py @@ -36,6 +36,7 @@ def test_evaluate(self, temp_dir): "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 20, + "save_first_step": False, } ) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 4f88e740c3..5be6efcf64 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -60,6 +60,7 @@ def test_lora(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) @@ -115,6 +116,7 @@ def test_lora_added_vocab(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) @@ -156,6 +158,7 @@ def test_ft(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index 3f00a13844..ef38d028dc 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -63,6 +63,7 @@ def test_lora_gemma3_text(self, temp_dir, sample_packing): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -113,6 +114,7 @@ def test_fft_gemma3_text(self, temp_dir, sample_packing): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 2b180029ce..1e6df0be9c 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -45,6 +45,7 @@ def test_fft_trust_remote_code(self, temp_dir): "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -92,6 +93,7 @@ def test_fix_untrained_tokens(self, temp_dir): "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -136,6 +138,7 @@ def test_fix_untrained_tokens_already_trained(self, temp_dir): "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -176,6 +179,7 @@ def test_batch_flattening(self, temp_dir): "batch_flattening": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index fdebf2173b..bd55023008 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -53,6 +53,7 @@ def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index ad4a83c6a2..760759bcaa 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -54,6 +54,7 @@ def test_lora_llama_vision_text_only_dataset(self, temp_dir): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) @@ -100,6 +101,7 @@ def test_lora_llama_vision_multimodal_dataset(self, temp_dir): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 3015653021..7e0ff46cf7 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -49,6 +49,7 @@ def test_lora(self, temp_dir): "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 5, + "save_first_step": False, } ) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 1824619a6a..73d3bdc260 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -51,6 +51,7 @@ def test_fft(self, temp_dir): "save_steps": 10, "eval_steps": None, "save_safetensors": False, + "save_first_step": False, } ) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 5d9b8ba8c1..f47f794e06 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -55,6 +55,7 @@ def test_lora(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -95,6 +96,7 @@ def test_ft(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 761e59391c..3fe2bf70f1 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -61,6 +61,7 @@ def test_qlora_w_fa2(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -116,6 +117,7 @@ def test_qlora_wo_fa2(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -170,6 +172,7 @@ def test_16bit_lora_w_fa2(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -228,6 +231,7 @@ def test_16bit_lora_wo_fa2(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -273,6 +277,7 @@ def test_ft(self, temp_dir): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 53ef86022f..1d233a2013 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -55,6 +55,7 @@ def test_optimi_adamw(self, temp_dir): "optimizer": "optimi_adamw", "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) @@ -100,6 +101,7 @@ def test_adopt_adamw(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adopt_adamw", "lr_scheduler": "cosine", + "save_first_step": False, } ) @@ -146,6 +148,7 @@ def test_muon(self, temp_dir): "optimizer": "muon", "lr_scheduler": "cosine", "weight_decay": 0.01, + "save_first_step": False, } ) @@ -184,6 +187,7 @@ def test_fft_schedule_free_adamw(self, temp_dir): "lr_scheduler": "constant", "save_safetensors": True, "max_steps": 10, + "save_first_step": False, } ) # pylint: disable=duplicate-code @@ -232,6 +236,7 @@ def test_came_pytorch(self, temp_dir): "adam_epsilon2": 1e-16, "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index cc2db72e0e..aec9d95f8b 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -48,6 +48,7 @@ def test_loss_packed(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 88fda91915..ab3a636748 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -53,6 +53,7 @@ def test_phi_ft(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -102,6 +103,7 @@ def test_phi_qlora(self, temp_dir): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index abfe1b0c55..bd9eec48b6 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -49,6 +49,7 @@ def test_prm(self, temp_dir): "use_tensorboard": True, "special_tokens": {"pad_token": "<|endoftext|>"}, "seed": 42, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index ef726079d0..139ae155ac 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -57,6 +57,7 @@ def test_qat(self, temp_dir): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -115,6 +116,7 @@ def test_qat_dpo(self, temp_dir): "weight_dtype": "int8", "group_size": 8, }, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index aa8b9f6c05..59267d14db 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -59,6 +59,7 @@ def test_dpo(self, base_model, temp_dir): "bf16": "auto", "tf32": True, "gradient_checkpointing": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 5d52bcc865..82513f99f2 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -58,6 +58,7 @@ def test_rm_lora(self, temp_dir): "gradient_checkpointing": True, "warmup_ratio": 0.1, "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py new file mode 100644 index 0000000000..5bbd2302b9 --- /dev/null +++ b/tests/e2e/test_save_first_step.py @@ -0,0 +1,102 @@ +""" +E2E tests for relora llama +""" + +import unittest +from pathlib import Path + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestSaveFirstStepCallback(unittest.TestCase): + """Test cases for save_first_step callback config.""" + + @with_temp_dir + def test_save_first_step(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": True, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) + + @with_temp_dir + def test_no_save_first_step(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + with pytest.raises(AssertionError): + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index e98378f08a..8f7a13aeea 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -51,6 +51,7 @@ def test_rex_scheduler(self, temp_dir): "lr_scheduler": "rex", "warmup_steps": 5, "cosine_min_lr_ratio": 0.05, + "save_first_step": False, } )