Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,43 @@ def _load_model_weights_from_checkpoint(
torch.distributed.barrier()


def load_model_weights(
model: list[MegatronModule],
checkpoint_path: str,
*,
fully_parallel_load: bool = False,
strict: bool = True,
return_state_dict: bool = False,
) -> Optional[StateDict]:
"""Load only model weights from a ``torch_dist`` checkpoint.

Simple API for loading pretrained model weights without optimizer state,
RNG state, or iteration tracking.

Args:
model: The model(s) to load weights into.
checkpoint_path: Path to the checkpoint directory directly containing model weights.
fully_parallel_load: Apply full load parallelization across data parallel ranks.
strict: Whether to enforce strict state dict loading.
return_state_dict: If True, return the state dict instead of loading into model.

Returns:
If return_state_dict is True, returns the model state dict.
Otherwise returns None.

Example:
>>> load_model_weights(model, "/checkpoints/iter_0000005")
>>> state_dict = load_model_weights(model, "/checkpoints/iter_0000005", return_state_dict=True)
"""
return _load_model_weights_from_checkpoint(
checkpoint_path,
model,
fully_parallel_load=fully_parallel_load,
strict=strict,
return_state_dict=return_state_dict,
)


def load_checkpoint(
state: GlobalState,
model: list[MegatronModule],
Expand Down
12 changes: 3 additions & 9 deletions src/megatron/bridge/training/model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,7 @@ def build_and_load_model(
The model instance with loaded weights if return_state_dict is False,
otherwise returns a dictionary containing the full, unsharded model state_dict.
"""
from megatron.bridge.training.checkpointing import (
_load_model_weights_from_checkpoint,
)
from megatron.bridge.training.checkpointing import load_model_weights
from megatron.bridge.training.mlm_compat.arguments import _tokenizer_config_from_args
from megatron.bridge.training.mlm_compat.model import _get_model, _gpt_provider, _mamba_provider
from megatron.bridge.training.post_training.checkpointing import has_modelopt_state
Expand Down Expand Up @@ -324,15 +322,11 @@ def _load_checkpoint():

load_modelopt_state(model, checkpoint_path)

maybe_state_dict = _load_model_weights_from_checkpoint(
checkpoint_path, model, return_state_dict=return_state_dict
)

maybe_state_dict = load_model_weights(model, checkpoint_path, return_state_dict=return_state_dict)
if return_state_dict:
del model
return maybe_state_dict
else:
return model
return model

if skip_temp_dist_context:
return _load_checkpoint()
Expand Down
163 changes: 163 additions & 0 deletions tests/functional_tests/training/test_load_model_weights_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""End-to-end tests for the load_model_weights public API.

Verifies save -> load roundtrips for both ``torch_dist`` and ``fsdp_dtensor``
checkpoint formats using real GPT models on GPU.

Multi-GPU safe: rank 0 creates the temp directory and broadcasts the path
to all other ranks before any checkpoint I/O.
"""

import os
import tempfile

import pytest
import torch
import torch.distributed as dist
from megatron.core import parallel_state
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.training.checkpointing import load_model_weights
from megatron.bridge.training.model_load_save import save_megatron_model
from tests.functional_tests.utils import broadcast_path, clear_directories, initialize_distributed


def _create_gpt_model() -> list[MegatronModule]:
"""Create a minimal GPT model on GPU for checkpoint roundtrip testing."""
provider = GPTModelProvider(
num_layers=2,
hidden_size=128,
num_attention_heads=4,
seq_length=64,
vocab_size=256,
ffn_hidden_size=256,
)
provider._pg_collection = ProcessGroupCollection.use_mpu_process_groups()
provider.finalize()
model = provider.provide_distributed_model(ddp_config=None, wrap_with_ddp=False)
return [m.cuda() for m in model]


def _snapshot_weights(model: list[MegatronModule]) -> dict[str, torch.Tensor]:
"""Deep-copy all named parameters from the first model chunk to CPU."""
return {name: param.data.detach().cpu().clone() for name, param in model[0].named_parameters()}


def _randomize_weights(model: list[MegatronModule]) -> None:
"""Replace all weights with random values so they differ from the original."""
with torch.no_grad():
for param in model[0].parameters():
param.data.uniform_(-1.0, 1.0)


class TestLoadModelWeightsE2E:
"""Save -> load roundtrip tests that exercise real checkpoint I/O on GPU."""

@pytest.fixture(autouse=True)
def setup_distributed(self):
"""Initialize distributed and model-parallel state (once per process)."""
initialize_distributed()

if not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
)

from megatron.bridge.training.initialize import _set_random_seed

pg_collection = ProcessGroupCollection.use_mpu_process_groups()
_set_random_seed(
seed_=1234,
data_parallel_random_init=False,
te_rng_tracker=True,
inference_rng_tracker=False,
pg_collection=pg_collection,
)

yield

from megatron.core.rerun_state_machine import destroy_rerun_state_machine

destroy_rerun_state_machine()

@pytest.fixture()
def shared_tmp_dir(self):
"""Create a temp directory on rank 0 and broadcast the path to all ranks."""
if dist.get_rank() == 0:
tmp_dir = tempfile.mkdtemp()
else:
tmp_dir = ""
tmp_dir = broadcast_path(tmp_dir)

yield tmp_dir

clear_directories(tmp_dir)

# ------------------------------------------------------------------
# torch_dist format
# ------------------------------------------------------------------

@pytest.mark.run_only_on("GPU")
def test_torch_dist_save_load_roundtrip(self, shared_tmp_dir):
"""Weights survive a torch_dist save -> load_model_weights cycle."""
save_dir = os.path.join(shared_tmp_dir, "checkpoint")
ckpt_path = os.path.join(save_dir, "iter_0000000")

model = _create_gpt_model()
original = _snapshot_weights(model)

save_megatron_model(model, save_dir, ckpt_format="torch_dist")
assert os.path.isdir(ckpt_path), f"Checkpoint dir not created at {ckpt_path}"

model2 = _create_gpt_model()
_randomize_weights(model2)

for name in original:
assert not torch.equal(model2[0].state_dict()[name].cpu(), original[name]), (
f"Weights for '{name}' should differ before load"
)

load_model_weights(model2, ckpt_path)

for name, expected in original.items():
actual = model2[0].state_dict()[name].cpu()
assert torch.allclose(actual, expected, atol=1e-6), (
f"torch_dist weight mismatch for '{name}': max diff = {(actual - expected).abs().max().item():.2e}"
)

@pytest.mark.run_only_on("GPU")
def test_torch_dist_return_state_dict(self, shared_tmp_dir):
"""load_model_weights can return a state dict instead of loading in-place."""
save_dir = os.path.join(shared_tmp_dir, "checkpoint")
ckpt_path = os.path.join(save_dir, "iter_0000000")

model = _create_gpt_model()
original = _snapshot_weights(model)

save_megatron_model(model, save_dir, ckpt_format="torch_dist")

state_dict = load_model_weights(model, ckpt_path, return_state_dict=True)

assert state_dict is not None, "return_state_dict=True should return a dict"
assert "model" in state_dict, "state dict must contain 'model' key"

for name in original:
assert name in state_dict["model"], f"Key '{name}' missing from returned state dict"
71 changes: 71 additions & 0 deletions tests/unit_tests/training/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,77 @@ def test_load_model_state_dict_non_strict_raises(self):
_load_model_state_dict(module, {"w": 1}, strict=False)


class TestLoadModelWeights:
"""Test the load_model_weights function."""

@pytest.fixture
def mock_model(self):
"""Create a mock model for testing."""
model = Mock()
model.sharded_state_dict.return_value = {"weight": torch.randn(10, 10)}
return [model]

@patch("megatron.bridge.training.checkpointing._load_model_weights_from_checkpoint")
def test_load_model_weights_delegates_to_loader(
self,
mock_load_weights,
mock_model,
):
"""Test load_model_weights delegates to _load_model_weights_from_checkpoint."""
from megatron.bridge.training.checkpointing import load_model_weights

load_model_weights(mock_model, "/checkpoint/iter_0000005")

mock_load_weights.assert_called_once_with(
"/checkpoint/iter_0000005",
mock_model,
fully_parallel_load=False,
strict=True,
return_state_dict=False,
)

@patch("megatron.bridge.training.checkpointing._load_model_weights_from_checkpoint")
def test_load_model_weights_with_fully_parallel_load(
self,
mock_load_weights,
mock_model,
):
"""Test load_model_weights with fully_parallel_load enabled."""
from megatron.bridge.training.checkpointing import load_model_weights

load_model_weights(mock_model, "/checkpoint/iter_0000005", fully_parallel_load=True)

mock_load_weights.assert_called_once_with(
"/checkpoint/iter_0000005",
mock_model,
fully_parallel_load=True,
strict=True,
return_state_dict=False,
)

@patch("megatron.bridge.training.checkpointing._load_model_weights_from_checkpoint")
def test_load_model_weights_return_state_dict(
self,
mock_load_weights,
mock_model,
):
"""Test load_model_weights with return_state_dict=True."""
from megatron.bridge.training.checkpointing import load_model_weights

mock_load_weights.return_value = {"model": {"weight": torch.randn(10, 10)}}

result = load_model_weights(mock_model, "/checkpoint/iter_0000005", return_state_dict=True)

mock_load_weights.assert_called_once_with(
"/checkpoint/iter_0000005",
mock_model,
fully_parallel_load=False,
strict=True,
return_state_dict=True,
)
assert result is not None


class TestMegatronLMCompatibility:
"""Test Megatron-LM checkpoint compatibility features."""

Expand Down
Loading