diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml
index 1950314948..bbe8785f16 100644
--- a/examples/cloud/modal.yaml
+++ b/examples/cloud/modal.yaml
@@ -26,3 +26,5 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml
index 4a30e9a776..da2777270e 100644
--- a/examples/cohere/command-r-7b-qlora.yml
+++ b/examples/cohere/command-r-7b-qlora.yml
@@ -35,7 +35,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
index 2c0495ceda..1a051b98bd 100644
--- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
index de9c956e0c..8073426412 100644
--- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml
index 0ed97db369..78bf6b1797 100644
--- a/examples/deepseek-v2/fft-fsdp-16b.yaml
+++ b/examples/deepseek-v2/fft-fsdp-16b.yaml
@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
index 34dbeaafed..da1d9aefd5 100644
--- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml
+++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml
index dc0051bd5e..9d92e8662f 100644
--- a/examples/devstral/devstral-small-qlora.yml
+++ b/examples/devstral/devstral-small-qlora.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
index 1dd901154a..484c31fecc 100644
--- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
index 24dc7cae37..dea2a6e6d8 100644
--- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
@@ -46,7 +46,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
index 43eb1967ba..b187efbf6e 100644
--- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
index 00929bbf01..4d981ad95f 100644
--- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
index e2640de7bc..5ee13facd3 100644
--- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
index 183e423b51..4b665c3cd9 100644
--- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml
index cb96a32c1d..68d213fada 100644
--- a/examples/gemma2/qlora.yml
+++ b/examples/gemma2/qlora.yml
@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml
index ce01a4572e..624ebdcd22 100644
--- a/examples/gemma2/reward-model.yaml
+++ b/examples/gemma2/reward-model.yaml
@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml
index 217c887aa6..99921770db 100644
--- a/examples/gemma3/gemma-3-1b-qlora.yml
+++ b/examples/gemma3/gemma-3-1b-qlora.yml
@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml
index d78559ae3b..025cb9240f 100644
--- a/examples/gemma3/gemma-3-4b-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-qlora.yml
@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml
index 183eb88e84..e9e606b69f 100644
--- a/examples/gemma3/gemma-3-4b-vision-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml
@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml
index 86d9b43f8b..8973cedd4b 100644
--- a/examples/glm4/qlora-32b.yaml
+++ b/examples/glm4/qlora-32b.yaml
@@ -60,3 +60,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml
index 2cb0eea411..494154886b 100644
--- a/examples/jamba/qlora.yaml
+++ b/examples/jamba/qlora.yaml
@@ -54,3 +54,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml
index d13ce64839..64db8f2ff7 100644
--- a/examples/jamba/qlora_deepspeed.yaml
+++ b/examples/jamba/qlora_deepspeed.yaml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
index 6badaba19b..fda30e2d2f 100644
--- a/examples/jamba/qlora_fsdp_large.yaml
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -64,3 +64,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml
index 95961557e3..74c90c1e1e 100644
--- a/examples/lfm2/lfm2-350m-fft.yaml
+++ b/examples/lfm2/lfm2-350m-fft.yaml
@@ -46,3 +46,5 @@ evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml
index 86b1b6a218..c44cd22307 100644
--- a/examples/llama-2/fft_optimized.yml
+++ b/examples/llama-2/fft_optimized.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml
index 0f1b34016c..580fabdf8d 100644
--- a/examples/llama-2/gptq-lora.yml
+++ b/examples/llama-2/gptq-lora.yml
@@ -64,3 +64,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml
index a76a792aef..a44e261beb 100644
--- a/examples/llama-2/lisa.yml
+++ b/examples/llama-2/lisa.yml
@@ -60,3 +60,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml
index 22dbf2d992..085627f63b 100644
--- a/examples/llama-2/loftq.yml
+++ b/examples/llama-2/loftq.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index 679aed3a99..759fce0441 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml
index a42eabd4b7..3bf30120bb 100644
--- a/examples/llama-2/qlora-fsdp.yml
+++ b/examples/llama-2/qlora-fsdp.yml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index de65928bc8..09596c71e3 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -53,3 +53,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml
index e0a5f7068c..ca8b14a1cc 100644
--- a/examples/llama-2/relora.yml
+++ b/examples/llama-2/relora.yml
@@ -58,3 +58,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml
index 2b0ae2c70b..64d749b5a9 100644
--- a/examples/llama-3-vision/lora-11b.yaml
+++ b/examples/llama-3-vision/lora-11b.yaml
@@ -57,3 +57,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml
index 5d979c96c2..08d8ee5c13 100644
--- a/examples/llama-3/3b-qat-fsdp2.yaml
+++ b/examples/llama-3/3b-qat-fsdp2.yaml
@@ -77,3 +77,5 @@ fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml
index eccfa6d8c1..e2808935f6 100644
--- a/examples/llama-3/fft-8b-liger-fsdp.yaml
+++ b/examples/llama-3/fft-8b-liger-fsdp.yaml
@@ -72,3 +72,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index fdae3e6c4d..2dfe6d492d 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -42,3 +42,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
index 51f1c768b1..10ab2a320c 100644
--- a/examples/llama-3/instruct-dpo-lora-8b.yml
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -71,3 +71,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index acab862f64..83b7f9a37c 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -64,3 +64,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml
index 10e9747cb1..b20dbad844 100644
--- a/examples/llama-3/lora-1b-deduplicate-dpo.yml
+++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml
@@ -83,3 +83,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml
index 630ec92f6a..67e518184b 100644
--- a/examples/llama-3/lora-1b-deduplicate-sft.yml
+++ b/examples/llama-3/lora-1b-deduplicate-sft.yml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml
index a2d07ca491..92a948c2e6 100644
--- a/examples/llama-3/lora-1b-kernels.yml
+++ b/examples/llama-3/lora-1b-kernels.yml
@@ -65,3 +65,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml
index bb23164ebb..178a1fb89c 100644
--- a/examples/llama-3/lora-1b-ray.yml
+++ b/examples/llama-3/lora-1b-ray.yml
@@ -64,3 +64,5 @@ special_tokens:
use_ray: true
ray_num_workers: 4
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
index 769dd32e60..c4ce3eb0fd 100644
--- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml
+++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
@@ -63,3 +63,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml
index acc17e21f2..82085483f1 100644
--- a/examples/llama-3/lora-1b.yml
+++ b/examples/llama-3/lora-1b.yml
@@ -60,3 +60,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml
index ad50cd38a3..c393897557 100644
--- a/examples/llama-3/lora-8b.yml
+++ b/examples/llama-3/lora-8b.yml
@@ -57,3 +57,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml
index 89a51ea68f..f156e23d36 100644
--- a/examples/llama-3/qlora-1b-kto.yaml
+++ b/examples/llama-3/qlora-1b-kto.yaml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml
index 5c8fe66289..6b76ea8d9d 100644
--- a/examples/llama-3/qlora-1b.yml
+++ b/examples/llama-3/qlora-1b.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml
index 2b7d51925c..1ee922b59b 100644
--- a/examples/llama-3/qlora-fsdp-405b.yaml
+++ b/examples/llama-3/qlora-fsdp-405b.yaml
@@ -60,3 +60,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml
index 412b6721ca..5edd8353ad 100644
--- a/examples/llama-3/qlora-fsdp-70b.yaml
+++ b/examples/llama-3/qlora-fsdp-70b.yaml
@@ -69,3 +69,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml
index 4cc9fc3dba..a674eca279 100644
--- a/examples/llama-3/qlora.yml
+++ b/examples/llama-3/qlora.yml
@@ -54,3 +54,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml
index 1bbb880289..8577a19d2f 100644
--- a/examples/llama-3/sparse-finetuning.yaml
+++ b/examples/llama-3/sparse-finetuning.yaml
@@ -75,3 +75,5 @@ llmcompressor:
]
start: 0
save_compressed: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
index 2be94f4efa..d4a038e113 100644
--- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
index eeae872a6b..bea10d979e 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
@@ -90,3 +90,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
index 17ad706344..737d938126 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
@@ -83,3 +83,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
index eff708e4db..390be5af78 100644
--- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
index 9a411883e4..b319349c4a 100644
--- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
+++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
@@ -84,3 +84,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml
index 20352f81eb..6be3988ef0 100644
--- a/examples/llama-4/scout-qlora-single-h100-flex.yaml
+++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml
@@ -82,3 +82,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
index 9fbd34107b..a67936cf18 100644
--- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
+++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
@@ -87,3 +87,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml
index 5198c8e744..a4bac8987d 100644
--- a/examples/llava/lora-7b.yaml
+++ b/examples/llava/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml
index b10e8baf6a..b23d2309a0 100644
--- a/examples/magistral/magistral-small-fsdp-qlora.yaml
+++ b/examples/magistral/magistral-small-fsdp-qlora.yaml
@@ -70,3 +70,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml
index e3e746f224..f0fce014fa 100644
--- a/examples/magistral/magistral-small-qlora.yaml
+++ b/examples/magistral/magistral-small-qlora.yaml
@@ -61,3 +61,5 @@ flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml
index 3d4583932e..2261bd2156 100644
--- a/examples/mamba/config.yml
+++ b/examples/mamba/config.yml
@@ -48,3 +48,5 @@ weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml
index f626a92a17..e9bcbb7d68 100644
--- a/examples/mistral/bigstral-ds-zero3.yaml
+++ b/examples/mistral/bigstral-ds-zero3.yaml
@@ -53,3 +53,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
index 15edffb44e..8c4d80f792 100644
--- a/examples/mistral/config.yml
+++ b/examples/mistral/config.yml
@@ -43,3 +43,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml
index e6f46affb1..d54c3e30bb 100644
--- a/examples/mistral/lora-mps.yml
+++ b/examples/mistral/lora-mps.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml
index 9af4274fdf..161255468e 100644
--- a/examples/mistral/lora.yml
+++ b/examples/mistral/lora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml
index af707973ff..8d03786904 100644
--- a/examples/mistral/mistral-dpo-qlora.yml
+++ b/examples/mistral/mistral-dpo-qlora.yml
@@ -80,3 +80,5 @@ weight_decay: 0.0
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml
index e234b19a24..cec958c54c 100644
--- a/examples/mistral/mistral-qlora-fsdp.yml
+++ b/examples/mistral/mistral-qlora-fsdp.yml
@@ -74,3 +74,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml
index 6c0212b7cb..f37dc09fa3 100644
--- a/examples/mistral/mistral-qlora-orpo.yml
+++ b/examples/mistral/mistral-qlora-orpo.yml
@@ -69,3 +69,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml
index 3e3b45862d..4a492c5953 100644
--- a/examples/mistral/mistral-small-3.1-24B-lora.yml
+++ b/examples/mistral/mistral-small-3.1-24B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
index af6ba5a769..64ef9930ce 100644
--- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
@@ -72,3 +72,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml
index b1843a138d..c8d0a2711b 100644
--- a/examples/mistral/mixtral-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-qlora-fsdp.yml
@@ -77,3 +77,5 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml
index 4c256420cf..5be9b4db89 100644
--- a/examples/mistral/mixtral.yml
+++ b/examples/mistral/mixtral.yml
@@ -81,3 +81,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml
index 25e1d71551..100e4464ff 100644
--- a/examples/mistral/mixtral_22.yml
+++ b/examples/mistral/mixtral_22.yml
@@ -51,3 +51,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml
index 607e337014..08df36e150 100644
--- a/examples/mistral/qlora.yml
+++ b/examples/mistral/qlora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml
index 9bcbbeee0d..57f65d9666 100644
--- a/examples/orpheus/finetune.yml
+++ b/examples/orpheus/finetune.yml
@@ -50,3 +50,5 @@ weight_decay: 0.05
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
index ad4ce9cd44..9f3bbdf539 100644
--- a/examples/phi/lora-3.5.yaml
+++ b/examples/phi/lora-3.5.yaml
@@ -63,3 +63,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml
index 1562a73536..fc6d649d71 100644
--- a/examples/phi/phi-ft.yml
+++ b/examples/phi/phi-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml
index 4cd53db979..ccd92c817f 100644
--- a/examples/phi/phi-qlora.yml
+++ b/examples/phi/phi-qlora.yml
@@ -60,3 +60,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml
index ca733cc71b..853250ccbb 100644
--- a/examples/phi/phi2-ft.yml
+++ b/examples/phi/phi2-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml
index d0d14fea67..130298bc04 100644
--- a/examples/phi/phi3-ft-fsdp.yml
+++ b/examples/phi/phi3-ft-fsdp.yml
@@ -71,3 +71,5 @@ fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml
index 17c48da6f7..42b87e8d05 100644
--- a/examples/phi/phi3-ft.yml
+++ b/examples/phi/phi3-ft.yml
@@ -59,3 +59,5 @@ warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml
index 6ad0a5e999..ea769d202c 100644
--- a/examples/pixtral/lora-12b.yml
+++ b/examples/pixtral/lora-12b.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml
index e8932b9688..8ea6081999 100644
--- a/examples/qwen2-vl/lora-7b.yaml
+++ b/examples/qwen2-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml
index bd896c2b3d..69a74ae4a4 100644
--- a/examples/qwen2/dpo.yaml
+++ b/examples/qwen2/dpo.yaml
@@ -54,3 +54,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml
index 4afa24f3ce..af188f75d6 100644
--- a/examples/qwen2/prm.yaml
+++ b/examples/qwen2/prm.yaml
@@ -55,3 +55,5 @@ eval_steps: 100
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index ed2670ab61..861ce5517e 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml
index 822407a1fe..1854b8216b 100644
--- a/examples/qwen2/reward-model.yaml
+++ b/examples/qwen2/reward-model.yaml
@@ -26,7 +26,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
@@ -50,3 +49,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml
index 25d02805f7..13a97dec3e 100644
--- a/examples/qwen2_5-vl/lora-7b.yaml
+++ b/examples/qwen2_5-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml
index 45a4395ac1..1f148ece5c 100644
--- a/examples/qwen3/32b-qlora.yaml
+++ b/examples/qwen3/32b-qlora.yaml
@@ -67,3 +67,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml
index 6832b6af75..e4d0ed4fb0 100644
--- a/examples/qwen3/8b-qat-fsdp2.yml
+++ b/examples/qwen3/8b-qat-fsdp2.yml
@@ -76,3 +76,5 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml
index dc3377b4f2..762f9648d1 100644
--- a/examples/qwen3/qlora-fsdp.yaml
+++ b/examples/qwen3/qlora-fsdp.yaml
@@ -66,3 +66,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py
index 4df0100406..d3a3b32424 100644
--- a/src/axolotl/core/builders/base.py
+++ b/src/axolotl/core/builders/base.py
@@ -36,6 +36,7 @@
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
+ SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -135,6 +136,8 @@ def get_callbacks(self) -> list[TrainerCallback]:
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
+ if self.cfg.save_first_step:
+ callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index 5f804d6afa..bb777fc90f 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -64,7 +64,7 @@ def on_step_end(
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
@@ -100,11 +100,11 @@ def __init__(self, cfg):
def on_step_end(
self,
- args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
@@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback):
def __init__(self, cfg):
self.cfg = cfg
- self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
- _args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**_kwargs,
- ):
+ ) -> TrainerControl:
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
@@ -141,6 +140,21 @@ def on_step_end(
return control
+class SaveModelOnFirstStepCallback(TrainerCallback):
+ """Callback to save the model on the first step of training if enabled"""
+
+ def on_step_end(
+ self,
+ args: TrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl,
+ **_kwargs,
+ ) -> TrainerControl:
+ if state.global_step == 1:
+ control.should_save = True
+ return control
+
+
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index 1726feb67f..d5fea9158a 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -695,6 +695,7 @@ class AxolotlInputConfig(
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
+
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
@@ -716,6 +717,13 @@ class AxolotlInputConfig(
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
+ save_first_step: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Whether to checkpoint a model after the first step of training. Defaults to False."
+ },
+ )
+
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)
diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py
index 790b34f3e6..34e6c96447 100644
--- a/tests/e2e/integrations/test_cut_cross_entropy.py
+++ b/tests/e2e/integrations/test_cut_cross_entropy.py
@@ -44,6 +44,7 @@ def min_cfg(temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
@@ -98,6 +99,7 @@ def test_qwen2_w_cce(self, temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py
index 4734449fe2..8743efb981 100644
--- a/tests/e2e/integrations/test_hooks.py
+++ b/tests/e2e/integrations/test_hooks.py
@@ -153,6 +153,7 @@ def test_plugin_hooks(self, temp_dir):
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py
index 212450e89e..1ac3b537e7 100644
--- a/tests/e2e/integrations/test_kd.py
+++ b/tests/e2e/integrations/test_kd.py
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py
index 6ab3d7ab89..b1f5befdd4 100644
--- a/tests/e2e/integrations/test_liger.py
+++ b/tests/e2e/integrations/test_liger.py
@@ -50,6 +50,7 @@ def test_llama_wo_flce(self, temp_dir):
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -96,6 +97,7 @@ def test_llama_w_flce(self, temp_dir):
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py
index 247ae3bac2..dceecea9ff 100644
--- a/tests/e2e/integrations/test_llm_compressor.py
+++ b/tests/e2e/integrations/test_llm_compressor.py
@@ -81,6 +81,7 @@ def test_llmcompressor_plugin(
},
"save_compressed": save_compressed,
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py
index 5593c7eb62..80098e6841 100644
--- a/tests/e2e/multigpu/patched/test_sp.py
+++ b/tests/e2e/multigpu/patched/test_sp.py
@@ -69,6 +69,7 @@ def _run_sequence_parallel_test(
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py
index bdf5ada6ba..cbdf8de96b 100644
--- a/tests/e2e/multigpu/solo/test_flex.py
+++ b/tests/e2e/multigpu/solo/test_flex.py
@@ -61,6 +61,7 @@ def test_loss_llama(self, temp_dir):
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py
index c047343456..d022ae2d92 100644
--- a/tests/e2e/multigpu/solo/test_grpo.py
+++ b/tests/e2e/multigpu/solo/test_grpo.py
@@ -223,6 +223,7 @@ def test_llama_dora(self, temp_dir, num_gpus):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -317,6 +318,7 @@ def test_llama_lora_sp(self, temp_dir):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -409,6 +411,7 @@ def test_llama_fft(self, temp_dir, num_gpus):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py
index d6429cf63c..4f86278ffb 100644
--- a/tests/e2e/multigpu/test_eval.py
+++ b/tests/e2e/multigpu/test_eval.py
@@ -67,6 +67,7 @@ def test_eval_sample_packing(self, temp_dir):
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -138,6 +139,7 @@ def test_eval(self, temp_dir):
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py
index 3868d90f0f..4a7b101a83 100644
--- a/tests/e2e/multigpu/test_gemma3.py
+++ b/tests/e2e/multigpu/test_gemma3.py
@@ -71,6 +71,7 @@ def test_lora_ddp_packed(self, temp_dir):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py
index f0c74fbf8b..aab14dcc4f 100644
--- a/tests/e2e/multigpu/test_llama.py
+++ b/tests/e2e/multigpu/test_llama.py
@@ -69,6 +69,7 @@ def test_lora_ddp(self, temp_dir):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -135,6 +136,7 @@ def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -210,6 +212,7 @@ def test_dpo_lora_ddp(self, temp_dir):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -289,6 +292,7 @@ def test_dpo_qlora_ddp(self, temp_dir):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -365,6 +369,7 @@ def test_fsdp(self, temp_dir, gradient_accumulation_steps):
},
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
}
)
@@ -442,6 +447,7 @@ def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -520,6 +526,7 @@ def test_fsdp2_packed(
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if attention_backend == "flash":
@@ -605,6 +612,7 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -689,6 +697,7 @@ def test_ds_zero3_packed(
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -765,6 +774,7 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
**adapter,
}
)
@@ -840,6 +850,7 @@ def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -908,6 +919,7 @@ def test_fix_untrained_tokens(self, temp_dir):
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py
index 43a722b488..dd14222968 100644
--- a/tests/e2e/multigpu/test_ray.py
+++ b/tests/e2e/multigpu/test_ray.py
@@ -56,6 +56,7 @@ def test_lora_ddp(self, temp_dir):
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py
index 08b62accc0..1824443e79 100644
--- a/tests/e2e/patched/test_4d_multipack_llama.py
+++ b/tests/e2e/patched/test_4d_multipack_llama.py
@@ -55,6 +55,7 @@ def test_sdp_lora_packing(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ def test_torch_lora_packing(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py
index d494ed1ebd..3d5b3dc56c 100644
--- a/tests/e2e/patched/test_activation_checkpointing.py
+++ b/tests/e2e/patched/test_activation_checkpointing.py
@@ -69,6 +69,7 @@ def test_activation_checkpointing_offload(
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py
index ca8b21178f..38099b220b 100644
--- a/tests/e2e/patched/test_fa_xentropy.py
+++ b/tests/e2e/patched/test_fa_xentropy.py
@@ -62,6 +62,7 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py
index a593b07918..ef31b11c78 100644
--- a/tests/e2e/patched/test_falcon_samplepack.py
+++ b/tests/e2e/patched/test_falcon_samplepack.py
@@ -58,6 +58,7 @@ def test_qlora(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -99,6 +100,7 @@ def test_ft(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py
index f77a1fbe5a..fdaab558dc 100644
--- a/tests/e2e/patched/test_flattening.py
+++ b/tests/e2e/patched/test_flattening.py
@@ -61,6 +61,7 @@ def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py
index 1bbc82a38a..a3fe591ee8 100644
--- a/tests/e2e/patched/test_fused_llama.py
+++ b/tests/e2e/patched/test_fused_llama.py
@@ -53,6 +53,7 @@ def test_fft_packing(self, temp_dir):
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py
index d2dcc5e4b7..ba5556a593 100644
--- a/tests/e2e/patched/test_llama_s2_attention.py
+++ b/tests/e2e/patched/test_llama_s2_attention.py
@@ -58,6 +58,7 @@ def test_lora_s2_attn(self, temp_dir):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ def test_fft_s2_attn(self, temp_dir):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py
index 5df6bfecc6..fdf6adbc6c 100644
--- a/tests/e2e/patched/test_lora_llama_multipack.py
+++ b/tests/e2e/patched/test_lora_llama_multipack.py
@@ -55,6 +55,7 @@ def test_lora_packing(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -108,6 +109,7 @@ def test_lora_gptq_packed(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py
index 442089bae7..bea0f9c68c 100644
--- a/tests/e2e/patched/test_mistral_samplepack.py
+++ b/tests/e2e/patched/test_mistral_samplepack.py
@@ -56,6 +56,7 @@ def test_lora_packing(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -97,6 +98,7 @@ def test_ft_packing(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py
index 5f778660bb..09e427abd3 100644
--- a/tests/e2e/patched/test_mixtral_samplepack.py
+++ b/tests/e2e/patched/test_mixtral_samplepack.py
@@ -52,6 +52,7 @@ def test_qlora(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -90,6 +91,7 @@ def test_ft(self, temp_dir):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py
index 5ea88b001b..b90be23e43 100644
--- a/tests/e2e/patched/test_model_patches.py
+++ b/tests/e2e/patched/test_model_patches.py
@@ -45,6 +45,7 @@ def test_mixtral_multipack(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -78,6 +79,7 @@ def test_mistral_multipack(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py
index d4f59a128f..4769319aef 100644
--- a/tests/e2e/patched/test_peft_embeddings.py
+++ b/tests/e2e/patched/test_peft_embeddings.py
@@ -49,6 +49,7 @@ def test_peft_embeddings_upcast(self, temp_dir):
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py
index d241ce1853..1f0ddd6303 100644
--- a/tests/e2e/patched/test_phi_multipack.py
+++ b/tests/e2e/patched/test_phi_multipack.py
@@ -54,6 +54,7 @@ def test_ft_packed(self, temp_dir):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ def test_qlora_packed(self, temp_dir):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py
index 3639567335..54b8245eec 100644
--- a/tests/e2e/patched/test_resume.py
+++ b/tests/e2e/patched/test_resume.py
@@ -58,6 +58,7 @@ def test_resume_lora_packed(self, temp_dir):
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py
index 2b4d11b30c..4a2c69d457 100644
--- a/tests/e2e/patched/test_sp.py
+++ b/tests/e2e/patched/test_sp.py
@@ -47,6 +47,7 @@ def fixture_cfg():
"special_tokens": {
"pad_token": "<|endoftext|>",
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py
index 69171481c7..2c8ee4eb05 100644
--- a/tests/e2e/patched/test_unsloth_qlora.py
+++ b/tests/e2e/patched/test_unsloth_qlora.py
@@ -62,6 +62,7 @@ def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -112,6 +113,7 @@ def test_unsloth_llama_qlora_unpacked(self, temp_dir):
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -167,6 +169,7 @@ def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention)
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py
index 2799137136..76364fc0e5 100644
--- a/tests/e2e/solo/test_flex.py
+++ b/tests/e2e/solo/test_flex.py
@@ -49,6 +49,7 @@ def test_loss_llama(self, temp_dir):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py
index 7af5504963..f6fcad8415 100644
--- a/tests/e2e/solo/test_relora_llama.py
+++ b/tests/e2e/solo/test_relora_llama.py
@@ -65,6 +65,7 @@ def test_relora(self, temp_dir):
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py
index 7dfc4ae159..e4a47fb0aa 100644
--- a/tests/e2e/test_deepseekv3.py
+++ b/tests/e2e/test_deepseekv3.py
@@ -67,6 +67,7 @@ def test_lora_deepseekv3(self, temp_dir, sample_packing):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -116,6 +117,7 @@ def test_fft_deepseekv3(self, temp_dir, sample_packing):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py
index 2cdb576891..a1df695352 100644
--- a/tests/e2e/test_dpo.py
+++ b/tests/e2e/test_dpo.py
@@ -56,6 +56,7 @@ def test_dpo_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ def test_dpo_nll_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -154,6 +156,7 @@ def test_dpo_use_weighting(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -203,6 +206,7 @@ def test_kto_pair_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -251,6 +255,7 @@ def test_ipo_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -302,6 +307,7 @@ def test_orpo_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -370,6 +376,7 @@ def test_kto_lora(self, temp_dir):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py
index 9b65f8feb6..e4a06ad148 100644
--- a/tests/e2e/test_embeddings_lr.py
+++ b/tests/e2e/test_embeddings_lr.py
@@ -48,6 +48,7 @@ def test_train_w_embedding_lr_scale(self, temp_dir):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -93,6 +94,7 @@ def test_train_w_embedding_lr(self, temp_dir):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py
index 6271bba289..977497e5e8 100644
--- a/tests/e2e/test_evaluate.py
+++ b/tests/e2e/test_evaluate.py
@@ -36,6 +36,7 @@ def test_evaluate(self, temp_dir):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py
index 4f88e740c3..5be6efcf64 100644
--- a/tests/e2e/test_falcon.py
+++ b/tests/e2e/test_falcon.py
@@ -60,6 +60,7 @@ def test_lora(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ def test_lora_added_vocab(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -156,6 +158,7 @@ def test_ft(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py
index 3f00a13844..ef38d028dc 100644
--- a/tests/e2e/test_gemma3_text.py
+++ b/tests/e2e/test_gemma3_text.py
@@ -63,6 +63,7 @@ def test_lora_gemma3_text(self, temp_dir, sample_packing):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -113,6 +114,7 @@ def test_fft_gemma3_text(self, temp_dir, sample_packing):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py
index 2b180029ce..1e6df0be9c 100644
--- a/tests/e2e/test_llama.py
+++ b/tests/e2e/test_llama.py
@@ -45,6 +45,7 @@ def test_fft_trust_remote_code(self, temp_dir):
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -92,6 +93,7 @@ def test_fix_untrained_tokens(self, temp_dir):
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -136,6 +138,7 @@ def test_fix_untrained_tokens_already_trained(self, temp_dir):
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -176,6 +179,7 @@ def test_batch_flattening(self, temp_dir):
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py
index fdebf2173b..bd55023008 100644
--- a/tests/e2e/test_llama_pretrain.py
+++ b/tests/e2e/test_llama_pretrain.py
@@ -53,6 +53,7 @@ def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py
index ad4a83c6a2..760759bcaa 100644
--- a/tests/e2e/test_llama_vision.py
+++ b/tests/e2e/test_llama_vision.py
@@ -54,6 +54,7 @@ def test_lora_llama_vision_text_only_dataset(self, temp_dir):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index 3015653021..7e0ff46cf7 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -49,6 +49,7 @@ def test_lora(self, temp_dir):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py
index 1824619a6a..73d3bdc260 100644
--- a/tests/e2e/test_mamba.py
+++ b/tests/e2e/test_mamba.py
@@ -51,6 +51,7 @@ def test_fft(self, temp_dir):
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py
index 5d9b8ba8c1..f47f794e06 100644
--- a/tests/e2e/test_mistral.py
+++ b/tests/e2e/test_mistral.py
@@ -55,6 +55,7 @@ def test_lora(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -95,6 +96,7 @@ def test_ft(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py
index 761e59391c..3fe2bf70f1 100644
--- a/tests/e2e/test_mixtral.py
+++ b/tests/e2e/test_mixtral.py
@@ -61,6 +61,7 @@ def test_qlora_w_fa2(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -116,6 +117,7 @@ def test_qlora_wo_fa2(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -170,6 +172,7 @@ def test_16bit_lora_w_fa2(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -228,6 +231,7 @@ def test_16bit_lora_wo_fa2(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -273,6 +277,7 @@ def test_ft(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 53ef86022f..1d233a2013 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -55,6 +55,7 @@ def test_optimi_adamw(self, temp_dir):
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ def test_adopt_adamw(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -146,6 +148,7 @@ def test_muon(self, temp_dir):
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
+ "save_first_step": False,
}
)
@@ -184,6 +187,7 @@ def test_fft_schedule_free_adamw(self, temp_dir):
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -232,6 +236,7 @@ def test_came_pytorch(self, temp_dir):
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index cc2db72e0e..aec9d95f8b 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -48,6 +48,7 @@ def test_loss_packed(self, temp_dir):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index 88fda91915..ab3a636748 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -53,6 +53,7 @@ def test_phi_ft(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ def test_phi_qlora(self, temp_dir):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py
index abfe1b0c55..bd9eec48b6 100644
--- a/tests/e2e/test_process_reward_model_smollm2.py
+++ b/tests/e2e/test_process_reward_model_smollm2.py
@@ -49,6 +49,7 @@ def test_prm(self, temp_dir):
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py
index ef726079d0..139ae155ac 100644
--- a/tests/e2e/test_qat.py
+++ b/tests/e2e/test_qat.py
@@ -57,6 +57,7 @@ def test_qat(self, temp_dir):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -115,6 +116,7 @@ def test_qat_dpo(self, temp_dir):
"weight_dtype": "int8",
"group_size": 8,
},
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py
index aa8b9f6c05..59267d14db 100644
--- a/tests/e2e/test_qwen.py
+++ b/tests/e2e/test_qwen.py
@@ -59,6 +59,7 @@ def test_dpo(self, base_model, temp_dir):
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py
index 5d52bcc865..82513f99f2 100644
--- a/tests/e2e/test_reward_model_smollm2.py
+++ b/tests/e2e/test_reward_model_smollm2.py
@@ -58,6 +58,7 @@ def test_rm_lora(self, temp_dir):
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py
new file mode 100644
index 0000000000..5bbd2302b9
--- /dev/null
+++ b/tests/e2e/test_save_first_step.py
@@ -0,0 +1,102 @@
+"""
+E2E tests for relora llama
+"""
+
+import unittest
+from pathlib import Path
+
+import pytest
+
+from axolotl.common.datasets import load_datasets
+from axolotl.train import train
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import check_model_output_exists, with_temp_dir
+
+
+class TestSaveFirstStepCallback(unittest.TestCase):
+ """Test cases for save_first_step callback config."""
+
+ @with_temp_dir
+ def test_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": True,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
+
+ @with_temp_dir
+ def test_no_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": False,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ with pytest.raises(AssertionError):
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py
index e98378f08a..8f7a13aeea 100644
--- a/tests/e2e/test_schedulers.py
+++ b/tests/e2e/test_schedulers.py
@@ -51,6 +51,7 @@ def test_rex_scheduler(self, temp_dir):
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
+ "save_first_step": False,
}
)