diff --git a/src/megatron/bridge/training/checkpointing.py b/src/megatron/bridge/training/checkpointing.py index ca41f4faec..7062fd2650 100644 --- a/src/megatron/bridge/training/checkpointing.py +++ b/src/megatron/bridge/training/checkpointing.py @@ -41,6 +41,7 @@ FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) +from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy from megatron.core.dist_checkpointing.utils import _clean_metadata_for_serialization from megatron.core.msc_utils import MultiStorageClientFeature from megatron.core.num_microbatches_calculator import update_num_microbatches @@ -647,7 +648,14 @@ def save_checkpoint( validate_sharding_integrity = not ckpt_cfg.ckpt_assume_constant_structure else: validate_sharding_integrity = True - save_strategy = get_default_save_sharded_strategy(ckpt_cfg.ckpt_format) + if ckpt_cfg.ckpt_format == "torch_dist": + save_strategy = TorchDistSaveShardedStrategy( + "torch_dist", + 1, + thread_count=ckpt_cfg.storage_writers_per_rank, + ) + else: + save_strategy = get_default_save_sharded_strategy(ckpt_cfg.ckpt_format) if ckpt_cfg.ckpt_assume_constant_structure and ckpt_cfg.ckpt_format == "torch_dist": save_strategy.use_cached_ckpt_structure = ckpt_cfg.ckpt_assume_constant_structure if checkpointing_context is not None and "load_strategy" in checkpointing_context: diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 6b99adfc6d..59f26ae1e9 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -842,6 +842,10 @@ class CheckpointConfig: use_checkpoint_args: bool = False """Override any command line arguments with arguments from the checkpoint""" + storage_writers_per_rank: int = 1 + """Number of storage writers per rank for torch_dist checkpoint format. + Affects the number of checkpoint files: saving_ranks * storage_writers_per_rank.""" + exit_on_missing_checkpoint: bool = False """If 'load' is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization.""" diff --git a/tests/functional_tests/quantization/test_qat_workflow.py b/tests/functional_tests/quantization/test_qat_workflow.py index 44b690b96e..c7ac0dfaff 100644 --- a/tests/functional_tests/quantization/test_qat_workflow.py +++ b/tests/functional_tests/quantization/test_qat_workflow.py @@ -299,11 +299,12 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path): cp = context_parallel_size or 2 world_size = tp * pp * cp - # For torch_dist format, expect 2 * world_size .distcp files + # For torch_dist format, expect world_size .distcp files # (one for model state, one for optimizer state per rank) - expected_distcp_files = 2 * world_size + # this is dictated by the checkpoint config's default value for storage_writers_per_rank + expected_distcp_files = world_size assert len(distcp_files) == expected_distcp_files, ( - f"Expected {expected_distcp_files} .distcp files (2 * {world_size} world_size), " + f"Expected {expected_distcp_files} .distcp files with {world_size} world_size), " f"found {len(distcp_files)}: {distcp_files}" ) print( diff --git a/tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py b/tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py index b56c0dda63..2e43962ae3 100644 --- a/tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py @@ -261,7 +261,12 @@ def test_gpt_oss_finetune_recipes( finetune(config, forward_step) # Verify checkpoints were saved - verify_checkpoint_files(config.checkpoint.save, 5) + verify_checkpoint_files( + config.checkpoint.save, + 5, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) diff --git a/tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py b/tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py index 4c9e0676aa..c73170d354 100644 --- a/tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py +++ b/tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py @@ -162,7 +162,12 @@ def run_distill_recipe_test( distill(config=config) # Basic verification that training completed successfully - verify_checkpoint_files(config.checkpoint.save, 10) + verify_checkpoint_files( + config.checkpoint.save, + 10, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) diff --git a/tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py b/tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py index b30ca867fb..3c904df13d 100644 --- a/tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py @@ -284,7 +284,12 @@ def test_nemotron_nano_v2_finetune_recipes( finetune(config, forward_step) # Verify checkpoints were saved - verify_checkpoint_files(config.checkpoint.save, config.train.train_iters) + verify_checkpoint_files( + config.checkpoint.save, + config.train.train_iters, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) @@ -552,7 +557,12 @@ def test_nemotron_3_nano_finetune_recipes( finetune(config, forward_step) # Verify checkpoints were saved - verify_checkpoint_files(config.checkpoint.save, config.train.train_iters) + verify_checkpoint_files( + config.checkpoint.save, + config.train.train_iters, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) diff --git a/tests/functional_tests/recipes/utils.py b/tests/functional_tests/recipes/utils.py index 49399c805e..97dbc6fae9 100644 --- a/tests/functional_tests/recipes/utils.py +++ b/tests/functional_tests/recipes/utils.py @@ -119,7 +119,12 @@ def run_pretrain_recipe_test( pretrain(config, forward_step) # Basic verification that training completed successfully - verify_checkpoint_files(config.checkpoint.save, 10) + verify_checkpoint_files( + config.checkpoint.save, + 10, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) @@ -282,7 +287,12 @@ def run_pretrain_vl_recipe_test( pretrain(config, vlm_forward_step) # Basic verification that training completed successfully - verify_checkpoint_files(config.checkpoint.save, config.train.train_iters) + verify_checkpoint_files( + config.checkpoint.save, + config.train.train_iters, + ckpt_format=config.checkpoint.ckpt_format, + storage_writers_per_rank=config.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) diff --git a/tests/functional_tests/training/test_finetune_lora.py b/tests/functional_tests/training/test_finetune_lora.py index 52c64c2fcb..c3bf2799b2 100644 --- a/tests/functional_tests/training/test_finetune_lora.py +++ b/tests/functional_tests/training/test_finetune_lora.py @@ -87,14 +87,24 @@ def test_pretrain_then_lora_finetune(self, tmp_path): pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length ) pretrain(pretrain_cfg, forward_step) - verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters) + verify_checkpoint_files( + pretrain_checkpoint_dir, + pretrain_iters, + ckpt_format=pretrain_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank, + ) # Create LoRA config and run finetuning lora_cfg = self._create_lora_config( lora_iters, lora_checkpoint_dir, lora_tensorboard_dir, pretrain_checkpoint_dir, seq_length ) finetune(lora_cfg, forward_step) - verify_checkpoint_files(lora_checkpoint_dir, lora_iters) + verify_checkpoint_files( + lora_checkpoint_dir, + lora_iters, + ckpt_format=lora_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank, + ) verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters) finally: @@ -129,7 +139,12 @@ def test_lora_save_and_resume(self, tmp_path): # Run pretrain pretrain(pretrain_cfg, forward_step) - verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters) + verify_checkpoint_files( + pretrain_checkpoint_dir, + pretrain_iters, + ckpt_format=pretrain_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank, + ) # Second run: LoRA finetuning initial phase (will be "interrupted") @@ -146,7 +161,12 @@ def test_lora_save_and_resume(self, tmp_path): # Run initial LoRA finetuning (simulate job getting interrupted) finetune(lora_initial_cfg, forward_step) - verify_checkpoint_files(lora_checkpoint_dir, initial_lora_iters) + verify_checkpoint_files( + lora_checkpoint_dir, + initial_lora_iters, + ckpt_format=lora_initial_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=lora_initial_cfg.checkpoint.storage_writers_per_rank, + ) # Third run: Resume LoRA finetuning from checkpoint (adapter-only states) lora_resume_cfg = self._create_lora_config( @@ -165,7 +185,12 @@ def test_lora_save_and_resume(self, tmp_path): # Run resumed LoRA finetuning (should continue from iteration 6 to 12) finetune(lora_resume_cfg, forward_step) - verify_checkpoint_files(lora_checkpoint_dir, total_lora_iters) + verify_checkpoint_files( + lora_checkpoint_dir, + total_lora_iters, + ckpt_format=lora_resume_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=lora_resume_cfg.checkpoint.storage_writers_per_rank, + ) verify_peft_checkpoint_smaller( pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, initial_lora_iters ) @@ -198,7 +223,12 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path): pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length ) pretrain(pretrain_cfg, forward_step) - verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters) + verify_checkpoint_files( + pretrain_checkpoint_dir, + pretrain_iters, + ckpt_format=pretrain_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank, + ) # Create LoRA config with packed sequences and run finetuning lora_cfg = self._create_lora_config( @@ -214,7 +244,12 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path): lora_cfg.validation.eval_iters = 2 finetune(lora_cfg, forward_step) - verify_checkpoint_files(lora_checkpoint_dir, lora_iters) + verify_checkpoint_files( + lora_checkpoint_dir, + lora_iters, + ckpt_format=lora_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank, + ) verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters) finally: diff --git a/tests/functional_tests/training/test_megatron_fsdp.py b/tests/functional_tests/training/test_megatron_fsdp.py index 753e04f688..40dba4b9f4 100644 --- a/tests/functional_tests/training/test_megatron_fsdp.py +++ b/tests/functional_tests/training/test_megatron_fsdp.py @@ -312,7 +312,12 @@ def test_fsdp_pretrain_with_checkpoint(self, tmp_path): # Verify FSDP DTensor checkpoint files torch.distributed.barrier() - verify_checkpoint_files(checkpoint_dir, total_iters, ckpt_format=cfg.checkpoint.ckpt_format) + verify_checkpoint_files( + checkpoint_dir, + total_iters, + ckpt_format=cfg.checkpoint.ckpt_format, + storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(tmp_path) @@ -356,7 +361,12 @@ def test_fsdp_pretrain_save_resume(self, tmp_path): torch.distributed.barrier() # Verify FSDP DTensor checkpoint files from first run - verify_checkpoint_files(checkpoint_dir, checkpoint_iters, ckpt_format=cfg_first.checkpoint.ckpt_format) + verify_checkpoint_files( + checkpoint_dir, + checkpoint_iters, + ckpt_format=cfg_first.checkpoint.ckpt_format, + storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank, + ) torch.distributed.barrier() @@ -377,7 +387,12 @@ def test_fsdp_pretrain_save_resume(self, tmp_path): torch.distributed.barrier() # Verify FSDP DTensor checkpoint files from second run (should be at total_iters) - verify_checkpoint_files(checkpoint_dir, total_iters, ckpt_format=cfg_second.checkpoint.ckpt_format) + verify_checkpoint_files( + checkpoint_dir, + total_iters, + ckpt_format=cfg_second.checkpoint.ckpt_format, + storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(shared_base_dir) diff --git a/tests/functional_tests/training/test_pretrain_resume.py b/tests/functional_tests/training/test_pretrain_resume.py index 53b48f0284..7300e88fdf 100644 --- a/tests/functional_tests/training/test_pretrain_resume.py +++ b/tests/functional_tests/training/test_pretrain_resume.py @@ -163,7 +163,12 @@ def test_pretrain_save_load(self, tmp_path): torch.distributed.barrier() # Verify checkpoint files from first run - verify_checkpoint_files(checkpoint_dir, checkpoint_iters) + verify_checkpoint_files( + checkpoint_dir, + checkpoint_iters, + ckpt_format=cfg_first.checkpoint.ckpt_format, + storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank, + ) torch.distributed.barrier() @@ -248,7 +253,12 @@ def test_pretrain_save_load(self, tmp_path): torch.distributed.barrier() # Verify checkpoint files from second run - verify_checkpoint_files(checkpoint_dir, total_iters) + verify_checkpoint_files( + checkpoint_dir, + total_iters, + ckpt_format=cfg_second.checkpoint.ckpt_format, + storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(shared_base_dir) diff --git a/tests/functional_tests/training/test_seqpacking_cp_example.py b/tests/functional_tests/training/test_seqpacking_cp_example.py index 3c10c1b5de..e0581c3038 100644 --- a/tests/functional_tests/training/test_seqpacking_cp_example.py +++ b/tests/functional_tests/training/test_seqpacking_cp_example.py @@ -99,6 +99,11 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path): try: finetune(cfg, forward_step) - verify_checkpoint_files(checkpoint_dir, cfg.train.train_iters) + verify_checkpoint_files( + checkpoint_dir, + cfg.train.train_iters, + ckpt_format=cfg.checkpoint.ckpt_format, + storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(shared_dir) diff --git a/tests/functional_tests/training/test_sft.py b/tests/functional_tests/training/test_sft.py index 9837a5a124..c2bba449ed 100644 --- a/tests/functional_tests/training/test_sft.py +++ b/tests/functional_tests/training/test_sft.py @@ -79,7 +79,12 @@ def test_pretrain_then_finetune(self, tmp_path): pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length ) pretrain(pretrain_cfg, forward_step) - verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters) + verify_checkpoint_files( + pretrain_checkpoint_dir, + pretrain_iters, + ckpt_format=pretrain_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank, + ) # Create finetune config and run (lower LR, different seed, use pretrained checkpoint) finetune_cfg = self._create_config( @@ -92,7 +97,12 @@ def test_pretrain_then_finetune(self, tmp_path): pretrained_checkpoint=pretrain_checkpoint_dir, ) finetune(finetune_cfg, forward_step) - verify_checkpoint_files(finetune_checkpoint_dir, finetune_iters) + verify_checkpoint_files( + finetune_checkpoint_dir, + finetune_iters, + ckpt_format=finetune_cfg.checkpoint.ckpt_format, + storage_writers_per_rank=finetune_cfg.checkpoint.storage_writers_per_rank, + ) finally: clear_directories(shared_base_dir) diff --git a/tests/functional_tests/utils.py b/tests/functional_tests/utils.py index e5182e10ea..c99ffbe386 100644 --- a/tests/functional_tests/utils.py +++ b/tests/functional_tests/utils.py @@ -102,13 +102,21 @@ def clear_directories(path: str) -> None: torch.distributed.barrier() -def verify_checkpoint_files(checkpoint_dir: str, iteration_count: int, ckpt_format: str = "torch_dist") -> None: - """Verify that checkpoint files were created correctly for different checkpoint formats. +def verify_checkpoint_files( + checkpoint_dir: str, + iteration_count: int, + ckpt_format: str = "torch_dist", + storage_writers_per_rank: int = 1, +) -> None: + """Verify that checkpoint files were created correctly. Args: checkpoint_dir: Directory containing checkpoints iteration_count: Expected iteration number for the checkpoint ckpt_format: Checkpoint format ("torch_dist", "fsdp_dtensor", etc.) + storage_writers_per_rank: Storage writers per rank (torch_dist only). + Pass config.checkpoint.storage_writers_per_rank. + Affects expected file count: world_size * storage_writers_per_rank. """ if torch.distributed.is_initialized(): torch.distributed.barrier() @@ -139,7 +147,7 @@ def verify_checkpoint_files(checkpoint_dir: str, iteration_count: int, ckpt_form distcp_files = [f for f in os.listdir(final_iter_dir) if f.endswith(".distcp")] if ckpt_format == "torch_dist": - num_expected_files = 2 * torch.distributed.get_world_size() + num_expected_files = storage_writers_per_rank * torch.distributed.get_world_size() elif ckpt_format == "fsdp_dtensor": # fsdp_dtensor format creates .distcp files (one per rank) num_expected_files = torch.distributed.get_world_size() @@ -147,7 +155,7 @@ def verify_checkpoint_files(checkpoint_dir: str, iteration_count: int, ckpt_form raise ValueError(f"Unsupported checkpoint format for verification: {ckpt_format}") assert len(distcp_files) == num_expected_files, ( - f"Expected {num_expected_files} .distcp files for fsdp_dtensor, found {len(distcp_files)}: {distcp_files}" + f"Expected {num_expected_files} .distcp files for {ckpt_format}, found {len(distcp_files)}: {distcp_files}" )