Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions examples/convert_dcp_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.

import argparse
import os
import json

from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster
from nemo_reinforcer.models.policy.hf_policy import HfPolicy
from nemo_reinforcer.utils.config import load_config
import os
import torch
from nemo_reinforcer.utils.native_checkpoint import convert_dcp_to_hf


def parse_args():
Expand Down Expand Up @@ -51,41 +49,22 @@ def main():
with open(args.config, "r") as f:
config = json.load(f)

dcp_ckpt = args.dcp_ckpt_path
hf_ckpt = args.hf_ckpt_path

# Extract individual configs for easier access
policy_config = config["policy"]
cluster_config = config["cluster"]

init_ray()

cluster = RayVirtualCluster(
name="convert_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1,
model_name_or_path = config["policy"]["model_name"]
# TODO: After the following PR gets merged:
# https://github.com/NVIDIA/reinforcer/pull/148/files
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
# We can expose a arg at the top level --tokenizer_path to plumb that through.
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
# change release to release.
tokenizer_name_or_path = config["policy"]["model_name"]

hf_ckpt = convert_dcp_to_hf(
dcp_ckpt_path=args.dcp_ckpt_path,
hf_ckpt_path=args.hf_ckpt_path,
model_name_or_path=model_name_or_path,
tokenizer_name_or_path=tokenizer_name_or_path,
)

policy = HfPolicy(
cluster=cluster,
config=policy_config,
weights_path=dcp_ckpt,
init_optimizer=False,
)

policy.save_checkpoint(
weights_path=os.path.abspath(hf_ckpt),
save_hf=True,
save_torch_dist=False,
)

print(f"Saved HF checkpoint to: {hf_ckpt}-hf")

cluster.shutdown()
policy.worker_group.shutdown()
print(f"Saved HF checkpoint to: {hf_ckpt}")


if __name__ == "__main__":
Expand Down
59 changes: 59 additions & 0 deletions nemo_reinforcer/utils/native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Optional
import torch

from transformers import AutoConfig, AutoTokenizer
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import (
Expand All @@ -27,6 +28,7 @@
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save


## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Expand Down Expand Up @@ -202,3 +204,60 @@ def load_checkpoint(
print(f"Loading optimizer from {optimizer_path}")
optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)}
dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path)


def convert_dcp_to_hf(
dcp_ckpt_path: str,
hf_ckpt_path: str,
model_name_or_path: str,
tokenizer_name_or_path: str,
overwrite: bool = False,
):
"""Convert a Torch DCP checkpoint to a Hugging Face checkpoint.

This is not an optimized utility. If checkpoint is too large, consider saving DCP during training
and using this utility to convert to HF format.

Args:
dcp_ckpt_path (str): Path to DCP checkpoint
hf_ckpt_path (str): Path to save HF checkpoint
model_name_or_path (str): Model name or path for config
tokenizer_name_or_path (str, optional): Tokenizer name or path.
Defaults to model_name_or_path if None.
overwrite (bool, optional): Whether to overwrite existing checkpoint. Defaults to False.

Returns:
str: Path to the saved HF checkpoint

Raises:
FileExistsError: If HF checkpoint already exists and overwrite is False
"""
if os.path.exists(hf_ckpt_path) and not overwrite:
raise FileExistsError(
f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True."
)

os.makedirs(hf_ckpt_path, exist_ok=True)
weights_path = os.path.join(hf_ckpt_path, "pytorch_model.bin")
dcp_to_torch_save(dcp_ckpt_path, weights_path)

# Need to reload and save b/c the state dict is scoped inside the model key {"model": actual_state_dict}
state_dict = torch.load(weights_path)
assert set(state_dict.keys()) == {"model"}, (
f"We expect that the state dict only has the top level model key, but found: {state_dict.keys()}"
)
torch.save(state_dict["model"], weights_path)

config = AutoConfig.from_pretrained(model_name_or_path)
config.save_pretrained(hf_ckpt_path)

# TODO: After the following PR gets merged:
# https://github.com/NVIDIA/reinforcer/pull/148/files
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
# We can expose a arg at the top level --tokenizer_path to plumb that through.
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
# change release to release.
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.save_pretrained(hf_ckpt_path)

return hf_ckpt_path
104 changes: 98 additions & 6 deletions tests/unit/utils/test_native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,28 @@
save_checkpoint,
ModelState,
OptimizerState,
convert_dcp_to_hf,
)
from tests.unit.test_utils import simple_loss

# Define basic test config
simple_policy_config = {
"model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
"tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
"train_global_batch_size": 32,
"train_global_batch_size": 4,
"train_micro_batch_size": 1,
"logprob_batch_size": 1,
"max_total_sequence_length": 1024,
"precision": "float32",
"optimizer": {
"name": "torch.optim.AdamW",
"kwargs": {
"lr": 5e-6,
"weight_decay": 0.01,
"betas": [0.9, 0.999],
"eps": 1e-8,
},
},
}


Expand Down Expand Up @@ -85,12 +96,14 @@ def tokenizer():
@pytest.fixture(scope="function")
def policy(cluster, tokenizer):
"""Initialize the policy."""
return HfPolicy(
policy = HfPolicy(
cluster=cluster,
config=simple_policy_config,
init_optimizer=False,
init_optimizer=True,
init_reference_model=False,
)
yield policy
policy.worker_group.shutdown()


def get_dummy_state_dict(state_dict, dummy_dict={}):
Expand Down Expand Up @@ -118,6 +131,15 @@ def check_dict_equality(dict1, dict2):
assert dict1[k] == dict2[k]


def assert_recursive_dict_different(dict1, dict2):
"""Recursively assert that two dictionaries are different"""
try:
check_dict_equality(dict1, dict2)
except AssertionError:
return
raise AssertionError("Dictionaries are equal")


def test_model_state(mock_experiment):
test_model, _, _ = mock_experiment
model_state = ModelState(test_model)
Expand Down Expand Up @@ -275,14 +297,84 @@ def test_save_and_load_hf_checkpoint(policy):
"model.safetensors.index.json",
}

coverted_model = AutoModelForCausalLM.from_pretrained(
converted_model = AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
)
original_model = AutoModelForCausalLM.from_pretrained(
simple_policy_config["model_name"]
)

## make sure converted model matches the original
check_dict_equality(coverted_model.state_dict(), original_model.state_dict())
check_dict_equality(converted_model.state_dict(), original_model.state_dict())

policy.worker_group.shutdown()

def test_convert_dcp_to_hf(policy):
## warm up with a forward pass
## this is needed before saving a checkpoint because FSDP does some lazy initialization
input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128
attention_mask = torch.ones(4, 128)
input_lengths = attention_mask.sum(dim=1).to(torch.int32)
dummy_fwd_dict = BatchedDataDict(
{
"input_ids": input_ids,
"input_lengths": input_lengths,
"attention_mask": attention_mask,
"labels": torch.randint(0, 16000, (4, 128)),
}
)
policy.train(dummy_fwd_dict, simple_loss)

with TemporaryDirectory() as tmp_dir:
policy.save_checkpoint(
os.path.join(tmp_dir, "test_hf_and_dcp"),
save_hf=True,
save_torch_dist=True,
)

## make sure we save both HF and DCP checkpoints
assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == {
"__0_0.distcp",
"__1_0.distcp",
".metadata",
}
## 1B model has two shards
assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == {
"config.json",
"generation_config.json",
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
}

offline_converted_model_path = convert_dcp_to_hf(
os.path.join(tmp_dir, "test_hf_and_dcp"),
os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"),
simple_policy_config["model_name"],
# TODO: After the following PR gets merged:
# https://github.com/NVIDIA/reinforcer/pull/148/files
# tokenizer should be copied from policy/tokenizer/* instead of relying on the model name
# We can expose a arg at the top level --tokenizer_path to plumb that through.
# This is more stable than relying on the current NeMo-RL get_tokenizer() which can
# change release to release.
simple_policy_config["model_name"],
)

offline_converted_model = AutoModelForCausalLM.from_pretrained(
offline_converted_model_path
)

online_converted_model = AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
)
original_model = AutoModelForCausalLM.from_pretrained(
simple_policy_config["model_name"]
)

## make sure both conversions results in the same state dict
check_dict_equality(
online_converted_model.state_dict(), offline_converted_model.state_dict()
)
# Ensure the offline one is different from the original
assert_recursive_dict_different(
offline_converted_model.state_dict(), original_model.state_dict()
)