diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index ee347eeb9e..b314d93d7a 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -13,12 +13,10 @@ # limitations under the License. import argparse -import os import json - -from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster -from nemo_reinforcer.models.policy.hf_policy import HfPolicy -from nemo_reinforcer.utils.config import load_config +import os +import torch +from nemo_reinforcer.utils.native_checkpoint import convert_dcp_to_hf def parse_args(): @@ -51,41 +49,22 @@ def main(): with open(args.config, "r") as f: config = json.load(f) - dcp_ckpt = args.dcp_ckpt_path - hf_ckpt = args.hf_ckpt_path - - # Extract individual configs for easier access - policy_config = config["policy"] - cluster_config = config["cluster"] - - init_ray() - - cluster = RayVirtualCluster( - name="convert_cluster", - bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] - * cluster_config["num_nodes"], - use_gpus=True, - num_gpus_per_node=cluster_config["gpus_per_node"], - max_colocated_worker_groups=1, + model_name_or_path = config["policy"]["model_name"] + # TODO: After the following PR gets merged: + # https://github.com/NVIDIA/reinforcer/pull/148/files + # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name + # We can expose a arg at the top level --tokenizer_path to plumb that through. + # This is more stable than relying on the current NeMo-RL get_tokenizer() which can + # change release to release. + tokenizer_name_or_path = config["policy"]["model_name"] + + hf_ckpt = convert_dcp_to_hf( + dcp_ckpt_path=args.dcp_ckpt_path, + hf_ckpt_path=args.hf_ckpt_path, + model_name_or_path=model_name_or_path, + tokenizer_name_or_path=tokenizer_name_or_path, ) - - policy = HfPolicy( - cluster=cluster, - config=policy_config, - weights_path=dcp_ckpt, - init_optimizer=False, - ) - - policy.save_checkpoint( - weights_path=os.path.abspath(hf_ckpt), - save_hf=True, - save_torch_dist=False, - ) - - print(f"Saved HF checkpoint to: {hf_ckpt}-hf") - - cluster.shutdown() - policy.worker_group.shutdown() + print(f"Saved HF checkpoint to: {hf_ckpt}") if __name__ == "__main__": diff --git a/nemo_reinforcer/utils/native_checkpoint.py b/nemo_reinforcer/utils/native_checkpoint.py index 6f22ea82fd..aa6c4f9f01 100644 --- a/nemo_reinforcer/utils/native_checkpoint.py +++ b/nemo_reinforcer/utils/native_checkpoint.py @@ -19,6 +19,7 @@ from typing import Any, Optional import torch +from transformers import AutoConfig, AutoTokenizer import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import ( @@ -27,6 +28,7 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save ## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html @@ -202,3 +204,60 @@ def load_checkpoint( print(f"Loading optimizer from {optimizer_path}") optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path) + + +def convert_dcp_to_hf( + dcp_ckpt_path: str, + hf_ckpt_path: str, + model_name_or_path: str, + tokenizer_name_or_path: str, + overwrite: bool = False, +): + """Convert a Torch DCP checkpoint to a Hugging Face checkpoint. + + This is not an optimized utility. If checkpoint is too large, consider saving DCP during training + and using this utility to convert to HF format. + + Args: + dcp_ckpt_path (str): Path to DCP checkpoint + hf_ckpt_path (str): Path to save HF checkpoint + model_name_or_path (str): Model name or path for config + tokenizer_name_or_path (str, optional): Tokenizer name or path. + Defaults to model_name_or_path if None. + overwrite (bool, optional): Whether to overwrite existing checkpoint. Defaults to False. + + Returns: + str: Path to the saved HF checkpoint + + Raises: + FileExistsError: If HF checkpoint already exists and overwrite is False + """ + if os.path.exists(hf_ckpt_path) and not overwrite: + raise FileExistsError( + f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True." + ) + + os.makedirs(hf_ckpt_path, exist_ok=True) + weights_path = os.path.join(hf_ckpt_path, "pytorch_model.bin") + dcp_to_torch_save(dcp_ckpt_path, weights_path) + + # Need to reload and save b/c the state dict is scoped inside the model key {"model": actual_state_dict} + state_dict = torch.load(weights_path) + assert set(state_dict.keys()) == {"model"}, ( + f"We expect that the state dict only has the top level model key, but found: {state_dict.keys()}" + ) + torch.save(state_dict["model"], weights_path) + + config = AutoConfig.from_pretrained(model_name_or_path) + config.save_pretrained(hf_ckpt_path) + + # TODO: After the following PR gets merged: + # https://github.com/NVIDIA/reinforcer/pull/148/files + # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name + # We can expose a arg at the top level --tokenizer_path to plumb that through. + # This is more stable than relying on the current NeMo-RL get_tokenizer() which can + # change release to release. + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + tokenizer.save_pretrained(hf_ckpt_path) + + return hf_ckpt_path diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 8f71badea1..a5b9c9d637 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -26,17 +26,28 @@ save_checkpoint, ModelState, OptimizerState, + convert_dcp_to_hf, ) +from tests.unit.test_utils import simple_loss # Define basic test config simple_policy_config = { "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", - "train_global_batch_size": 32, + "train_global_batch_size": 4, "train_micro_batch_size": 1, "logprob_batch_size": 1, "max_total_sequence_length": 1024, "precision": "float32", + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, } @@ -85,12 +96,14 @@ def tokenizer(): @pytest.fixture(scope="function") def policy(cluster, tokenizer): """Initialize the policy.""" - return HfPolicy( + policy = HfPolicy( cluster=cluster, config=simple_policy_config, - init_optimizer=False, + init_optimizer=True, init_reference_model=False, ) + yield policy + policy.worker_group.shutdown() def get_dummy_state_dict(state_dict, dummy_dict={}): @@ -118,6 +131,15 @@ def check_dict_equality(dict1, dict2): assert dict1[k] == dict2[k] +def assert_recursive_dict_different(dict1, dict2): + """Recursively assert that two dictionaries are different""" + try: + check_dict_equality(dict1, dict2) + except AssertionError: + return + raise AssertionError("Dictionaries are equal") + + def test_model_state(mock_experiment): test_model, _, _ = mock_experiment model_state = ModelState(test_model) @@ -275,7 +297,7 @@ def test_save_and_load_hf_checkpoint(policy): "model.safetensors.index.json", } - coverted_model = AutoModelForCausalLM.from_pretrained( + converted_model = AutoModelForCausalLM.from_pretrained( os.path.join(tmp_dir, "test_hf_and_dcp-hf") ) original_model = AutoModelForCausalLM.from_pretrained( @@ -283,6 +305,76 @@ def test_save_and_load_hf_checkpoint(policy): ) ## make sure converted model matches the original - check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) + check_dict_equality(converted_model.state_dict(), original_model.state_dict()) - policy.worker_group.shutdown() + +def test_convert_dcp_to_hf(policy): + ## warm up with a forward pass + ## this is needed before saving a checkpoint because FSDP does some lazy initialization + input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128 + attention_mask = torch.ones(4, 128) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + dummy_fwd_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "labels": torch.randint(0, 16000, (4, 128)), + } + ) + policy.train(dummy_fwd_dict, simple_loss) + + with TemporaryDirectory() as tmp_dir: + policy.save_checkpoint( + os.path.join(tmp_dir, "test_hf_and_dcp"), + save_hf=True, + save_torch_dist=True, + ) + + ## make sure we save both HF and DCP checkpoints + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == { + "__0_0.distcp", + "__1_0.distcp", + ".metadata", + } + ## 1B model has two shards + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == { + "config.json", + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + } + + offline_converted_model_path = convert_dcp_to_hf( + os.path.join(tmp_dir, "test_hf_and_dcp"), + os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"), + simple_policy_config["model_name"], + # TODO: After the following PR gets merged: + # https://github.com/NVIDIA/reinforcer/pull/148/files + # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name + # We can expose a arg at the top level --tokenizer_path to plumb that through. + # This is more stable than relying on the current NeMo-RL get_tokenizer() which can + # change release to release. + simple_policy_config["model_name"], + ) + + offline_converted_model = AutoModelForCausalLM.from_pretrained( + offline_converted_model_path + ) + + online_converted_model = AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_dir, "test_hf_and_dcp-hf") + ) + original_model = AutoModelForCausalLM.from_pretrained( + simple_policy_config["model_name"] + ) + + ## make sure both conversions results in the same state dict + check_dict_equality( + online_converted_model.state_dict(), offline_converted_model.state_dict() + ) + # Ensure the offline one is different from the original + assert_recursive_dict_different( + offline_converted_model.state_dict(), original_model.state_dict() + )