From d36467bbd5ec1ddb988f8dc033d40dda769e4c0d Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sat, 10 May 2025 11:21:20 +0800 Subject: [PATCH 1/8] refactor model merger --- scripts/model_merger.py | 882 ++++++++++++++++++++++------------------ 1 file changed, 493 insertions(+), 389 deletions(-) diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 590f4508c04..7d4d85cf092 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -12,11 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. + +To merge FSDP checkpoints: +```sh +python scripts/model_merger.py merge \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +To merge Megatron checkpoints: +```sh +python scripts/model_merger.py merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +For more details, please refer to documentation: +https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model +""" + import argparse import os import re +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Tuple +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional import numpy as np import torch @@ -27,8 +54,8 @@ AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq, - AutoTokenizer, GenerationConfig, + PretrainedConfig, ) try: @@ -37,296 +64,296 @@ except ImportError: from torch.distributed._tensor import DTensor -parser = argparse.ArgumentParser() -parser.add_argument("--backend", type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"]) -parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights") -parser.add_argument("--is-value-model", action="store_true", help="Whether the model loaded as value model") -parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") -parser.add_argument( - "--local_dir", - type=str, - required=True, - help=("The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`."), -) -parser.add_argument("--target_dir", required=False, default="tmp", type=str, help="The path for the target model") -parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") -parser.add_argument("--test", action="store_true", help="test correctness of hf_model") -parser.add_argument( - "--test_hf_dir", - type=str, - required=False, - help="test correctness of hf_model, , with hf_model in checkpoint.contents", -) -parser.add_argument("--private", required=False, default=False, help="Whether to upload the model to private repo") +from tqdm import tqdm + +from verl.utils import hf_processor, hf_tokenizer + + +@dataclass +class ModelMergerConfig: + operation: str # 'merge' or 'test' + backend: str + local_dir: str + target_dir: Optional[str] = "tmp" + hf_upload_path: Optional[str] = None + private: bool = False + test_hf_dir: Optional[str] = None + tie_word_embedding: bool = False + is_value_model: bool = False + hf_model_path: Optional[str] = None + hf_upload: bool = field(init=False) + + def __post_init__(self): + self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) + if self.operation == "test": + self.target_dir = None + self.hf_upload_path = None + self.private = False + + +class BaseModelMerger(ABC): + def __init__(self, config: ModelMergerConfig): + self.config = config + self.config_path = config.local_dir + + if config.hf_model_path: + print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ") + self.config_path = config.hf_model_path + + self.model_config = AutoConfig.from_pretrained(self.config_path) + + def get_transformers_auto_model_class(self): + if "ForTokenClassification" in self.model_config.architectures[0]: + return AutoModelForTokenClassification + elif "ForCausalLM" in self.model_config.architectures[0]: + return AutoModelForCausalLM + elif "ForConditionalGeneration" in self.model_config.architectures[0]: + return AutoModelForVision2Seq + + raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + + def patch_model_generation_config(self, model): + """ + The generation_config created from model config may be different to the pretrained model, + this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 + + This function patch the generation_config created from model config to the pretrained model. + """ + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained(self.config_path) + except OSError: + print(f"Warning: Generation config file not found in {self.config_path}, using a generation config created from the model config.") + return model + + def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + with torch.device("meta"): + model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) + model.to_empty(device="cpu") + model = self.patch_model_generation_config(model) + + print(f"Saving model to {self.config.target_dir}") + model.save_pretrained(self.config.target_dir, state_dict=state_dict) + del state_dict + del model + + processor = hf_processor(self.config_path) + tokenizer = hf_tokenizer(self.config_path) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def upload_to_huggingface(self): + from huggingface_hub import HfApi + + api = HfApi() + api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + + @abstractmethod + def merge_and_save(self): + raise NotImplementedError("Subclasses should implement this method") + + +class FSDPModelMerger(BaseModelMerger): + def _get_world_size(self) -> int: + """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" + for filename in os.listdir(self.config.local_dir): + match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) + if match: + return int(match.group(1)) + raise FileNotFoundError(f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}") + + def _load_rank_zero_state_dict(self, world_size: int) -> dict: + return torch.load(Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", map_location="cpu", weights_only=False) + + def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + """ + Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. + If no DTensor is found, infers a simple FSDP mesh based on world_size. + """ + pivot_key = sorted(list(state_dict.keys()))[0] + weight = state_dict[pivot_key] + + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([world_size], dtype=np.int64) + mesh_dim_names = ("fsdp",) -args = parser.parse_args() -os.makedirs(args.target_dir, exist_ok=True) -if args.test: - assert args.test_hf_dir is not None, "You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here" + return mesh, mesh_dim_names + def _calculate_shard_configuration(self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]) -> tuple[int, tuple[int, ...]]: + """Calculates the total number of shards and the shape of the device mesh.""" + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" -def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - else: - raise ValueError(f"Unsupported placement: {placement}") - - -def upload_model_to_huggingface(hf_path): - # Push to hugging face - from huggingface_hub import HfApi - - api = HfApi() - api.create_repo(repo_id=args.hf_upload_path, private=args.private, exist_ok=True) - api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model") - - -def test_fsdp_state_dict( - auto_model_class, - original_hf_model_path: str, - collected_state_dict: Dict[str, torch.Tensor], -) -> bool: - # load original model using bf16 since we collected state_dict with bf16 - original_model = auto_model_class.from_pretrained(original_hf_model_path, torch_dtype=torch.bfloat16) - original_state_dict = original_model.state_dict() - del original_model # Free memory - - original_keys = set(original_state_dict.keys()) - collected_keys = set(collected_state_dict.keys()) - - missing_keys = original_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - original_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in original_keys: - original_shape = original_state_dict[key].shape - collected_shape = collected_state_dict[key].shape - assert original_shape == collected_shape, f"Shape mismatch for key '{key}': original {original_shape} vs collected {collected_shape}" - - original_dtype = original_state_dict[key].dtype - collected_dtype = collected_state_dict[key].dtype - assert original_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {original_dtype} vs collected {collected_dtype}" - - torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-4, rtol=1e-4) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - return True - - -def patch_model_generation_config(model, hf_model_path): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(hf_model_path) - except OSError: - print(f"Warning: Generation config file not found in {hf_model_path}, using a generation config created from the model config.") - pass - return model - - -def convert_fsdp_checkpoints_to_hfmodels(): - local_dir = args.local_dir - - # copy rank zero to find the shape of (dp, fsdp) - rank = 0 - world_size = 0 - for filename in os.listdir(local_dir): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - world_size = match.group(1) - break - assert world_size, "No model file with the proper format" - - state_dict = torch.load(os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu", weights_only=False) - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([int(world_size)], dtype=np.int64) - mesh_dim_names = ("fsdp",) + if "tp" in mesh_dim_names: + # TODO: "tp" is not supported yet due to the above assert + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + return total_shards, mesh_shape + + def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + """Merges a list of tensors based on their DTensor placement""" + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + + raise NotImplementedError(f"Unsupported placement: {placement}") + + def _load_and_merge_state_dicts(self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]) -> dict[str, torch.Tensor]: + model_state_dict_lst = [None] * total_shards + + def process_one_shard(rank: int, model_state_dict_lst: list): + model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] + for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + future.result() + + # Merge state dicts from all shards + state_dict = {} + param_placements: dict[str, list] = {} + + for key in set(model_state_dict_lst[0].keys()): + state_dict[key] = [] + for model_state_shard in model_state_dict_lst: + # add tensor shard in order of rank to state_dict[key] + tensor = model_state_shard.pop(key) + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + + placements = tuple(tensor.placements) + # replicated placement at dp dimension can be discarded + if mesh_dim_names[0] in ("dp", "ddp"): + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + del model_state_dict_lst + + # Merge tensors + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + if key in param_placements: + # merge shards + placements: tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = self._merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + return state_dict - if "tp" in mesh_dim_names: - # fsdp * tp - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - # fsdp - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) + def merge_and_save(self): + world_size = self._get_world_size() + rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") + mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - model_state_dict_lst = [] - model_state_dict_lst.append(state_dict) - model_state_dict_lst.extend([""] * (total_shards - 1)) + total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + print(f"Processing model shards with {total_shards} {mesh_shape} in total") - def process_one_shard(rank, model_state_dict_lst): - model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict + merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - for rank in range(1, total_shards): - executor.submit(process_one_shard, rank, model_state_dict_lst) - state_dict = {} - param_placements: Dict[str, List[Placement]] = {} - keys = set(model_state_dict_lst[0].keys()) - for key in keys: - state_dict[key] = [] - for model_state_dict in model_state_dict_lst: - try: - tensor = model_state_dict.pop(key) - except Exception: - print("-" * 30) - print(model_state_dict) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] == "dp" or mesh_dim_names[0] == "ddp": - placements = placements[1:] - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: Tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._test_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() else: - state_dict[key] = torch.cat(state_dict[key], dim=0) + raise ValueError(f"Unknown operation: {self.config.operation}") - hf_path = os.path.join(local_dir, "huggingface") if args.target_dir is None else args.target_dir - config = AutoConfig.from_pretrained(args.hf_model_path) + def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() - if "ForTokenClassification" in config.architectures[0]: - auto_model = AutoModelForTokenClassification - elif "ForCausalLM" in config.architectures[0]: - auto_model = AutoModelForCausalLM - elif "ForConditionalGeneration" in config.architectures[0]: - auto_model = AutoModelForVision2Seq - else: - raise NotImplementedError(f"Unknown architecture {config['architectures']}") + hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_state_dict = hf_model.state_dict() + del hf_model - if args.test: - print("Running compatibility test") - test_fsdp_state_dict(auto_model, args.test_hf_dir, state_dict) + hf_model_keys = set(hf_state_dict.keys()) + collected_keys = set(state_dict.keys()) - with torch.device("meta"): - model = auto_model.from_config(config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = patch_model_generation_config(model, args.hf_model_path) + missing_keys = hf_model_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - print(f"Saving model to {hf_path}") - model.save_pretrained(hf_path, state_dict=state_dict) - del state_dict - del model + extra_keys = collected_keys - hf_model_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - print("Saving tokenizer") - tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) - tokenizer.save_pretrained(hf_path) + for key in hf_model_keys: + hf_shape = hf_state_dict[key].shape + collected_shape = state_dict[key].shape + assert hf_shape == collected_shape, f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - if args.hf_upload_path: - upload_model_to_huggingface(hf_path) + hf_dtype = hf_state_dict[key].dtype + collected_dtype = state_dict[key].dtype + assert hf_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" + torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) -def get_tp_pp_rank_from_sharded_dir(sharded_dir): - match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) - tp_rank = int(match.group(1)) - pp_rank = int(match.group(2)) - return tp_rank, pp_rank + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") -def check_megatron_checkpoint_path(model_path): - sharded_dirs = sorted(os.listdir(model_path)) - tp_size = 0 - pp_size = 0 - for sharded_dir in sharded_dirs: +class MegatronModelMerger(BaseModelMerger): + def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) assert match, f"Invalid sharded dir {sharded_dir}" - assert "model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}" tp_rank = int(match.group(1)) pp_rank = int(match.group(2)) - if tp_size < tp_rank + 1: - tp_size = tp_rank + 1 - if pp_size < pp_rank + 1: - pp_size = pp_rank + 1 - return sharded_dirs, tp_size, pp_size - - -def convert_megatron_checkpoints_to_hfmodels(): - from verl.utils.megatron_utils import get_model_checkpoint_path - - local_path = args.local_dir - - model_ckpt_path = get_model_checkpoint_path(local_path) - sharded_dirs, tp_size, pp_size = check_megatron_checkpoint_path(model_ckpt_path) - mp_size = len(sharded_dirs) - - model_state_dict_lst = [] - for i in range(pp_size): - model_state_dict_lst.append([]) - for j in range(tp_size): - model_state_dict_lst[i].append("") - - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}") - - def process_one_shard(shard_dir, model_state_dict_lst): - model_path = os.path.join(model_ckpt_path, shard_dir, "model.pt") - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - tp_rank, pp_rank = get_tp_pp_rank_from_sharded_dir(shard_dir) - model_state_dict_lst[pp_rank][tp_rank] = state_dict - - # with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - # for rank in range(1, mp_size): - # executor.submit(process_one_shard, sharded_dirs[rank]) - for sharded_dir in sharded_dirs: - process_one_shard(sharded_dir, model_state_dict_lst) - - state_dict = {} - config = AutoConfig.from_pretrained(args.hf_model_path) - if args.test: - ref_state_dict = load_file(os.path.join(args.test_hf_dir, "model.safetensors")) - - def merge_across_tp(key, tp_data): + return tp_rank, pp_rank + + def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: + """ + Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). + Determines TP and PP sizes from directory names. + """ + tp_size = 0 + pp_size = 0 + sharded_dirs = sorted(os.listdir(model_path)) + for sharded_dir in sharded_dirs: + assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" + tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) + tp_size = max(tp_size, tp_rank + 1) + pp_size = max(pp_size, pp_rank + 1) + return sharded_dirs, tp_size, pp_size + + def _merge_across_tp(self, key: str, tp_data: list[torch.Tensor], config: PretrainedConfig, tp_size: int, is_value_model: bool = False) -> torch.Tensor | list[torch.Tensor]: if "linear_fc1.weight" in key: # if the tensor is gate and proj gate_lst = [] @@ -337,7 +364,8 @@ def merge_across_tp(key, tp_data): up_lst.append(up) gate = torch.cat(gate_lst, dim=0) up = torch.cat(up_lst, dim=0) - tp_data = [gate, up] + return [gate, up] + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: # if the tensor is qkv, for each param on tp, split into q, k, v # concat q, k, v separately. @@ -349,6 +377,7 @@ def merge_across_tp(key, tp_data): assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] + for infer_param in tp_data: num_query_groups_per_partition = config.num_key_value_heads // tp_size for chunk in infer_param.chunk(num_query_groups_per_partition): @@ -361,86 +390,134 @@ def merge_across_tp(key, tp_data): q_lst.append(q) k_lst.append(k) v_lst.append(v) + q = torch.cat(q_lst, dim=0) k = torch.cat(k_lst, dim=0) v = torch.cat(v_lst, dim=0) + return [q, k, v] - tp_data = [q, k, v] - - elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args.is_value_model: - tp_data = tp_data[0] + elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and is_value_model: + return tp_data[0] else: dim = 0 if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: dim = 1 - tp_data = torch.cat(tp_data, dim=dim) - - return tp_data - - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() - for key in keys: - if "extra_state" in key: - continue - if args.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - new_key = key - if "decoder.layers." in key: - local_layer_no = int(key.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = key.split(".") - new_key_list[2] = str(global_layer_no) - new_key = ".".join(new_key_list) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = merge_across_tp(new_key, tp_data) - if not isinstance(merged, list): - state_dict[new_key] = merged - elif len(merged) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], merged): - state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d - elif len(merged) == 2: - # split gate up - state_dict[new_key.replace("linear_fc1", "gate_proj")] = merged[0] - state_dict[new_key.replace("linear_fc1", "up_proj")] = merged[1] - layers_cum += layers_handled + 1 # zero based - - del model_state_dict_lst - - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ("self_attention.linear_q", "self_attn.q_proj"), - ("self_attention.linear_k", "self_attn.k_proj"), - ("self_attention.linear_v", "self_attn.v_proj"), - ] - - if args.test: + return torch.cat(tp_data, dim=dim) + + def _load_state_dicts(self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int) -> list[list[dict]]: + model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] + + def _process_one_megatron_shard(sharded_dir: str): + model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" + state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) + tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) + model_state_dict_lst[pp_rank][tp_rank] = state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] + for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): + future.result() + + return model_state_dict_lst + + def _merge_state_dicts(self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int) -> dict[str, torch.Tensor]: + state_dict = {} + vpp_size = len(model_state_dict_lst[0][0]) + layers_cum = 0 + + for vpp_rank in range(vpp_size): + for pp_rank in range(pp_size): + layers_handled = 0 + keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() + for key in keys: + if "extra_state" in key: + continue + if self.config.tie_word_embedding and ("output_layer" in key): + print("skip lm_head and reward_head loading because of tie_word_embeddings") + continue + + new_key = key + if "decoder.layers." in key: + local_layer_no = int(key.split(".")[2]) + layers_handled = max(local_layer_no, layers_handled) + global_layer_no = local_layer_no + layers_cum + new_key_list = key.split(".") + new_key_list[2] = str(global_layer_no) + new_key = ".".join(new_key_list) + + tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] + merged = self._merge_across_tp(new_key, tp_data, self.model_config, tp_size, self.config.is_value_model) + + if not isinstance(merged, list): + state_dict[new_key] = merged + elif len(merged) == 3: + # split qkv + for n, d in zip(["q", "k", "v"], merged): + state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d + elif len(merged) == 2: + # split gate up + state_dict[new_key.replace("linear_fc1", "gate_proj")] = merged[0] + state_dict[new_key.replace("linear_fc1", "up_proj")] = merged[1] + + layers_cum += layers_handled + 1 # zero based + + return state_dict + + def merge_and_save(self): + from verl.utils.megatron_utils import get_model_checkpoint_path + + model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) + sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) + print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") + + model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) + merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) + del model_state_dict_lst + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._test_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): + """ + Compares the merged Megatron state_dict against a reference safetensors model. + Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. + """ + ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") + + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ("self_attention.linear_q", "self_attn.q_proj"), + ("self_attention.linear_k", "self_attn.k_proj"), + ("self_attention.linear_v", "self_attn.v_proj"), + ] + for original_name, loaded_weight in state_dict.items(): - name = _replace_name(original_name, params_mapping) + name = self._replace_name(original_name, params_mapping) if not name or name.endswith(".bias") and name not in ref_state_dict: continue if "rotary_emb.inv_freq" in name: continue - if args.tie_word_embedding and "lm_head.weight" in name: + if self.config.tie_word_embedding and "lm_head.weight" in name: continue if name not in ref_state_dict: raise RuntimeError(f"key: {name} not exist in state_dict") @@ -448,63 +525,90 @@ def merge_across_tp(key, tp_data): assert loaded_weight.dtype == param.dtype torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4) - print("Writing to local disk") - hf_path = os.path.join(args.local_dir, "huggingface") if args.target_dir is None else args.target_dir + def _replace_name(self, megatron_name: str, name_mapping: list[tuple[str, str]]) -> str: + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + return megatron_name # Return original name if no mapping found + + +def main(): + parser = argparse.ArgumentParser(description="verl model merger") + subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + + base_op_parser = argparse.ArgumentParser(add_help=False) + base_op_parser.add_argument("--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model") + base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") + base_op_parser.add_argument("--hf_model_path", type=str, default=None, help="(Deprecated) Path to the original Hugging Face model for config.") + base_op_parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights (currently only Megatron supported)") + base_op_parser.add_argument("--is-value-model", action="store_true", help="Whether the model is a value model (currently only Megatron supported)") + + merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser.add_argument("--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model") + merge_parser.add_argument("--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model") + merge_parser.add_argument("--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository") + + test_parser = subparsers.add_parser("test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model") + test_parser.add_argument("--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing") + + args = parser.parse_args() + + common_config_args = { + "operation": args.operation, + "backend": args.backend, + "tie_word_embedding": args.tie_word_embedding, + "is_value_model": args.is_value_model, + "local_dir": args.local_dir, + "hf_model_path": args.hf_model_path, + } + + if args.operation == "merge": + config = ModelMergerConfig( + **common_config_args, + target_dir=args.target_dir, + hf_upload_path=args.hf_upload_path, + private=args.private, + test_hf_dir=None, + ) + os.makedirs(config.target_dir, exist_ok=True) + elif args.operation == "test": + config = ModelMergerConfig( + **common_config_args, + test_hf_dir=args.test_hf_dir, + # the following args are not used by test operation + target_dir=None, + hf_upload_path=None, + private=False, + ) + else: + raise NotImplementedError(f"Unknown operation: {args.operation}") - if "ForTokenClassification" in config.architectures[0]: - auto_model = AutoModelForTokenClassification - elif "ForCausalLM" in config.architectures[0]: - auto_model = AutoModelForCausalLM - elif "ForConditionalGeneration" in config.architectures[0]: - auto_model = AutoModelForVision2Seq + if config.backend == "fsdp": + merger = FSDPModelMerger(config) + elif config.backend == "megatron": + merger = MegatronModelMerger(config) else: - raise NotImplementedError(f"Unknown architecture {config['architectures']}") - - with torch.device("meta"): - model = auto_model.from_config(config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = patch_model_generation_config(model, args.hf_model_path) - - print(f"Saving model to {hf_path}") - model.save_pretrained(hf_path, state_dict=state_dict) - del state_dict - del model - - print("Saving tokenizer") - tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) - tokenizer.save_pretrained(hf_path) - - if args.hf_upload_path: - upload_model_to_huggingface(hf_path) - - -def _replace_name(megatron_name, name_mapping): - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name + raise NotImplementedError(f"Unknown backend: {config.backend}") + + merger.merge_and_save() if __name__ == "__main__": - if args.backend == "fsdp": - convert_fsdp_checkpoints_to_hfmodels() - elif args.backend == "megatron": - convert_megatron_checkpoints_to_hfmodels() - else: - raise NotImplementedError(f"{args.backend} not supported") + main() From e55d7fe3f409175a059f7a94060f4ea0b29187db Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sat, 10 May 2025 11:22:18 +0800 Subject: [PATCH 2/8] docs: update model merger usage --- docs/advance/checkpoint.rst | 47 +++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/docs/advance/checkpoint.rst b/docs/advance/checkpoint.rst index d577e40b8c3..082a1e93881 100644 --- a/docs/advance/checkpoint.rst +++ b/docs/advance/checkpoint.rst @@ -67,27 +67,50 @@ Convert FSDP and Megatron Checkpoints to HuggingFace Format Model We provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model. The tool is located in ``scripts/model_merger.py``. -The arguments are as follows: +The script supports two main sub-commands: `merge` (to convert and save checkpoints) and `test` (to validate merged checkpoints against a reference model). +The arguments for the `merge` sub-command are as follows: .. code:: bash - usage: model_merger.py [-h] [--backend {fsdp,megatron}] - [--tie-word-embedding whether the model share embedding weights] - [--is-value-model whether the model is critic model] - [--hf_model_path $original_model_path, like {Qwen/Qwen2-7B}] - [--local_dir $local_directory saved fsdp or megatron models] - [--target_dir $target_dir to save converted models, default is tmp] - [--hf_upload_path $huggingface_repo to upload] - -So example use of Megatron model merger is: + usage: model_merger.py merge [-h] --backend {fsdp,megatron} --local_dir LOCAL_DIR [--hf_model_path HF_MODEL_PATH] + [--tie-word-embedding] [--is-value-model] [--target_dir TARGET_DIR] + [--hf_upload_path HF_UPLOAD_PATH] [--private] + + options: + -h, --help show this help message and exit + --backend {fsdp,megatron} + The backend of the model + --local_dir LOCAL_DIR + Path to the saved model checkpoints + --hf_model_path HF_MODEL_PATH + (Deprecated) Path to the original Hugging Face model for config. + --tie-word-embedding Whether to tie word embedding weights (currently only Megatron supported) + --is-value-model Whether the model is a value model (currently only Megatron supported) + --target_dir TARGET_DIR + Directory to save the merged huggingface model + --hf_upload_path HF_UPLOAD_PATH + Hugging Face repository ID to upload the model + --private Whether to upload the model to a private Hugging Face repository + +Example usage for merging Megatron checkpoints: .. code:: bash python scripts/model_merger.py \ --backend megatron \ --tie-word-embedding \ - --hf_model_path Qwen/Qwen2.5-0.5B \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model + +Example usage for merging FSDP checkpoints: + +.. code:: bash + + python scripts/model_merger.py \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model + Megatron Merger details ----------------------- From 24021bc655a110ce5d3a1f510789694ecefc24ca Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sat, 10 May 2025 11:24:30 +0800 Subject: [PATCH 3/8] update model merger commands in CI --- .github/workflows/e2e_ppo_trainer.yml | 15 ++++++++++++--- .github/workflows/e2e_ppo_trainer_megatron.yml | 8 ++++---- tests/e2e/ppo_trainer/run_function_reward.sh | 2 ++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index dac4c296bf7..9950f73e3a9 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -83,7 +83,7 @@ jobs: ray stop --force python3 examples/data_preprocess/gsm8k.py # Function RM - - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving + - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8) run: | ray stop --force VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True bash tests/e2e/ppo_trainer/run_function_reward.sh @@ -91,10 +91,19 @@ jobs: run: | ray stop --force RESUME_MODE=auto bash tests/e2e/ppo_trainer/run_function_reward.sh - - name: Test FSDP checkpoints merging function (Qwen Actor) + - name: Test merging FSDP checkpoints (Qwen Actor) run: | exp_name="qwen2.5-0.5b-function-reward-minimal" - python scripts/model_merger.py --backend fsdp --hf_model_path ~/models/Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (DDP_SIZE=2, FSDP_SIZE=4) + run: | + ray stop --force + rm -rf checkpoints/verl-test/* + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True FSDP_SIZE=4 bash tests/e2e/ppo_trainer/run_function_reward.sh + - name: Test merging DDP+FSDP checkpoints (Qwen Actor) + run: | + exp_name="qwen2.5-0.5b-function-reward-minimal" + python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E without rmpad using function rm run: | ray stop --force diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 0e945481c4b..6634dd66906 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -73,8 +73,8 @@ jobs: - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) run: | exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" - python scripts/model_merger.py --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py --backend megatron --is-value-model --hf_model_path Qwen/Qwen2.5-0.5B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force @@ -119,8 +119,8 @@ jobs: - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) run: | exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py --backend megatron --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py --backend megatron --is-value-model --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: clean up run: | rm -rf checkpoints diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index 0c842e86ffe..2b262f1ff3b 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -32,6 +32,7 @@ TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} # whether to save hf_model SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} +FSDP_SIZE=${FSDP_SIZE:--1} if [ "${SAVE_HF_MODEL}" = "True" ]; then CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']" @@ -79,6 +80,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ actor_rollout_ref.actor.checkpoint.contents=${CHECKPOINT_CONTENTS} \ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ From b873a04fd87c1bca6e341748ef72d989b2a4f550 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sat, 10 May 2025 11:35:09 +0800 Subject: [PATCH 4/8] Update E2E training scripts to include VERL_EXP_NAME --- .github/workflows/e2e_ppo_trainer.yml | 11 +++++------ tests/e2e/ppo_trainer/run_function_reward.sh | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 9950f73e3a9..9b29e1dc48a 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -86,23 +86,22 @@ jobs: - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8) run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True bash tests/e2e/ppo_trainer/run_function_reward.sh + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp8" bash tests/e2e/ppo_trainer/run_function_reward.sh - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm after resuming run: | ray stop --force - RESUME_MODE=auto bash tests/e2e/ppo_trainer/run_function_reward.sh + RESUME_MODE=auto VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-fsdp8" bash tests/e2e/ppo_trainer/run_function_reward.sh - name: Test merging FSDP checkpoints (Qwen Actor) run: | - exp_name="qwen2.5-0.5b-function-reward-minimal" + exp_name="qwen2.5-0.5b-function-reward-minimal-fsdp8" python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (DDP_SIZE=2, FSDP_SIZE=4) run: | ray stop --force - rm -rf checkpoints/verl-test/* - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True FSDP_SIZE=4 bash tests/e2e/ppo_trainer/run_function_reward.sh + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 SAVE_HF_MODEL=True FSDP_SIZE=4 VERL_EXP_NAME="qwen2.5-0.5b-function-reward-minimal-ddp2-fsdp4" bash tests/e2e/ppo_trainer/run_function_reward.sh - name: Test merging DDP+FSDP checkpoints (Qwen Actor) run: | - exp_name="qwen2.5-0.5b-function-reward-minimal" + exp_name="qwen2.5-0.5b-function-reward-minimal-ddp2-fsdp4" python scripts/model_merger.py test --backend fsdp --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - name: Running GSM8K E2E without rmpad using function rm run: | diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index 2b262f1ff3b..68640dde49a 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -64,7 +64,7 @@ EOF rm -rf "${output_file}" fi -exp_name="$(basename "${MODEL_ID,,}")-function-reward-minimal" +exp_name="${VERL_EXP_NAME:-$(basename "${MODEL_ID,,}")-function-reward-minimal}" python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator="${ADV_ESTIMATOR}" \ From ada4ebfe1603dde223f8fc45911b187e9c1657f3 Mon Sep 17 00:00:00 2001 From: Qunhong Zeng <871206929@qq.com> Date: Thu, 15 May 2025 15:02:02 +0800 Subject: [PATCH 5/8] fix typo in checkpoint.rst --- docs/advance/checkpoint.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/advance/checkpoint.rst b/docs/advance/checkpoint.rst index 082a1e93881..b9bebcf57c3 100644 --- a/docs/advance/checkpoint.rst +++ b/docs/advance/checkpoint.rst @@ -96,7 +96,7 @@ Example usage for merging Megatron checkpoints: .. code:: bash - python scripts/model_merger.py \ + python scripts/model_merger.py merge \ --backend megatron \ --tie-word-embedding \ --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ @@ -106,7 +106,7 @@ Example usage for merging FSDP checkpoints: .. code:: bash - python scripts/model_merger.py \ + python scripts/model_merger.py merge \ --backend fsdp \ --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ --target_dir /path/to/merged_hf_model From e03e9e9a5365dc24304dd5833517a5d9000fe718 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Thu, 15 May 2025 11:42:40 +0000 Subject: [PATCH 6/8] fix: megatron checkpoint manager currently don't save model config, still use --hf_model_path --- .github/workflows/e2e_ppo_trainer_megatron.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 6634dd66906..7f75a41400a 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -73,8 +73,8 @@ jobs: - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) run: | exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path Qwen/Qwen2.5-0.5B + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path Qwen/Qwen2.5-0.5B - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force @@ -119,8 +119,8 @@ jobs: - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) run: | exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct - name: clean up run: | rm -rf checkpoints From 983f2ff55fca7f61bfc70a68d4ca842cacaae2e8 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Fri, 16 May 2025 06:18:02 +0000 Subject: [PATCH 7/8] fix: return none if no mapping found in name_mapping --- scripts/model_merger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 7d4d85cf092..210b85e15f4 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -546,7 +546,7 @@ def _replace_name(self, megatron_name: str, name_mapping: list[tuple[str, str]]) else: param_name = megatron_name.replace(m_name, v_name) return param_name - return megatron_name # Return original name if no mapping found + return None # Return None if no mapping found def main(): From df7fde576bd44dfdd890113c2e10b36accb90f41 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Fri, 16 May 2025 10:35:34 +0000 Subject: [PATCH 8/8] ci: fix testing qwen3 checkpoint --- .github/workflows/e2e_ppo_trainer_megatron.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index f49b4c9bb53..4735d4bbe50 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -256,8 +256,8 @@ jobs: - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) run: | exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python scripts/model_merger.py --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py --backend megatron --is-value-model --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + python scripts/model_merger.py test --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) run: | ray stop --force