Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.utils import _clean_metadata_for_serialization
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.num_microbatches_calculator import update_num_microbatches
Expand Down Expand Up @@ -647,7 +648,14 @@ def save_checkpoint(
validate_sharding_integrity = not ckpt_cfg.ckpt_assume_constant_structure
else:
validate_sharding_integrity = True
save_strategy = get_default_save_sharded_strategy(ckpt_cfg.ckpt_format)
if ckpt_cfg.ckpt_format == "torch_dist":
save_strategy = TorchDistSaveShardedStrategy(
"torch_dist",
1,
thread_count=ckpt_cfg.storage_writers_per_rank,
)
else:
save_strategy = get_default_save_sharded_strategy(ckpt_cfg.ckpt_format)
if ckpt_cfg.ckpt_assume_constant_structure and ckpt_cfg.ckpt_format == "torch_dist":
save_strategy.use_cached_ckpt_structure = ckpt_cfg.ckpt_assume_constant_structure
if checkpointing_context is not None and "load_strategy" in checkpointing_context:
Expand Down
4 changes: 4 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,10 @@ class CheckpointConfig:
use_checkpoint_args: bool = False
"""Override any command line arguments with arguments from the checkpoint"""

storage_writers_per_rank: int = 1
"""Number of storage writers per rank for torch_dist checkpoint format.
Affects the number of checkpoint files: saving_ranks * storage_writers_per_rank."""

exit_on_missing_checkpoint: bool = False
"""If 'load' is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization."""

Expand Down
7 changes: 4 additions & 3 deletions tests/functional_tests/quantization/test_qat_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,12 @@ def test_qat_workflow(self, recipe_name, parallelism_overrides, tmp_path):
cp = context_parallel_size or 2
world_size = tp * pp * cp

# For torch_dist format, expect 2 * world_size .distcp files
# For torch_dist format, expect world_size .distcp files
# (one for model state, one for optimizer state per rank)
expected_distcp_files = 2 * world_size
# this is dictated by the checkpoint config's default value for storage_writers_per_rank
expected_distcp_files = world_size
assert len(distcp_files) == expected_distcp_files, (
f"Expected {expected_distcp_files} .distcp files (2 * {world_size} world_size), "
f"Expected {expected_distcp_files} .distcp files with {world_size} world_size), "
f"found {len(distcp_files)}: {distcp_files}"
)
print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ def test_gpt_oss_finetune_recipes(
finetune(config, forward_step)

# Verify checkpoints were saved
verify_checkpoint_files(config.checkpoint.save, 5)
verify_checkpoint_files(
config.checkpoint.save,
5,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ def run_distill_recipe_test(
distill(config=config)

# Basic verification that training completed successfully
verify_checkpoint_files(config.checkpoint.save, 10)
verify_checkpoint_files(
config.checkpoint.save,
10,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
14 changes: 12 additions & 2 deletions tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,12 @@ def test_nemotron_nano_v2_finetune_recipes(
finetune(config, forward_step)

# Verify checkpoints were saved
verify_checkpoint_files(config.checkpoint.save, config.train.train_iters)
verify_checkpoint_files(
config.checkpoint.save,
config.train.train_iters,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
Expand Down Expand Up @@ -552,7 +557,12 @@ def test_nemotron_3_nano_finetune_recipes(
finetune(config, forward_step)

# Verify checkpoints were saved
verify_checkpoint_files(config.checkpoint.save, config.train.train_iters)
verify_checkpoint_files(
config.checkpoint.save,
config.train.train_iters,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
14 changes: 12 additions & 2 deletions tests/functional_tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def run_pretrain_recipe_test(
pretrain(config, forward_step)

# Basic verification that training completed successfully
verify_checkpoint_files(config.checkpoint.save, 10)
verify_checkpoint_files(
config.checkpoint.save,
10,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
Expand Down Expand Up @@ -282,7 +287,12 @@ def run_pretrain_vl_recipe_test(
pretrain(config, vlm_forward_step)

# Basic verification that training completed successfully
verify_checkpoint_files(config.checkpoint.save, config.train.train_iters)
verify_checkpoint_files(
config.checkpoint.save,
config.train.train_iters,
ckpt_format=config.checkpoint.ckpt_format,
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
49 changes: 42 additions & 7 deletions tests/functional_tests/training/test_finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,24 @@ def test_pretrain_then_lora_finetune(self, tmp_path):
pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length
)
pretrain(pretrain_cfg, forward_step)
verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters)
verify_checkpoint_files(
pretrain_checkpoint_dir,
pretrain_iters,
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)

# Create LoRA config and run finetuning
lora_cfg = self._create_lora_config(
lora_iters, lora_checkpoint_dir, lora_tensorboard_dir, pretrain_checkpoint_dir, seq_length
)
finetune(lora_cfg, forward_step)
verify_checkpoint_files(lora_checkpoint_dir, lora_iters)
verify_checkpoint_files(
lora_checkpoint_dir,
lora_iters,
ckpt_format=lora_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank,
)
verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters)

finally:
Expand Down Expand Up @@ -129,7 +139,12 @@ def test_lora_save_and_resume(self, tmp_path):
# Run pretrain
pretrain(pretrain_cfg, forward_step)

verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters)
verify_checkpoint_files(
pretrain_checkpoint_dir,
pretrain_iters,
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)

# Second run: LoRA finetuning initial phase (will be "interrupted")

Expand All @@ -146,7 +161,12 @@ def test_lora_save_and_resume(self, tmp_path):
# Run initial LoRA finetuning (simulate job getting interrupted)
finetune(lora_initial_cfg, forward_step)

verify_checkpoint_files(lora_checkpoint_dir, initial_lora_iters)
verify_checkpoint_files(
lora_checkpoint_dir,
initial_lora_iters,
ckpt_format=lora_initial_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=lora_initial_cfg.checkpoint.storage_writers_per_rank,
)

# Third run: Resume LoRA finetuning from checkpoint (adapter-only states)
lora_resume_cfg = self._create_lora_config(
Expand All @@ -165,7 +185,12 @@ def test_lora_save_and_resume(self, tmp_path):
# Run resumed LoRA finetuning (should continue from iteration 6 to 12)
finetune(lora_resume_cfg, forward_step)

verify_checkpoint_files(lora_checkpoint_dir, total_lora_iters)
verify_checkpoint_files(
lora_checkpoint_dir,
total_lora_iters,
ckpt_format=lora_resume_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=lora_resume_cfg.checkpoint.storage_writers_per_rank,
)
verify_peft_checkpoint_smaller(
pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, initial_lora_iters
)
Expand Down Expand Up @@ -198,7 +223,12 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path):
pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length
)
pretrain(pretrain_cfg, forward_step)
verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters)
verify_checkpoint_files(
pretrain_checkpoint_dir,
pretrain_iters,
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)

# Create LoRA config with packed sequences and run finetuning
lora_cfg = self._create_lora_config(
Expand All @@ -214,7 +244,12 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path):
lora_cfg.validation.eval_iters = 2

finetune(lora_cfg, forward_step)
verify_checkpoint_files(lora_checkpoint_dir, lora_iters)
verify_checkpoint_files(
lora_checkpoint_dir,
lora_iters,
ckpt_format=lora_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank,
)
verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters)

finally:
Expand Down
21 changes: 18 additions & 3 deletions tests/functional_tests/training/test_megatron_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,12 @@ def test_fsdp_pretrain_with_checkpoint(self, tmp_path):

# Verify FSDP DTensor checkpoint files
torch.distributed.barrier()
verify_checkpoint_files(checkpoint_dir, total_iters, ckpt_format=cfg.checkpoint.ckpt_format)
verify_checkpoint_files(
checkpoint_dir,
total_iters,
ckpt_format=cfg.checkpoint.ckpt_format,
storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(tmp_path)
Expand Down Expand Up @@ -356,7 +361,12 @@ def test_fsdp_pretrain_save_resume(self, tmp_path):
torch.distributed.barrier()

# Verify FSDP DTensor checkpoint files from first run
verify_checkpoint_files(checkpoint_dir, checkpoint_iters, ckpt_format=cfg_first.checkpoint.ckpt_format)
verify_checkpoint_files(
checkpoint_dir,
checkpoint_iters,
ckpt_format=cfg_first.checkpoint.ckpt_format,
storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank,
)

torch.distributed.barrier()

Expand All @@ -377,7 +387,12 @@ def test_fsdp_pretrain_save_resume(self, tmp_path):
torch.distributed.barrier()

# Verify FSDP DTensor checkpoint files from second run (should be at total_iters)
verify_checkpoint_files(checkpoint_dir, total_iters, ckpt_format=cfg_second.checkpoint.ckpt_format)
verify_checkpoint_files(
checkpoint_dir,
total_iters,
ckpt_format=cfg_second.checkpoint.ckpt_format,
storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(shared_base_dir)
14 changes: 12 additions & 2 deletions tests/functional_tests/training/test_pretrain_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def test_pretrain_save_load(self, tmp_path):
torch.distributed.barrier()

# Verify checkpoint files from first run
verify_checkpoint_files(checkpoint_dir, checkpoint_iters)
verify_checkpoint_files(
checkpoint_dir,
checkpoint_iters,
ckpt_format=cfg_first.checkpoint.ckpt_format,
storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank,
)

torch.distributed.barrier()

Expand Down Expand Up @@ -248,7 +253,12 @@ def test_pretrain_save_load(self, tmp_path):
torch.distributed.barrier()

# Verify checkpoint files from second run
verify_checkpoint_files(checkpoint_dir, total_iters)
verify_checkpoint_files(
checkpoint_dir,
total_iters,
ckpt_format=cfg_second.checkpoint.ckpt_format,
storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(shared_base_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):

try:
finetune(cfg, forward_step)
verify_checkpoint_files(checkpoint_dir, cfg.train.train_iters)
verify_checkpoint_files(
checkpoint_dir,
cfg.train.train_iters,
ckpt_format=cfg.checkpoint.ckpt_format,
storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank,
)
finally:
clear_directories(shared_dir)
14 changes: 12 additions & 2 deletions tests/functional_tests/training/test_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def test_pretrain_then_finetune(self, tmp_path):
pretrain_iters, pretrain_checkpoint_dir, pretrain_tensorboard_dir, seq_length
)
pretrain(pretrain_cfg, forward_step)
verify_checkpoint_files(pretrain_checkpoint_dir, pretrain_iters)
verify_checkpoint_files(
pretrain_checkpoint_dir,
pretrain_iters,
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
)

# Create finetune config and run (lower LR, different seed, use pretrained checkpoint)
finetune_cfg = self._create_config(
Expand All @@ -92,7 +97,12 @@ def test_pretrain_then_finetune(self, tmp_path):
pretrained_checkpoint=pretrain_checkpoint_dir,
)
finetune(finetune_cfg, forward_step)
verify_checkpoint_files(finetune_checkpoint_dir, finetune_iters)
verify_checkpoint_files(
finetune_checkpoint_dir,
finetune_iters,
ckpt_format=finetune_cfg.checkpoint.ckpt_format,
storage_writers_per_rank=finetune_cfg.checkpoint.storage_writers_per_rank,
)

finally:
clear_directories(shared_base_dir)
Expand Down
16 changes: 12 additions & 4 deletions tests/functional_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ def clear_directories(path: str) -> None:
torch.distributed.barrier()


def verify_checkpoint_files(checkpoint_dir: str, iteration_count: int, ckpt_format: str = "torch_dist") -> None:
"""Verify that checkpoint files were created correctly for different checkpoint formats.
def verify_checkpoint_files(
checkpoint_dir: str,
iteration_count: int,
ckpt_format: str = "torch_dist",
storage_writers_per_rank: int = 1,
) -> None:
"""Verify that checkpoint files were created correctly.

Args:
checkpoint_dir: Directory containing checkpoints
iteration_count: Expected iteration number for the checkpoint
ckpt_format: Checkpoint format ("torch_dist", "fsdp_dtensor", etc.)
storage_writers_per_rank: Storage writers per rank (torch_dist only).
Pass config.checkpoint.storage_writers_per_rank.
Affects expected file count: world_size * storage_writers_per_rank.
"""
if torch.distributed.is_initialized():
torch.distributed.barrier()
Expand Down Expand Up @@ -139,15 +147,15 @@ def verify_checkpoint_files(checkpoint_dir: str, iteration_count: int, ckpt_form
distcp_files = [f for f in os.listdir(final_iter_dir) if f.endswith(".distcp")]

if ckpt_format == "torch_dist":
num_expected_files = 2 * torch.distributed.get_world_size()
num_expected_files = storage_writers_per_rank * torch.distributed.get_world_size()
elif ckpt_format == "fsdp_dtensor":
# fsdp_dtensor format creates .distcp files (one per rank)
num_expected_files = torch.distributed.get_world_size()
else:
raise ValueError(f"Unsupported checkpoint format for verification: {ckpt_format}")

assert len(distcp_files) == num_expected_files, (
f"Expected {num_expected_files} .distcp files for fsdp_dtensor, found {len(distcp_files)}: {distcp_files}"
f"Expected {num_expected_files} .distcp files for {ckpt_format}, found {len(distcp_files)}: {distcp_files}"
)


Expand Down